Merge branch 'develop' into daniel/ick

pull/633/head
Daniel Wagner-Hall 2016-03-08 17:35:09 +00:00
commit 3b97797c8d
49 changed files with 1062 additions and 337 deletions

View File

@ -21,5 +21,6 @@ recursive-include synapse/static *.html
recursive-include synapse/static *.js recursive-include synapse/static *.js
exclude jenkins.sh exclude jenkins.sh
exclude jenkins*.sh
prune demo/etc prune demo/etc

22
jenkins-flake8.sh Executable file
View File

@ -0,0 +1,22 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
export PEP8SUFFIX="--output-file=violations.flake8.log"
rm .coverage* || echo "No coverage files to remove"
tox -e packaging -e pep8

92
jenkins-postgres.sh Executable file
View File

@ -0,0 +1,92 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
TOX_BIN=$WORKSPACE/.tox/py27/bin
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PERL5LIB:=$WORKSPACE/perl5/lib/perl5}
: ${PERL_MB_OPT:=--install_base=$WORKSPACE/perl5}
: ${PERL_MM_OPT:=INSTALL_BASE=$WORKSPACE/perl5}
export PERL5LIB PERL_MB_OPT PERL_MM_OPT
./install-deps.pl
: ${PORT_BASE:=8000}
if [[ -z "$POSTGRES_DB_1" ]]; then
echo >&2 "Variable POSTGRES_DB_1 not set"
exit 1
fi
if [[ -z "$POSTGRES_DB_2" ]]; then
echo >&2 "Variable POSTGRES_DB_2 not set"
exit 1
fi
mkdir -p "localhost-$(($PORT_BASE + 1))"
mkdir -p "localhost-$(($PORT_BASE + 2))"
cat > localhost-$(($PORT_BASE + 1))/database.yaml << EOF
name: psycopg2
args:
database: $POSTGRES_DB_1
EOF
cat > localhost-$(($PORT_BASE + 2))/database.yaml << EOF
name: psycopg2
args:
database: $POSTGRES_DB_2
EOF
# Run if both postgresql databases exist
echo >&2 "Running sytest with PostgreSQL";
$TOX_BIN/pip install psycopg2
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
--python $TOX_BIN/python --all --port-base $PORT_BASE > results-postgresql.tap
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

63
jenkins-sqlite.sh Executable file
View File

@ -0,0 +1,63 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
TOX_BIN=$WORKSPACE/.tox/py27/bin
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PERL5LIB:=$WORKSPACE/perl5/lib/perl5}
: ${PERL_MB_OPT:=--install_base=$WORKSPACE/perl5}
: ${PERL_MM_OPT:=INSTALL_BASE=$WORKSPACE/perl5}
export PERL5LIB PERL_MB_OPT PERL_MM_OPT
./install-deps.pl
: ${PORT_BASE:=8500}
echo >&2 "Running sytest with SQLite3";
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
--python $TOX_BIN/python --all --port-base $PORT_BASE > results-sqlite3.tap
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

25
jenkins-unittests.sh Executable file
View File

@ -0,0 +1,25 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox -e py27

View File

@ -534,7 +534,7 @@ class Auth(object):
) )
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
user_info = yield self._get_user_by_access_token(access_token) user_info = yield self.get_user_by_access_token(access_token)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
@ -595,7 +595,7 @@ class Auth(object):
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_user_by_access_token(self, token): def get_user_by_access_token(self, token):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:

View File

@ -722,7 +722,7 @@ def run(hs):
if hs.config.daemonize: if hs.config.daemonize:
if hs.config.print_pidfile: if hs.config.print_pidfile:
print hs.config.pid_file print (hs.config.pid_file)
daemon = Daemonize( daemon = Daemonize(
app="synapse-homeserver", app="synapse-homeserver",

View File

@ -29,13 +29,13 @@ NORMAL = "\x1b[m"
def start(configfile): def start(configfile):
print "Starting ...", print ("Starting ...")
args = SYNAPSE args = SYNAPSE
args.extend(["--daemonize", "-c", configfile]) args.extend(["--daemonize", "-c", configfile])
try: try:
subprocess.check_call(args) subprocess.check_call(args)
print GREEN + "started" + NORMAL print (GREEN + "started" + NORMAL)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print ( print (
RED + RED +
@ -48,7 +48,7 @@ def stop(pidfile):
if os.path.exists(pidfile): if os.path.exists(pidfile):
pid = int(open(pidfile).read()) pid = int(open(pidfile).read())
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
print GREEN + "stopped" + NORMAL print (GREEN + "stopped" + NORMAL)
def main(): def main():

View File

@ -28,7 +28,7 @@ if __name__ == "__main__":
sys.stderr.write("\n" + e.message + "\n") sys.stderr.write("\n" + e.message + "\n")
sys.exit(1) sys.exit(1)
print getattr(config, key) print (getattr(config, key))
sys.exit(0) sys.exit(0)
else: else:
sys.stderr.write("Unknown command %r\n" % (action,)) sys.stderr.write("Unknown command %r\n" % (action,))

View File

@ -104,7 +104,7 @@ class Config(object):
dir_path = cls.abspath(dir_path) dir_path = cls.abspath(dir_path)
try: try:
os.makedirs(dir_path) os.makedirs(dir_path)
except OSError, e: except OSError as e:
if e.errno != errno.EEXIST: if e.errno != errno.EEXIST:
raise raise
if not os.path.isdir(dir_path): if not os.path.isdir(dir_path):

40
synapse/config/api.py Normal file
View File

@ -0,0 +1,40 @@
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
from synapse.api.constants import EventTypes
class ApiConfig(Config):
def read_config(self, config):
self.room_invite_state_types = config.get("room_invite_state_types", [
EventTypes.JoinRules,
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.Name,
])
def default_config(cls, **kwargs):
return """\
## API Configuration ##
# A list of event types that will be included in the room_invite_state
room_invite_state_types:
- "{JoinRules}"
- "{CanonicalAlias}"
- "{RoomAvatar}"
- "{Name}"
""".format(**vars(EventTypes))

View File

@ -23,6 +23,7 @@ from .captcha import CaptchaConfig
from .voip import VoipConfig from .voip import VoipConfig
from .registration import RegistrationConfig from .registration import RegistrationConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig
from .api import ApiConfig
from .appservice import AppServiceConfig from .appservice import AppServiceConfig
from .key import KeyConfig from .key import KeyConfig
from .saml2 import SAML2Config from .saml2 import SAML2Config
@ -32,7 +33,7 @@ from .password import PasswordConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
PasswordConfig,): PasswordConfig,):
pass pass

View File

@ -114,7 +114,7 @@ class FederationClient(FederationBase):
@log_function @log_function
def make_query(self, destination, query_type, args, def make_query(self, destination, query_type, args,
retry_on_dns_fail=True): retry_on_dns_fail=False):
"""Sends a federation Query to a remote homeserver of the given type """Sends a federation Query to a remote homeserver of the given type
and arguments. and arguments.

View File

@ -160,6 +160,7 @@ class TransportLayerClient(object):
path=path, path=path,
args=args, args=args,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
timeout=10000,
) )
defer.returnValue(content) defer.returnValue(content)

View File

@ -29,6 +29,14 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VISIBILITY_PRIORITY = (
"world_readable",
"shared",
"invited",
"joined",
)
class BaseHandler(object): class BaseHandler(object):
""" """
Common base class for the event handlers. Common base class for the event handlers.
@ -85,10 +93,28 @@ class BaseHandler(object):
else: else:
visibility = "shared" visibility = "shared"
if visibility not in VISIBILITY_PRIORITY:
visibility = "shared"
# if it was world_readable, it's easy: everyone can read it # if it was world_readable, it's easy: everyone can read it
if visibility == "world_readable": if visibility == "world_readable":
return True return True
# Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive
# of the old vs new.
if event.type == EventTypes.RoomHistoryVisibility:
prev_content = event.unsigned.get("prev_content", {})
prev_visibility = prev_content.get("history_visibility", None)
if prev_visibility not in VISIBILITY_PRIORITY:
prev_visibility = "shared"
new_priority = VISIBILITY_PRIORITY.index(visibility)
old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
if old_priority < new_priority:
visibility = prev_visibility
# get the user's membership at the time of the event. (or rather, # get the user's membership at the time of the event. (or rather,
# just *after* the event. Which means that people can see their # just *after* the event. Which means that people can see their
# own join events, but not (currently) their own leave events.) # own join events, but not (currently) their own leave events.)
@ -160,10 +186,10 @@ class BaseHandler(object):
) )
defer.returnValue(res.get(user_id, [])) defer.returnValue(res.get(user_id, []))
def ratelimit(self, user_id): def ratelimit(self, requester):
time_now = self.clock.time() time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
user_id, time_now, requester.user.to_string(), time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second, msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count, burst_count=self.hs.config.rc_message_burst_count,
) )
@ -199,8 +225,7 @@ class BaseHandler(object):
# events in the room, because we don't know enough about the graph # events in the room, because we don't know enough about the graph
# fragment we received to treat it like a graph, so the above returned # fragment we received to treat it like a graph, so the above returned
# no relevant events. It may have returned some events (if we have # no relevant events. It may have returned some events (if we have
# joined and left the room), but not useful ones, like the invite. So we # joined and left the room), but not useful ones, like the invite.
# forcibly set our context to the invite we received over federation.
if ( if (
not self.is_host_in_room(context.current_state) and not self.is_host_in_room(context.current_state) and
builder.type == EventTypes.Member builder.type == EventTypes.Member
@ -208,7 +233,27 @@ class BaseHandler(object):
prev_member_event = yield self.store.get_room_member( prev_member_event = yield self.store.get_room_member(
builder.sender, builder.room_id builder.sender, builder.room_id
) )
if prev_member_event:
# The prev_member_event may already be in context.current_state,
# despite us not being present in the room; in particular, if
# inviting user, and all other local users, have already left.
#
# In that case, we have all the information we need, and we don't
# want to drop "context" - not least because we may need to handle
# the invite locally, which will require us to have the whole
# context (not just prev_member_event) to auth it.
#
context_event_ids = (
e.event_id for e in context.current_state.values()
)
if (
prev_member_event and
prev_member_event.event_id not in context_event_ids
):
# The prev_member_event is missing from context, so it must
# have arrived over federation and is an outlier. We forcibly
# set our context to the invite we received over federation
builder.prev_events = ( builder.prev_events = (
prev_member_event.event_id, prev_member_event.event_id,
prev_member_event.prev_events prev_member_event.prev_events
@ -263,11 +308,18 @@ class BaseHandler(object):
return False return False
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event(self, event, context, ratelimit=True, extra_users=[]): def handle_new_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[]
):
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if ratelimit: if ratelimit:
self.ratelimit(event.sender) self.ratelimit(requester)
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
@ -307,12 +359,8 @@ class BaseHandler(object):
"sender": e.sender, "sender": e.sender,
} }
for k, e in context.current_state.items() for k, e in context.current_state.items()
if e.type in ( if e.type in self.hs.config.room_invite_state_types
EventTypes.JoinRules, or is_inviter_member_event(e)
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.Name,
) or is_inviter_member_event(e)
] ]
invitee = UserID.from_string(event.state_key) invitee = UserID.from_string(event.state_key)
@ -348,6 +396,12 @@ class BaseHandler(object):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
if event.type == EventTypes.Create and context.current_state:
raise AuthError(
403,
"Changing the room create event is forbidden",
)
action_generator = ActionGenerator(self.hs) action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event( yield action_generator.handle_push_actions_for_event(
event, context, self event, context, self

View File

@ -477,4 +477,4 @@ class AuthHandler(BaseHandler):
Returns: Returns:
Whether self.hash(password) == stored_hash (bool). Whether self.hash(password) == stored_hash (bool).
""" """
return bcrypt.checkpw(password, stored_hash) return bcrypt.hashpw(password, stored_hash) == stored_hash

View File

@ -17,9 +17,9 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomAlias from synapse.types import RoomAlias, UserID
import logging import logging
import string import string
@ -38,7 +38,7 @@ class DirectoryHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_association(self, room_alias, room_id, servers=None): def _create_association(self, room_alias, room_id, servers=None, creator=None):
# general association creation for both human users and app services # general association creation for both human users and app services
for wchar in string.whitespace: for wchar in string.whitespace:
@ -60,7 +60,8 @@ class DirectoryHandler(BaseHandler):
yield self.store.create_room_alias_association( yield self.store.create_room_alias_association(
room_alias, room_alias,
room_id, room_id,
servers servers,
creator=creator,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -77,7 +78,7 @@ class DirectoryHandler(BaseHandler):
400, "This alias is reserved by an application service.", 400, "This alias is reserved by an application service.",
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
yield self._create_association(room_alias, room_id, servers) yield self._create_association(room_alias, room_id, servers, creator=user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_appservice_association(self, service, room_alias, room_id, def create_appservice_association(self, service, room_alias, room_id,
@ -95,7 +96,11 @@ class DirectoryHandler(BaseHandler):
def delete_association(self, user_id, room_alias): def delete_association(self, user_id, room_alias):
# association deletion for human users # association deletion for human users
# TODO Check if server admin can_delete = yield self._user_can_delete_alias(room_alias, user_id)
if not can_delete:
raise AuthError(
403, "You don't have permission to delete the alias.",
)
can_delete = yield self.can_modify_alias( can_delete = yield self.can_modify_alias(
room_alias, room_alias,
@ -212,17 +217,21 @@ class DirectoryHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def send_room_alias_update_event(self, user_id, room_id): def send_room_alias_update_event(self, requester, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event({ yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Aliases, "type": EventTypes.Aliases,
"state_key": self.hs.hostname, "state_key": self.hs.hostname,
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": user_id,
"content": {"aliases": aliases}, "content": {"aliases": aliases},
}, ratelimit=False) },
ratelimit=False
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias): def get_association_from_room_alias(self, room_alias):
@ -257,3 +266,13 @@ class DirectoryHandler(BaseHandler):
return return
# either no interested services, or no service with an exclusive lock # either no interested services, or no service with an exclusive lock
defer.returnValue(True) defer.returnValue(True)
@defer.inlineCallbacks
def _user_can_delete_alias(self, alias, user_id):
creator = yield self.store.get_room_alias_creator(alias.to_string())
if creator and creator == user_id:
defer.returnValue(True)
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
defer.returnValue(is_admin)

View File

@ -472,7 +472,7 @@ class FederationHandler(BaseHandler):
limit=100, limit=100,
extremities=[e for e in extremities.keys()] extremities=[e for e in extremities.keys()]
) )
except SynapseError: except SynapseError as e:
logger.info( logger.info(
"Failed to backfill from %s because %s", "Failed to backfill from %s because %s",
dom, e, dom, e,
@ -1657,7 +1657,7 @@ class FederationHandler(BaseHandler):
self.auth.check(event, context.current_state) self.auth.check(event, context.current_state)
yield self._check_signature(event, auth_events=context.current_state) yield self._check_signature(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(event, context, from_client=False) yield member_handler.send_membership_event(None, event, context)
else: else:
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id)) destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
yield self.replication_layer.forward_third_party_invite( yield self.replication_layer.forward_third_party_invite(
@ -1686,7 +1686,7 @@ class FederationHandler(BaseHandler):
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures) event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(event, context, from_client=False) yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_display_name_to_third_party_invite(self, event_dict, event, context): def add_display_name_to_third_party_invite(self, event_dict, event, context):

View File

@ -196,11 +196,24 @@ class MessageHandler(BaseHandler):
if builder.type == EventTypes.Member: if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership == Membership.JOIN: if membership == Membership.JOIN:
joinee = UserID.from_string(builder.state_key)
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
yield collect_presencelike_data( yield collect_presencelike_data(
self.distributor, joinee, builder.content self.distributor, target, builder.content
)
elif membership == Membership.INVITE:
profile = self.hs.get_handlers().profile_handler
content = builder.content
try:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
) )
if token_id is not None: if token_id is not None:
@ -215,7 +228,7 @@ class MessageHandler(BaseHandler):
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
def send_nonmember_event(self, event, context, ratelimit=True): def send_nonmember_event(self, requester, event, context, ratelimit=True):
""" """
Persists and notifies local clients and federation of an event. Persists and notifies local clients and federation of an event.
@ -241,6 +254,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(prev_state) defer.returnValue(prev_state)
yield self.handle_new_client_event( yield self.handle_new_client_event(
requester=requester,
event=event, event=event,
context=context, context=context,
ratelimit=ratelimit, ratelimit=ratelimit,
@ -268,9 +282,9 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_nonmember_event( def create_and_send_nonmember_event(
self, self,
requester,
event_dict, event_dict,
ratelimit=True, ratelimit=True,
token_id=None,
txn_id=None txn_id=None
): ):
""" """
@ -280,10 +294,11 @@ class MessageHandler(BaseHandler):
""" """
event, context = yield self.create_event( event, context = yield self.create_event(
event_dict, event_dict,
token_id=token_id, token_id=requester.access_token_id,
txn_id=txn_id txn_id=txn_id
) )
yield self.send_nonmember_event( yield self.send_nonmember_event(
requester,
event, event,
context, context,
ratelimit=ratelimit, ratelimit=ratelimit,
@ -647,8 +662,8 @@ class MessageHandler(BaseHandler):
user_id, messages, is_peeking=is_peeking user_id, messages, is_peeking=is_peeking
) )
start_token = StreamToken(token[0], 0, 0, 0, 0) start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken(token[1], 0, 0, 0, 0) end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View File

@ -89,13 +89,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["displayname"]) defer.returnValue(result["displayname"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_displayname(self, target_user, auth_user, new_displayname): def set_displayname(self, target_user, requester, new_displayname):
"""target_user is the user whose displayname is to be changed; """target_user is the user whose displayname is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname") raise AuthError(400, "Cannot set another user's displayname")
if new_displayname == '': if new_displayname == '':
@ -109,7 +109,7 @@ class ProfileHandler(BaseHandler):
"displayname": new_displayname, "displayname": new_displayname,
}) })
yield self._update_join_states(target_user) yield self._update_join_states(requester)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_avatar_url(self, target_user): def get_avatar_url(self, target_user):
@ -139,13 +139,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["avatar_url"]) defer.returnValue(result["avatar_url"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_avatar_url(self, target_user, auth_user, new_avatar_url): def set_avatar_url(self, target_user, requester, new_avatar_url):
"""target_user is the user whose avatar_url is to be changed; """target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url") raise AuthError(400, "Cannot set another user's avatar_url")
yield self.store.set_profile_avatar_url( yield self.store.set_profile_avatar_url(
@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler):
"avatar_url": new_avatar_url, "avatar_url": new_avatar_url,
}) })
yield self._update_join_states(target_user) yield self._update_join_states(requester)
@defer.inlineCallbacks @defer.inlineCallbacks
def collect_presencelike_data(self, user, state): def collect_presencelike_data(self, user, state):
@ -199,11 +199,12 @@ class ProfileHandler(BaseHandler):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_join_states(self, user): def _update_join_states(self, requester):
user = requester.user
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
return return
self.ratelimit(user.to_string()) self.ratelimit(requester)
joins = yield self.store.get_rooms_for_user( joins = yield self.store.get_rooms_for_user(
user.to_string(), user.to_string(),

View File

@ -157,6 +157,7 @@ class RegistrationHandler(BaseHandler):
) )
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
user = None
user_id = None user_id = None
token = None token = None
attempts += 1 attempts += 1
@ -240,7 +241,7 @@ class RegistrationHandler(BaseHandler):
password_hash=None password_hash=None
) )
yield registered_user(self.distributor, user) yield registered_user(self.distributor, user)
except Exception, e: except Exception as e:
yield self.store.add_access_token_to_user(user_id, token) yield self.store.add_access_token_to_user(user_id, token)
# Ignore Registration errors # Ignore Registration errors
logger.exception(e) logger.exception(e)

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset, EventTypes, Membership, JoinRules, RoomCreationPreset,
) )
@ -90,7 +90,7 @@ class RoomCreationHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
self.ratelimit(user_id) self.ratelimit(requester)
if "room_alias_name" in config: if "room_alias_name" in config:
for wchar in string.whitespace: for wchar in string.whitespace:
@ -185,26 +185,32 @@ class RoomCreationHandler(BaseHandler):
if "name" in config: if "name" in config:
name = config["name"] name = config["name"]
yield msg_handler.create_and_send_nonmember_event({ yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name, "type": EventTypes.Name,
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": user_id,
"state_key": "", "state_key": "",
"content": {"name": name}, "content": {"name": name},
}, ratelimit=False) },
ratelimit=False)
if "topic" in config: if "topic" in config:
topic = config["topic"] topic = config["topic"]
yield msg_handler.create_and_send_nonmember_event({ yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic, "type": EventTypes.Topic,
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": user_id,
"state_key": "", "state_key": "",
"content": {"topic": topic}, "content": {"topic": topic},
}, ratelimit=False) },
ratelimit=False)
for invitee in invite_list: for invitee in invite_list:
room_member_handler.update_membership( yield room_member_handler.update_membership(
requester, requester,
UserID.from_string(invitee), UserID.from_string(invitee),
room_id, room_id,
@ -231,7 +237,7 @@ class RoomCreationHandler(BaseHandler):
if room_alias: if room_alias:
result["room_alias"] = room_alias.to_string() result["room_alias"] = room_alias.to_string()
yield directory_handler.send_room_alias_update_event( yield directory_handler.send_room_alias_update_event(
user_id, room_id requester, user_id, room_id
) )
defer.returnValue(result) defer.returnValue(result)
@ -263,7 +269,11 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def send(etype, content, **kwargs): def send(etype, content, **kwargs):
event = create(etype, content, **kwargs) event = create(etype, content, **kwargs)
yield msg_handler.create_and_send_nonmember_event(event, ratelimit=False) yield msg_handler.create_and_send_nonmember_event(
creator,
event,
ratelimit=False
)
config = RoomCreationHandler.PRESETS_DICT[preset_config] config = RoomCreationHandler.PRESETS_DICT[preset_config]
@ -454,12 +464,11 @@ class RoomMemberHandler(BaseHandler):
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event( yield member_handler.send_membership_event(
requester,
event, event,
context, context,
is_guest=requester.is_guest,
ratelimit=ratelimit, ratelimit=ratelimit,
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
from_client=True,
) )
if action == "forget": if action == "forget":
@ -468,17 +477,19 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_membership_event( def send_membership_event(
self, self,
requester,
event, event,
context, context,
is_guest=False,
remote_room_hosts=None, remote_room_hosts=None,
ratelimit=True, ratelimit=True,
from_client=True,
): ):
""" """
Change the membership status of a user in a room. Change the membership status of a user in a room.
Args: Args:
requester (Requester): The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
event (SynapseEvent): The membership event. event (SynapseEvent): The membership event.
context: The context of the event. context: The context of the event.
is_guest (bool): Whether the sender is a guest. is_guest (bool): Whether the sender is a guest.
@ -486,19 +497,23 @@ class RoomMemberHandler(BaseHandler):
the room, and could be danced with in order to join this the room, and could be danced with in order to join this
homeserver for the first time. homeserver for the first time.
ratelimit (bool): Whether to rate limit this request. ratelimit (bool): Whether to rate limit this request.
from_client (bool): Whether this request is the result of a local
client request (rather than over federation). If so, we will
perform extra checks, like that this homeserver can act as this
client.
Raises: Raises:
SynapseError if there was a problem changing the membership. SynapseError if there was a problem changing the membership.
""" """
remote_room_hosts = remote_room_hosts or []
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
room_id = event.room_id room_id = event.room_id
if from_client: if requester is not None:
sender = UserID.from_string(event.sender) sender = UserID.from_string(event.sender)
assert sender == requester.user, (
"Sender (%s) must be same as requester (%s)" %
(sender, requester.user)
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context) prev_event = message_handler.deduplicate_state_event(event, context)
@ -508,7 +523,7 @@ class RoomMemberHandler(BaseHandler):
action = "send" action = "send"
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if is_guest and not self._can_guest_join(context.current_state): if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
@ -521,7 +536,23 @@ class RoomMemberHandler(BaseHandler):
action = "remote_join" action = "remote_join"
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
is_host_in_room = self.is_host_in_room(context.current_state) is_host_in_room = self.is_host_in_room(context.current_state)
if not is_host_in_room: if not is_host_in_room:
# perhaps we've been invited
inviter = self.get_inviter(target_user.to_string(), context.current_state)
if not inviter:
raise SynapseError(404, "Not a known room")
if self.hs.is_mine(inviter):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
# This is a bit of a hack, because the room might still be
# active on other servers.
pass
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
action = "remote_reject" action = "remote_reject"
federation_handler = self.hs.get_handlers().federation_handler federation_handler = self.hs.get_handlers().federation_handler
@ -541,16 +572,14 @@ class RoomMemberHandler(BaseHandler):
event.content, event.content,
) )
elif action == "remote_reject": elif action == "remote_reject":
inviter = self.get_inviter(target_user.to_string(), context.current_state)
if not inviter:
raise SynapseError(404, "No known servers")
yield federation_handler.do_remotely_reject_invite( yield federation_handler.do_remotely_reject_invite(
[inviter.domain], remote_room_hosts,
room_id, room_id,
event.user_id event.user_id
) )
else: else:
yield self.handle_new_client_event( yield self.handle_new_client_event(
requester,
event, event,
context, context,
extra_users=[target_user], extra_users=[target_user],
@ -669,12 +698,12 @@ class RoomMemberHandler(BaseHandler):
) )
else: else:
yield self._make_and_store_3pid_invite( yield self._make_and_store_3pid_invite(
requester,
id_server, id_server,
medium, medium,
address, address,
room_id, room_id,
inviter, inviter,
requester.access_token_id,
txn_id=txn_id txn_id=txn_id
) )
@ -732,12 +761,12 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_store_3pid_invite( def _make_and_store_3pid_invite(
self, self,
requester,
id_server, id_server,
medium, medium,
address, address,
room_id, room_id,
user, user,
token_id,
txn_id txn_id
): ):
room_state = yield self.hs.get_state_handler().get_current_state(room_id) room_state = yield self.hs.get_state_handler().get_current_state(room_id)
@ -787,6 +816,7 @@ class RoomMemberHandler(BaseHandler):
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event( yield msg_handler.create_and_send_nonmember_event(
requester,
{ {
"type": EventTypes.ThirdPartyInvite, "type": EventTypes.ThirdPartyInvite,
"content": { "content": {
@ -801,7 +831,6 @@ class RoomMemberHandler(BaseHandler):
"sender": user.to_string(), "sender": user.to_string(),
"state_key": token, "state_key": token,
}, },
token_id=token_id,
txn_id=txn_id, txn_id=txn_id,
) )
@ -855,6 +884,10 @@ class RoomMemberHandler(BaseHandler):
inviter_user_id=inviter_user_id, inviter_user_id=inviter_user_id,
) )
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % ( is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server, id_server_scheme, id_server,
) )
@ -871,6 +904,7 @@ class RoomMemberHandler(BaseHandler):
"sender": inviter_user_id, "sender": inviter_user_id,
"sender_display_name": inviter_display_name, "sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url, "sender_avatar_url": inviter_avatar_url,
"guest_user_id": guest_user_info["user"].to_string(),
"guest_access_token": guest_access_token, "guest_access_token": guest_access_token,
} }
) )

View File

@ -20,6 +20,7 @@ from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.push.clientformat import format_push_rules_for_user
from twisted.internet import defer from twisted.internet import defer
@ -209,9 +210,9 @@ class SyncHandler(BaseHandler):
key=None key=None
) )
membership_list = (Membership.INVITE, Membership.JOIN) membership_list = (
if sync_config.filter_collection.include_leave: Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN
membership_list += (Membership.LEAVE, Membership.BAN) )
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=sync_config.user.to_string(), user_id=sync_config.user.to_string(),
@ -224,6 +225,10 @@ class SyncHandler(BaseHandler):
) )
) )
account_data['m.push_rules'] = yield self.push_rules_for_user(
sync_config.user
)
tags_by_room = yield self.store.get_tags_for_user( tags_by_room = yield self.store.get_tags_for_user(
sync_config.user.to_string() sync_config.user.to_string()
) )
@ -257,6 +262,12 @@ class SyncHandler(BaseHandler):
invite=invite, invite=invite,
)) ))
elif event.membership in (Membership.LEAVE, Membership.BAN): elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if sync_config.user.to_string() == event.sender:
continue
leave_token = now_token.copy_and_replace( leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,) "room_key", "s%d" % (event.stream_ordering,)
) )
@ -322,6 +333,14 @@ class SyncHandler(BaseHandler):
defer.returnValue(room_sync) defer.returnValue(room_sync)
@defer.inlineCallbacks
def push_rules_for_user(self, user):
user_id = user.to_string()
rawrules = yield self.store.get_push_rules_for_user(user_id)
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
rules = format_push_rules_for_user(user, rawrules, enabled_map)
defer.returnValue(rules)
def account_data_for_user(self, account_data): def account_data_for_user(self, account_data):
account_data_events = [] account_data_events = []
@ -481,6 +500,15 @@ class SyncHandler(BaseHandler):
) )
) )
push_rules_changed = yield self.store.have_push_rules_changed_for_user(
user_id, int(since_token.push_rules_key)
)
if push_rules_changed:
account_data["m.push_rules"] = yield self.push_rules_for_user(
sync_config.user
)
# Get a list of membership change events that have happened. # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user( rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key user_id, since_token.room_key, now_token.room_key

View File

@ -103,7 +103,7 @@ class SimpleHttpClient(object):
# TODO: Do we ever want to log message contents? # TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args) logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
response = yield self.request( response = yield self.request(
"POST", "POST",
@ -249,7 +249,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_raw(self, url, args={}): def post_urlencoded_get_raw(self, url, args={}):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
response = yield self.request( response = yield self.request(
"POST", "POST",
@ -269,6 +269,19 @@ class CaptchaServerHttpClient(SimpleHttpClient):
defer.returnValue(e.response) defer.returnValue(e.response)
def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}
def encode_urlencode_arg(arg):
if isinstance(arg, unicode):
return arg.encode('utf-8')
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
else:
return arg
def _print_ex(e): def _print_ex(e):
if hasattr(e, "reasons") and e.reasons: if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons: for ex in e.reasons:

View File

@ -284,7 +284,7 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None, def wait_for_events(self, user_id, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0", "0")): from_token=StreamToken.START):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """

View File

@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push.baserules import list_with_base_rules
from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
)
import copy
import simplejson as json
def format_push_rules_for_user(user, rawrules, enabled_map):
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}}
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r['priority_class'])
# Remove internal stuff.
for c in r["conditions"]:
c.pop("_id", None)
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":
c["pattern"] = user.to_string()
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r)
if template_rule:
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
template_rule['enabled'] = r['enabled']
else:
template_rule['enabled'] = True
rulearray.append(template_rule)
return rules
def _add_empty_priority_class_arrays(d):
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _rule_to_template(rule):
unscoped_rule_id = None
if 'rule_id' in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
template_name = _priority_class_to_template_name(rule['priority_class'])
if template_name in ['override', 'underride']:
templaterule = {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ["sender", "room"]:
templaterule = {'actions': rule['actions']}
unscoped_rule_id = rule['conditions'][0]['pattern']
elif template_name == 'content':
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
templaterule = {'actions': rule['actions']}
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule
def _rule_id_from_namespaced(in_rule_id):
return in_rule_id.split('/')[-1]
def _priority_class_to_template_name(pc):
return PRIORITY_CLASS_INVERSE_MAP[pc]

View File

@ -36,6 +36,7 @@ STREAM_NAMES = (
("receipts",), ("receipts",),
("user_account_data", "room_account_data", "tag_account_data",), ("user_account_data", "room_account_data", "tag_account_data",),
("backfill",), ("backfill",),
("push_rules",),
) )
@ -63,6 +64,7 @@ class ReplicationResource(Resource):
* "room_account_data: Per room per user account data. * "room_account_data: Per room per user account data.
* "tag_account_data": Per room per user tags. * "tag_account_data": Per room per user tags.
* "backfill": Old events that have been backfilled from other servers. * "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules.
The API takes two additional query parameters: The API takes two additional query parameters:
@ -117,14 +119,16 @@ class ReplicationResource(Resource):
def current_replication_token(self): def current_replication_token(self):
stream_token = yield self.sources.get_current_token() stream_token = yield self.sources.get_current_token()
backfill_token = yield self.store.get_current_backfill_token() backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
stream_token.room_stream_id, room_stream_token,
int(stream_token.presence_key), int(stream_token.presence_key),
int(stream_token.typing_key), int(stream_token.typing_key),
int(stream_token.receipt_key), int(stream_token.receipt_key),
int(stream_token.account_data_key), int(stream_token.account_data_key),
backfill_token, backfill_token,
push_rules_token,
)) ))
@request_handler @request_handler
@ -146,6 +150,7 @@ class ReplicationResource(Resource):
yield self.presence(writer, current_token) # TODO: implement limit yield self.presence(writer, current_token) # TODO: implement limit
yield self.typing(writer, current_token) # TODO: implement limit yield self.typing(writer, current_token) # TODO: implement limit
yield self.receipts(writer, current_token, limit) yield self.receipts(writer, current_token, limit)
yield self.push_rules(writer, current_token, limit)
self.streams(writer, current_token) self.streams(writer, current_token)
logger.info("Replicated %d rows", writer.total) logger.info("Replicated %d rows", writer.total)
@ -277,6 +282,21 @@ class ReplicationResource(Resource):
"position", "user_id", "room_id", "tags" "position", "user_id", "room_id", "tags"
)) ))
@defer.inlineCallbacks
def push_rules(self, writer, current_token, limit):
current_position = current_token.push_rules
push_rules = parse_integer(writer.request, "push_rules")
if push_rules is not None:
rows = yield self.store.get_all_push_rule_updates(
push_rules, current_position, limit
)
writer.write_header_and_rows("push_rules", rows, (
"position", "event_stream_ordering", "user_id", "rule_id", "op",
"priority_class", "priority", "conditions", "actions"
))
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@ -307,12 +327,16 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules"
))): ))):
__slots__ = [] __slots__ = []
def __new__(cls, *args): def __new__(cls, *args):
if len(args) == 1: if len(args) == 1:
return cls(*(int(value) for value in args[0].split("_"))) streams = [int(value) for value in args[0].split("_")]
if len(streams) < len(cls._fields):
streams.extend([0] * (len(cls._fields) - len(streams)))
return cls(*streams)
else: else:
return super(_ReplicationToken, cls).__new__(cls, *args) return super(_ReplicationToken, cls).__new__(cls, *args)

View File

@ -75,7 +75,11 @@ class ClientDirectoryServer(ClientV1RestServlet):
yield dir_handler.create_association( yield dir_handler.create_association(
user_id, room_alias, room_id, servers user_id, room_alias, room_id, servers
) )
yield dir_handler.send_room_alias_update_event(user_id, room_id) yield dir_handler.send_room_alias_update_event(
requester,
user_id,
room_id
)
except SynapseError as e: except SynapseError as e:
raise e raise e
except: except:
@ -118,9 +122,6 @@ class ClientDirectoryServer(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user user = requester.user
is_admin = yield self.auth.is_server_admin(user)
if not is_admin:
raise AuthError(403, "You need to be a server admin")
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)

View File

@ -252,7 +252,7 @@ class SAML2RestServlet(ClientV1RestServlet):
SP = Saml2Client(conf) SP = Saml2Client(conf)
saml2_auth = SP.parse_authn_request_response( saml2_auth = SP.parse_authn_request_response(
request.args['SAMLResponse'][0], BINDING_HTTP_POST) request.args['SAMLResponse'][0], BINDING_HTTP_POST)
except Exception, e: # Not authenticated except Exception as e: # Not authenticated
logger.exception(e) logger.exception(e)
if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed: if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
username = saml2_auth.name_id.text username = saml2_auth.name_id.text

View File

@ -51,7 +51,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname( yield self.handlers.profile_handler.set_displayname(
user, requester.user, new_name) user, requester, new_name)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -88,7 +88,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url( yield self.handlers.profile_handler.set_avatar_url(
user, requester.user, new_name) user, requester, new_name)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -22,12 +22,10 @@ from .base import ClientV1RestServlet, client_path_patterns
from synapse.storage.push_rule import ( from synapse.storage.push_rule import (
InconsistentRuleException, RuleNotFoundException InconsistentRuleException, RuleNotFoundException
) )
from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import ( from synapse.push.baserules import BASE_RULE_IDS
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP
)
import copy
import simplejson as json import simplejson as json
@ -36,6 +34,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash") "Unrecognised request: You probably wanted a trailing slash")
def __init__(self, hs):
super(PushRuleRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request): def on_PUT(self, request):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
@ -51,8 +54,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
content = _parse_json(request) content = _parse_json(request)
user_id = requester.user.to_string()
if 'attr' in spec: if 'attr' in spec:
yield self.set_rule_attr(requester.user.to_string(), spec, content) yield self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
if spec['rule_id'].startswith('.'): if spec['rule_id'].startswith('.'):
@ -77,8 +83,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
after = _namespaced_rule_id(spec, after[0]) after = _namespaced_rule_id(spec, after[0])
try: try:
yield self.hs.get_datastore().add_push_rule( yield self.store.add_push_rule(
user_id=requester.user.to_string(), user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec), rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class, priority_class=priority_class,
conditions=conditions, conditions=conditions,
@ -86,6 +92,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
before=before, before=before,
after=after after=after
) )
self.notify_user(user_id)
except InconsistentRuleException as e: except InconsistentRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
except RuleNotFoundException as e: except RuleNotFoundException as e:
@ -98,13 +105,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try: try:
yield self.hs.get_datastore().delete_push_rule( yield self.store.delete_push_rule(
requester.user.to_string(), namespaced_rule_id user_id, namespaced_rule_id
) )
self.notify_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
@ -115,58 +124,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user user_id = requester.user.to_string()
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference # is probably not going to make a whole lot of difference
rawrules = yield self.hs.get_datastore().get_push_rules_for_user( rawrules = yield self.store.get_push_rules_for_user(user_id)
user.to_string()
)
ruleslist = [] enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}}
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
enabled_map = yield self.hs.get_datastore().\
get_push_rules_enabled_for_user(user.to_string())
for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r['priority_class'])
# Remove internal stuff.
for c in r["conditions"]:
c.pop("_id", None)
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":
c["pattern"] = user.to_string()
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r)
if template_rule:
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
template_rule['enabled'] = r['enabled']
else:
template_rule['enabled'] = True
rulearray.append(template_rule)
path = request.postpath[1:] path = request.postpath[1:]
@ -188,6 +155,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
def notify_user(self, user_id):
stream_id, _ = self.store.get_push_rules_stream_token()
self.notifier.on_new_event(
"push_rules_key", stream_id, users=[user_id]
)
def set_rule_attr(self, user_id, spec, val): def set_rule_attr(self, user_id, spec, val):
if spec['attr'] == 'enabled': if spec['attr'] == 'enabled':
if isinstance(val, dict) and "enabled" in val: if isinstance(val, dict) and "enabled" in val:
@ -198,7 +171,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
# bools directly, so let's not break them. # bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean") raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
return self.hs.get_datastore().set_push_rule_enabled( return self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val user_id, namespaced_rule_id, val
) )
elif spec['attr'] == 'actions': elif spec['attr'] == 'actions':
@ -210,7 +183,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
if is_default_rule: if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS: if namespaced_rule_id not in BASE_RULE_IDS:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
return self.hs.get_datastore().set_push_rule_actions( return self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule user_id, namespaced_rule_id, actions, is_default_rule
) )
else: else:
@ -308,12 +281,6 @@ def _check_actions(actions):
raise InvalidRuleException("Unrecognised action") raise InvalidRuleException("Unrecognised action")
def _add_empty_priority_class_arrays(d):
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _filter_ruleset_with_path(ruleset, path): def _filter_ruleset_with_path(ruleset, path):
if path == []: if path == []:
raise UnrecognizedRequestError( raise UnrecognizedRequestError(
@ -362,37 +329,6 @@ def _priority_class_from_spec(spec):
return pc return pc
def _priority_class_to_template_name(pc):
return PRIORITY_CLASS_INVERSE_MAP[pc]
def _rule_to_template(rule):
unscoped_rule_id = None
if 'rule_id' in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
template_name = _priority_class_to_template_name(rule['priority_class'])
if template_name in ['override', 'underride']:
templaterule = {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ["sender", "room"]:
templaterule = {'actions': rule['actions']}
unscoped_rule_id = rule['conditions'][0]['pattern']
elif template_name == 'content':
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
templaterule = {'actions': rule['actions']}
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule
def _namespaced_rule_id_from_spec(spec): def _namespaced_rule_id_from_spec(spec):
return _namespaced_rule_id(spec, spec['rule_id']) return _namespaced_rule_id(spec, spec['rule_id'])
@ -401,10 +337,6 @@ def _namespaced_rule_id(spec, rule_id):
return "global/%s/%s" % (spec['template'], rule_id) return "global/%s/%s" % (spec['template'], rule_id)
def _rule_id_from_namespaced(in_rule_id):
return in_rule_id.split('/')[-1]
class InvalidRuleException(Exception): class InvalidRuleException(Exception):
pass pass

View File

@ -158,12 +158,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
yield self.handlers.room_member_handler.send_membership_event( yield self.handlers.room_member_handler.send_membership_event(
requester,
event, event,
context, context,
is_guest=requester.is_guest,
) )
else: else:
yield msg_handler.send_nonmember_event(event, context) yield msg_handler.send_nonmember_event(requester, event, context)
defer.returnValue((200, {"event_id": event.event_id})) defer.returnValue((200, {"event_id": event.event_id}))
@ -183,13 +183,13 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_nonmember_event( event = yield msg_handler.create_and_send_nonmember_event(
requester,
{ {
"type": event_type, "type": event_type,
"content": content, "content": content,
"room_id": room_id, "room_id": room_id,
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
}, },
token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
) )
@ -504,6 +504,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_nonmember_event( event = yield msg_handler.create_and_send_nonmember_event(
requester,
{ {
"type": EventTypes.Redaction, "type": EventTypes.Redaction,
"content": content, "content": content,
@ -511,7 +512,6 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
"redacts": event_id, "redacts": event_id,
}, },
token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
) )

View File

@ -45,7 +45,7 @@ from .search import SearchStore
from .tags import TagsStore from .tags import TagsStore
from .account_data import AccountDataStore from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator from util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -122,6 +122,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id") self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._push_rules_stream_id_gen = ChainedIdGenerator(
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
)
events_max = self._stream_id_gen.get_max_token() events_max = self._stream_id_gen.get_max_token()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(
@ -157,6 +160,18 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=presence_cache_prefill prefilled_cache=presence_cache_prefill
) )
push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_max_token()[0],
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id,
prefilled_cache=push_rules_prefill,
)
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
def take_presence_startup_info(self): def take_presence_startup_info(self):

View File

@ -766,6 +766,19 @@ class SQLBaseStore(object):
"""Executes a DELETE query on the named table, expecting to delete a """Executes a DELETE query on the named table, expecting to delete a
single row. single row.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
return self.runInteraction(
desc, self._simple_delete_one_txn, table, keyvalues
)
@staticmethod
def _simple_delete_one_txn(txn, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args: Args:
table : string giving the table name table : string giving the table name
keyvalues : dict of column names and values to select the row with keyvalues : dict of column names and values to select the row with
@ -775,13 +788,11 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
) )
def func(txn):
txn.execute(sql, keyvalues.values()) txn.execute(sql, keyvalues.values())
if txn.rowcount == 0: if txn.rowcount == 0:
raise StoreError(404, "No row found") raise StoreError(404, "No row found")
if txn.rowcount > 1: if txn.rowcount > 1:
raise StoreError(500, "more than one row matched") raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func)
@staticmethod @staticmethod
def _simple_delete_txn(txn, table, keyvalues): def _simple_delete_txn(txn, table, keyvalues):

View File

@ -70,13 +70,14 @@ class DirectoryStore(SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room_alias_association(self, room_alias, room_id, servers): def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
""" Creates an associatin between a room alias and room_id/servers """ Creates an associatin between a room alias and room_id/servers
Args: Args:
room_alias (RoomAlias) room_alias (RoomAlias)
room_id (str) room_id (str)
servers (list) servers (list)
creator (str): Optional user_id of creator.
Returns: Returns:
Deferred Deferred
@ -87,6 +88,7 @@ class DirectoryStore(SQLBaseStore):
{ {
"room_alias": room_alias.to_string(), "room_alias": room_alias.to_string(),
"room_id": room_id, "room_id": room_id,
"creator": creator,
}, },
desc="create_room_alias_association", desc="create_room_alias_association",
) )
@ -107,6 +109,17 @@ class DirectoryStore(SQLBaseStore):
) )
self.get_aliases_for_room.invalidate((room_id,)) self.get_aliases_for_room.invalidate((room_id,))
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(
table="room_aliases",
keyvalues={
"room_alias": room_alias,
},
retcol="creator",
desc="get_room_alias_creator",
allow_none=True
)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_room_alias(self, room_alias): def delete_room_alias(self, room_alias):
room_id = yield self.runInteraction( room_id = yield self.runInteraction(

View File

@ -99,30 +99,32 @@ class PushRuleStore(SQLBaseStore):
results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled'] results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
defer.returnValue(results) defer.returnValue(results)
@defer.inlineCallbacks
def add_push_rule( def add_push_rule(
self, user_id, rule_id, priority_class, conditions, actions, self, user_id, rule_id, priority_class, conditions, actions,
before=None, after=None before=None, after=None
): ):
conditions_json = json.dumps(conditions) conditions_json = json.dumps(conditions)
actions_json = json.dumps(actions) actions_json = json.dumps(actions)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
if before or after: if before or after:
return self.runInteraction( yield self.runInteraction(
"_add_push_rule_relative_txn", "_add_push_rule_relative_txn",
self._add_push_rule_relative_txn, self._add_push_rule_relative_txn,
user_id, rule_id, priority_class, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json, before, after, conditions_json, actions_json, before, after,
) )
else: else:
return self.runInteraction( yield self.runInteraction(
"_add_push_rule_highest_priority_txn", "_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn, self._add_push_rule_highest_priority_txn,
user_id, rule_id, priority_class, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json, conditions_json, actions_json,
) )
def _add_push_rule_relative_txn( def _add_push_rule_relative_txn(
self, txn, user_id, rule_id, priority_class, self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json, before, after conditions_json, actions_json, before, after
): ):
# Lock the table since otherwise we'll have annoying races between the # Lock the table since otherwise we'll have annoying races between the
@ -174,12 +176,12 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_id, priority_class, new_rule_priority)) txn.execute(sql, (user_id, priority_class, new_rule_priority))
self._upsert_push_rule_txn( self._upsert_push_rule_txn(
txn, user_id, rule_id, priority_class, new_rule_priority, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json, new_rule_priority, conditions_json, actions_json,
) )
def _add_push_rule_highest_priority_txn( def _add_push_rule_highest_priority_txn(
self, txn, user_id, rule_id, priority_class, self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json conditions_json, actions_json
): ):
# Lock the table since otherwise we'll have annoying races between the # Lock the table since otherwise we'll have annoying races between the
@ -201,13 +203,13 @@ class PushRuleStore(SQLBaseStore):
self._upsert_push_rule_txn( self._upsert_push_rule_txn(
txn, txn,
user_id, rule_id, priority_class, new_prio, stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio,
conditions_json, actions_json, conditions_json, actions_json,
) )
def _upsert_push_rule_txn( def _upsert_push_rule_txn(
self, txn, user_id, rule_id, priority_class, self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
priority, conditions_json, actions_json priority, conditions_json, actions_json, update_stream=True
): ):
"""Specialised version of _simple_upsert_txn that picks a push_rule_id """Specialised version of _simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes using the _push_rule_id_gen if it needs to insert the rule. It assumes
@ -242,11 +244,16 @@ class PushRuleStore(SQLBaseStore):
}, },
) )
txn.call_after( if update_stream:
self.get_push_rules_for_user.invalidate, (user_id,) self._insert_push_rules_update_txn(
) txn, stream_id, event_stream_ordering, user_id, rule_id,
txn.call_after( op="ADD",
self.get_push_rules_enabled_for_user.invalidate, (user_id,) data={
"priority_class": priority_class,
"priority": priority,
"conditions": conditions_json,
"actions": actions_json,
}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -260,25 +267,37 @@ class PushRuleStore(SQLBaseStore):
user_id (str): The matrix ID of the push rule owner user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted rule_id (str): The rule_id of the rule to be deleted
""" """
yield self._simple_delete_one( def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
self._simple_delete_one_txn(
txn,
"push_rules", "push_rules",
{'user_name': user_id, 'rule_id': rule_id}, {'user_name': user_id, 'rule_id': rule_id},
desc="delete_push_rule",
) )
self.get_push_rules_for_user.invalidate((user_id,)) self._insert_push_rules_update_txn(
self.get_push_rules_enabled_for_user.invalidate((user_id,)) txn, stream_id, event_stream_ordering, user_id, rule_id,
op="DELETE"
)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.runInteraction(
"delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering
)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_push_rule_enabled(self, user_id, rule_id, enabled): def set_push_rule_enabled(self, user_id, rule_id, enabled):
ret = yield self.runInteraction( with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.runInteraction(
"_set_push_rule_enabled_txn", "_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn, self._set_push_rule_enabled_txn,
user_id, rule_id, enabled stream_id, event_stream_ordering, user_id, rule_id, enabled
) )
defer.returnValue(ret)
def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled): def _set_push_rule_enabled_txn(
self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
):
new_id = self._push_rules_enable_id_gen.get_next() new_id = self._push_rules_enable_id_gen.get_next()
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
@ -287,25 +306,26 @@ class PushRuleStore(SQLBaseStore):
{'enabled': 1 if enabled else 0}, {'enabled': 1 if enabled else 0},
{'id': new_id}, {'id': new_id},
) )
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,) self._insert_push_rules_update_txn(
) txn, stream_id, event_stream_ordering, user_id, rule_id,
txn.call_after( op="ENABLE" if enabled else "DISABLE"
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
) )
@defer.inlineCallbacks
def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
actions_json = json.dumps(actions) actions_json = json.dumps(actions)
def set_push_rule_actions_txn(txn): def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
if is_default_rule: if is_default_rule:
# Add a dummy rule to the rules table with the user specified # Add a dummy rule to the rules table with the user specified
# actions. # actions.
priority_class = -1 priority_class = -1
priority = 1 priority = 1
self._upsert_push_rule_txn( self._upsert_push_rule_txn(
txn, user_id, rule_id, priority_class, priority, txn, stream_id, event_stream_ordering, user_id, rule_id,
"[]", actions_json priority_class, priority, "[]", actions_json,
update_stream=False
) )
else: else:
self._simple_update_one_txn( self._simple_update_one_txn(
@ -315,8 +335,79 @@ class PushRuleStore(SQLBaseStore):
{'actions': actions_json}, {'actions': actions_json},
) )
return self.runInteraction( self._insert_push_rules_update_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id,
op="ACTIONS", data={"actions": actions_json}
)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.runInteraction(
"set_push_rule_actions", set_push_rule_actions_txn, "set_push_rule_actions", set_push_rule_actions_txn,
stream_id, event_stream_ordering
)
def _insert_push_rules_update_txn(
self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
):
values = {
"stream_id": stream_id,
"event_stream_ordering": event_stream_ordering,
"user_id": user_id,
"rule_id": rule_id,
"op": op,
}
if data is not None:
values.update(data)
self._simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
txn.call_after(
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
def get_all_push_rule_updates(self, last_id, current_id, limit):
"""Get all the push rules changes that have happend on the server"""
def get_all_push_rule_updates_txn(txn):
sql = (
"SELECT stream_id, event_stream_ordering, user_id, rule_id,"
" op, priority_class, priority, conditions, actions"
" FROM push_rules_stream"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
def get_push_rules_stream_token(self):
"""Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_max_token()
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
else:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
count, = txn.fetchone()
return bool(count)
return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
) )

View File

@ -0,0 +1,16 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE room_aliases ADD COLUMN creator TEXT;

View File

@ -0,0 +1,38 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE push_rules_stream(
stream_id BIGINT NOT NULL,
event_stream_ordering BIGINT NOT NULL,
user_id TEXT NOT NULL,
rule_id TEXT NOT NULL,
op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE"
priority_class SMALLINT,
priority INTEGER,
conditions TEXT,
actions TEXT
);
-- The extra data for each operation is:
-- * ENABLE, DISABLE, DELETE: []
-- * ACTIONS: ["actions"]
-- * ADD: ["priority_class", "priority", "actions", "conditions"]
-- Index for replication queries.
CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id);
-- Index for /sync queries.
CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id);

View File

@ -20,23 +20,21 @@ import threading
class IdGenerator(object): class IdGenerator(object):
def __init__(self, db_conn, table, column): def __init__(self, db_conn, table, column):
self.table = table
self.column = column
self._lock = threading.Lock() self._lock = threading.Lock()
cur = db_conn.cursor() self._next_id = _load_max_id(db_conn, table, column)
self._next_id = self._load_next_id(cur)
cur.close()
def _load_next_id(self, txn):
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,))
val, = txn.fetchone()
return val + 1 if val else 1
def get_next(self): def get_next(self):
with self._lock: with self._lock:
i = self._next_id
self._next_id += 1 self._next_id += 1
return i return self._next_id
def _load_max_id(db_conn, table, column):
cur = db_conn.cursor()
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
val, = cur.fetchone()
cur.close()
return val if val else 1
class StreamIdGenerator(object): class StreamIdGenerator(object):
@ -52,23 +50,10 @@ class StreamIdGenerator(object):
# ... persist event ... # ... persist event ...
""" """
def __init__(self, db_conn, table, column): def __init__(self, db_conn, table, column):
self.table = table
self.column = column
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column)
cur = db_conn.cursor()
self._current_max = self._load_current_max(cur)
cur.close()
self._unfinished_ids = deque() self._unfinished_ids = deque()
def _load_current_max(self, txn):
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
rows = txn.fetchall()
val, = rows[0]
return int(val) if val else 1
def get_next(self): def get_next(self):
""" """
Usage: Usage:
@ -124,3 +109,50 @@ class StreamIdGenerator(object):
return self._unfinished_ids[0] - 1 return self._unfinished_ids[0] - 1
return self._current_max return self._current_max
class ChainedIdGenerator(object):
"""Used to generate new stream ids where the stream must be kept in sync
with another stream. It generates pairs of IDs, the first element is an
integer ID for this stream, the second element is the ID for the stream
that this stream needs to be kept in sync with."""
def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator
self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column)
self._unfinished_ids = deque()
def get_next(self):
"""
Usage:
with stream_id_gen.get_next() as (stream_id, chained_id):
# ... persist event ...
"""
with self._lock:
self._current_max += 1
next_id = self._current_max
chained_id = self.chained_generator.get_max_token()
self._unfinished_ids.append((next_id, chained_id))
@contextlib.contextmanager
def manager():
try:
yield (next_id, chained_id)
finally:
with self._lock:
self._unfinished_ids.remove((next_id, chained_id))
return manager()
def get_max_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
with self._lock:
if self._unfinished_ids:
stream_id, chained_id = self._unfinished_ids[0]
return (stream_id - 1, chained_id)
return (self._current_max, self.chained_generator.get_max_token())

View File

@ -38,9 +38,12 @@ class EventSources(object):
name: cls(hs) name: cls(hs)
for name, cls in EventSources.SOURCE_TYPES.items() for name, cls in EventSources.SOURCE_TYPES.items()
} }
self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_token(self, direction='f'): def get_current_token(self, direction='f'):
push_rules_key, _ = self.store.get_push_rules_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
yield self.sources["room"].get_current_key(direction) yield self.sources["room"].get_current_key(direction)
@ -57,5 +60,6 @@ class EventSources(object):
account_data_key=( account_data_key=(
yield self.sources["account_data"].get_current_key() yield self.sources["account_data"].get_current_key()
), ),
push_rules_key=push_rules_key,
) )
defer.returnValue(token) defer.returnValue(token)

View File

@ -115,6 +115,7 @@ class StreamToken(
"typing_key", "typing_key",
"receipt_key", "receipt_key",
"account_data_key", "account_data_key",
"push_rules_key",
)) ))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -150,6 +151,7 @@ class StreamToken(
or (int(other.typing_key) < int(self.typing_key)) or (int(other.typing_key) < int(self.typing_key))
or (int(other.receipt_key) < int(self.receipt_key)) or (int(other.receipt_key) < int(self.receipt_key))
or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.account_data_key) < int(self.account_data_key))
or (int(other.push_rules_key) < int(self.push_rules_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):
@ -174,6 +176,11 @@ class StreamToken(
return StreamToken(**d) return StreamToken(**d)
StreamToken.START = StreamToken(
*(["s0"] + ["0"] * (len(StreamToken._fields) - 1))
)
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
"""Tokens are positions between events. The token "s1" comes after event 1. """Tokens are positions between events. The token "s1" comes after event 1.

View File

@ -69,7 +69,7 @@ class ExpiringCache(object):
if self._max_len and len(self._cache.keys()) > self._max_len: if self._max_len and len(self._cache.keys()) > self._max_len:
sorted_entries = sorted( sorted_entries = sorted(
self._cache.items(), self._cache.items(),
key=lambda (k, v): v.time, key=lambda item: item[1].time,
) )
for k, _ in sorted_entries[self._max_len:]: for k, _ in sorted_entries[self._max_len:]:

View File

@ -23,7 +23,7 @@ from synapse.api.errors import AuthError
from synapse.handlers.profile import ProfileHandler from synapse.handlers.profile import ProfileHandler
from synapse.types import UserID from synapse.types import UserID
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver, requester_for_user
class ProfileHandlers(object): class ProfileHandlers(object):
@ -84,7 +84,11 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name(self): def test_set_my_name(self):
yield self.handler.set_displayname(self.frank, self.frank, "Frank Jr.") yield self.handler.set_displayname(
self.frank,
requester_for_user(self.frank),
"Frank Jr."
)
self.assertEquals( self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), (yield self.store.get_profile_displayname(self.frank.localpart)),
@ -93,7 +97,11 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_name_noauth(self): def test_set_my_name_noauth(self):
d = self.handler.set_displayname(self.frank, self.bob, "Frank Jr.") d = self.handler.set_displayname(
self.frank,
requester_for_user(self.bob),
"Frank Jr."
)
yield self.assertFailure(d, AuthError) yield self.assertFailure(d, AuthError)
@ -136,7 +144,7 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_set_my_avatar(self): def test_set_my_avatar(self):
yield self.handler.set_avatar_url( yield self.handler.set_avatar_url(
self.frank, self.frank, "http://my.server/pic.gif" self.frank, requester_for_user(self.frank), "http://my.server/pic.gif"
) )
self.assertEquals( self.assertEquals(

View File

@ -18,7 +18,7 @@ from synapse.types import Requester, UserID
from twisted.internet import defer from twisted.internet import defer
from tests import unittest from tests import unittest
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver, requester_for_user
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
import json import json
import contextlib import contextlib
@ -35,7 +35,8 @@ class ReplicationResourceCase(unittest.TestCase):
"send_message", "send_message",
]), ]),
) )
self.user = UserID.from_string("@seeing:red") self.user_id = "@seeing:red"
self.user = UserID.from_string(self.user_id)
self.hs.get_ratelimiter().send_message.return_value = (True, 0) self.hs.get_ratelimiter().send_message.return_value = (True, 0)
@ -101,7 +102,7 @@ class ReplicationResourceCase(unittest.TestCase):
event_id = yield self.send_text_message(room_id, "Hello, World") event_id = yield self.send_text_message(room_id, "Hello, World")
get = self.get(receipts="-1") get = self.get(receipts="-1")
yield self.hs.get_handlers().receipts_handler.received_client_receipt( yield self.hs.get_handlers().receipts_handler.received_client_receipt(
room_id, "m.read", self.user.to_string(), event_id room_id, "m.read", self.user_id, event_id
) )
code, body = yield get code, body = yield get
self.assertEquals(code, 200) self.assertEquals(code, 200)
@ -129,16 +130,20 @@ class ReplicationResourceCase(unittest.TestCase):
test_timeout_room_account_data = _test_timeout("room_account_data") test_timeout_room_account_data = _test_timeout("room_account_data")
test_timeout_tag_account_data = _test_timeout("tag_account_data") test_timeout_tag_account_data = _test_timeout("tag_account_data")
test_timeout_backfill = _test_timeout("backfill") test_timeout_backfill = _test_timeout("backfill")
test_timeout_push_rules = _test_timeout("push_rules")
@defer.inlineCallbacks @defer.inlineCallbacks
def send_text_message(self, room_id, message): def send_text_message(self, room_id, message):
handler = self.hs.get_handlers().message_handler handler = self.hs.get_handlers().message_handler
event = yield handler.create_and_send_nonmember_event({ event = yield handler.create_and_send_nonmember_event(
requester_for_user(self.user),
{
"type": "m.room.message", "type": "m.room.message",
"content": {"body": "message", "msgtype": "m.text"}, "content": {"body": "message", "msgtype": "m.text"},
"room_id": room_id, "room_id": room_id,
"sender": self.user.to_string(), "sender": self.user.to_string(),
}) }
)
defer.returnValue(event.event_id) defer.returnValue(event.event_id)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -86,7 +86,7 @@ class ProfileTestCase(unittest.TestCase):
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][1].localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.") self.assertEquals(mocked_set.call_args[0][2], "Frank Jr.")
@defer.inlineCallbacks @defer.inlineCallbacks
@ -155,5 +155,5 @@ class ProfileTestCase(unittest.TestCase):
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][0].localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][1].localpart, "1234ABCD") self.assertEquals(mocked_set.call_args[0][1].user.localpart, "1234ABCD")
self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif") self.assertEquals(mocked_set.call_args[0][2], "http://my.server/pic.gif")

View File

@ -54,13 +54,13 @@ class RoomPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False, "is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -419,13 +419,13 @@ class RoomsMemberListTestCase(RestTestCase):
self.auth_user_id = self.user_id self.auth_user_id = self.user_id
def _get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False, "is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -501,13 +501,13 @@ class RoomsCreateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False, "is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -593,14 +593,14 @@ class RoomTopicTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False, "is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -708,13 +708,13 @@ class RoomMemberStateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False, "is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -840,13 +840,13 @@ class RoomMessagesTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False, "is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -942,13 +942,13 @@ class RoomInitialSyncTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False, "is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -1014,13 +1014,13 @@ class RoomMessageListTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False, "is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)
@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_topo_token_is_accepted(self): def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0" token = "t1-0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self): def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0" token = "s0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))

View File

@ -61,14 +61,14 @@ class RoomTypingTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False, "is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth().get_user_by_access_token = get_user_by_access_token
def _insert_client_ip(*args, **kwargs): def _insert_client_ip(*args, **kwargs):
return defer.succeed(None) return defer.succeed(None)

View File

@ -43,13 +43,13 @@ class V2AlphaRestTestCase(unittest.TestCase):
resource_for_federation=self.mock_resource, resource_for_federation=self.mock_resource,
) )
def _get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.USER_ID), "user": UserID.from_string(self.USER_ID),
"token_id": 1, "token_id": 1,
"is_guest": False, "is_guest": False,
} }
hs.get_auth()._get_user_by_access_token = _get_user_by_access_token hs.get_auth().get_user_by_access_token = get_user_by_access_token
for r in self.TO_REGISTER: for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource) r.register_servlets(hs, self.mock_resource)

View File

@ -20,6 +20,7 @@ from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.federation.transport import server from synapse.federation.transport import server
from synapse.types import Requester
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -50,6 +51,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.macaroon_secret_key = "not even a little secret" config.macaroon_secret_key = "not even a little secret"
config.server_name = "server.under.test" config.server_name = "server.under.test"
config.trusted_third_party_id_servers = [] config.trusted_third_party_id_servers = []
config.room_invite_state_types = []
config.database_config = {"name": "sqlite3"} config.database_config = {"name": "sqlite3"}
@ -510,3 +512,7 @@ class DeferredMockCallable(object):
"call(%s)" % _format_call(c[0], c[1]) for c in calls "call(%s)" % _format_call(c[0], c[1]) for c in calls
]) ])
) )
def requester_for_user(user):
return Requester(user, None, False)