Compare commits

...

9 Commits

Author SHA1 Message Date
Richard van der Hoff 5f4e4819ca fix lint and tests 2020-05-01 18:20:32 +01:00
Richard van der Hoff 3356d9fe23 Fix type annotations 2020-05-01 17:59:32 +01:00
Richard van der Hoff ba2b934307 Merge remote-tracking branch 'origin/develop' into rav/fix_account_data_catchup 2020-05-01 17:50:09 +01:00
Erik Johnston 0e719f2398
Thread through instance name to replication client. (#7369)
For in memory streams when fetching updates on workers we need to query the source of the stream, which currently is hard coded to be master. This PR threads through the source instance we received via `POSITION` through to the update function in each stream, which can then be passed to the replication client for in memory streams.
2020-05-01 17:19:56 +01:00
Erik Johnston 3085cde577
Use `stream.current_token()` and remove `stream_positions()` (#7172)
We move the processing of typing and federation replication traffic into their handlers so that `Stream.current_token()` points to a valid token. This allows us to remove `get_streams_to_replicate()` and `stream_positions()`.
2020-05-01 15:21:35 +01:00
Andrew Morgan 6b22921b19
async/await is_server_admin (#7363) 2020-05-01 15:15:36 +01:00
Andrew Morgan 2e8955f4a6
Further improvements to requesting the public rooms list on a homeserver which has it set to private (#7368) 2020-05-01 15:15:08 +01:00
Richard van der Hoff b2dba06079
Workaround for assertion errors from db_query_to_update_function (#7378)
Hopefully this is no worse than what we have on master...
2020-05-01 09:25:16 +01:00
Patrick Cloke 627b0f5f27
Persist user interactive authentication sessions (#7302)
By persisting the user interactive authentication sessions to the database, this fixes
situations where a user hits different works throughout their auth session and also
allows sessions to persist through restarts of Synapse.
2020-04-30 13:47:49 -04:00
65 changed files with 1061 additions and 798 deletions

1
changelog.d/7172.misc Normal file
View File

@ -0,0 +1 @@
Use `stream.current_token()` and remove `stream_positions()`.

1
changelog.d/7302.bugfix Normal file
View File

@ -0,0 +1 @@
Persist user interactive authentication sessions across workers and Synapse restarts.

1
changelog.d/7363.misc Normal file
View File

@ -0,0 +1 @@
Convert RegistrationWorkerStore.is_server_admin and dependent code to async/await.

1
changelog.d/7368.bugfix Normal file
View File

@ -0,0 +1 @@
Improve error responses when accessing remote public room lists.

1
changelog.d/7369.misc Normal file
View File

@ -0,0 +1 @@
Thread through instance name to replication client.

1
changelog.d/7378.misc Normal file
View File

@ -0,0 +1 @@
Move catchup of replication streams logic to worker.

View File

@ -537,8 +537,7 @@ class Auth(object):
return defer.succeed(auth_ids)
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id: str, user: UserID):
async def check_can_change_room_list(self, room_id: str, user: UserID):
"""Determine whether the user is allowed to edit the room's entry in the
published room list.
@ -547,17 +546,17 @@ class Auth(object):
user
"""
is_admin = yield self.is_server_admin(user)
is_admin = await self.is_server_admin(user)
if is_admin:
return True
user_id = user.to_string()
yield self.check_user_in_room(room_id, user_id)
await self.check_user_in_room(room_id, user_id)
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
# m.room.canonical_alias events
power_level_event = yield self.state.get_current_state(
power_level_event = await self.state.get_current_state(
room_id, EventTypes.PowerLevels, ""
)

View File

@ -127,6 +127,7 @@ from synapse.storage.data_stores.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.data_stores.main.presence import UserPresenceState
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
@ -412,12 +413,6 @@ class GenericWorkerTyping(object):
# map room IDs to sets of users currently typing
self._room_typing = {}
def stream_positions(self):
# We must update this typing token from the response of the previous
# sync. In particular, the stream id may "reset" back to zero/a low
# value which we *must* use for the next replication request.
return {"typing": self._latest_room_serial}
def process_replication_rows(self, token, rows):
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
@ -439,6 +434,7 @@ class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
UserDirectoryStore,
UIAuthWorkerStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedReceiptsStore,
@ -650,20 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
else:
self.send_handler = None
async def on_rdata(self, stream_name, token, rows):
await super(GenericWorkerReplicationHandler, self).on_rdata(
stream_name, token, rows
)
await self.process_and_notify(stream_name, token, rows)
async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
await self._process_and_notify(stream_name, instance_name, token, rows)
def get_streams_to_replicate(self):
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
args.update(self.typing_handler.stream_positions())
if self.send_handler:
args.update(self.send_handler.stream_positions())
return args
async def process_and_notify(self, stream_name, token, rows):
async def _process_and_notify(self, stream_name, instance_name, token, rows):
try:
if self.send_handler:
await self.send_handler.process_replication_rows(
@ -797,9 +784,6 @@ class FederationSenderHandler(object):
def wake_destination(self, server: str):
self.federation_sender.wake_destination(server)
def stream_positions(self):
return {"federation": self.federation_position}
async def process_replication_rows(self, stream_name, token, rows):
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.

View File

@ -883,18 +883,37 @@ class FederationClient(FederationBase):
def get_public_rooms(
self,
destination,
limit=None,
since_token=None,
search_filter=None,
include_all_networks=False,
third_party_instance_id=None,
remote_server: str,
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[Dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
):
if destination == self.server_name:
return
"""Get the list of public rooms from a remote homeserver
Args:
remote_server: The name of the remote server
limit: Maximum amount of rooms to return
since_token: Used for result pagination
search_filter: A filter dictionary to send the remote homeserver
and filter the result set
include_all_networks: Whether to include results from all third party instances
third_party_instance_id: Whether to only include results from a specific third
party instance
Returns:
Deferred[Dict[str, Any]]: The response from the remote server, or None if
`remote_server` is the same as the local server_name
Raises:
HttpResponseException: There was an exception returned from the remote server
SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom
requests over federation
"""
return self.transport_layer.get_public_rooms(
destination,
remote_server,
limit,
since_token,
search_filter,
@ -957,14 +976,13 @@ class FederationClient(FederationBase):
return signed_events
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
async def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
if destination == self.server_name:
continue
try:
yield self.transport_layer.exchange_third_party_invite(
await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
return None

View File

@ -15,13 +15,14 @@
# limitations under the License.
import logging
from typing import Any, Dict
from typing import Any, Dict, Optional
from six.moves import urllib
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
FEDERATION_V1_PREFIX,
@ -326,18 +327,25 @@ class TransportLayerClient(object):
@log_function
def get_public_rooms(
self,
remote_server,
limit,
since_token,
search_filter=None,
include_all_networks=False,
third_party_instance_id=None,
remote_server: str,
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[Dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
):
"""Get the list of public rooms from a remote homeserver
See synapse.federation.federation_client.FederationClient.get_public_rooms for
more information.
"""
if search_filter:
# this uses MSC2197 (Search Filtering over Federation)
path = _create_v1_path("/publicRooms")
data = {"include_all_networks": "true" if include_all_networks else "false"}
data = {
"include_all_networks": "true" if include_all_networks else "false"
} # type: Dict[str, Any]
if third_party_instance_id:
data["third_party_instance_id"] = third_party_instance_id
if limit:
@ -347,9 +355,19 @@ class TransportLayerClient(object):
data["filter"] = search_filter
response = yield self.client.post_json(
destination=remote_server, path=path, data=data, ignore_backoff=True
)
try:
response = yield self.client.post_json(
destination=remote_server, path=path, data=data, ignore_backoff=True
)
except HttpResponseException as e:
if e.code == 403:
raise SynapseError(
403,
"You are not allowed to view the public rooms list of %s"
% (remote_server,),
errcode=Codes.FORBIDDEN,
)
raise
else:
path = _create_v1_path("/publicRooms")
@ -363,9 +381,19 @@ class TransportLayerClient(object):
if since_token:
args["since"] = [since_token]
response = yield self.client.get_json(
destination=remote_server, path=path, args=args, ignore_backoff=True
)
try:
response = yield self.client.get_json(
destination=remote_server, path=path, args=args, ignore_backoff=True
)
except HttpResponseException as e:
if e.code == 403:
raise SynapseError(
403,
"You are not allowed to view the public rooms list of %s"
% (remote_server,),
errcode=Codes.FORBIDDEN,
)
raise
return response

View File

@ -748,17 +748,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise NotImplementedError()
@defer.inlineCallbacks
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):
"""Remove a user from the group; either a user is leaving or an admin
kicked them.
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_kick = False
if requester_user_id != user_id:
is_admin = yield self.store.is_user_admin_in_group(
is_admin = await self.store.is_user_admin_in_group(
group_id, requester_user_id
)
if not is_admin:
@ -766,30 +767,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
is_kick = True
yield self.store.remove_user_from_group(group_id, user_id)
await self.store.remove_user_from_group(group_id, user_id)
if is_kick:
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
yield groups_local.user_removed_from_group(group_id, user_id, {})
await groups_local.user_removed_from_group(group_id, user_id, {})
else:
yield self.transport_client.remove_user_from_group_notification(
await self.transport_client.remove_user_from_group_notification(
get_domain_from_id(user_id), group_id, user_id, {}
)
if not self.hs.is_mine_id(user_id):
yield self.store.maybe_delete_remote_profile_cache(user_id)
await self.store.maybe_delete_remote_profile_cache(user_id)
# Delete group if the last user has left
users = yield self.store.get_users_in_group(group_id, include_private=True)
users = await self.store.get_users_in_group(group_id, include_private=True)
if not users:
yield self.store.delete_group(group_id)
await self.store.delete_group(group_id)
return {}
@defer.inlineCallbacks
def create_group(self, group_id, requester_user_id, content):
group = yield self.check_group_is_ours(group_id, requester_user_id)
async def create_group(self, group_id, requester_user_id, content):
group = await self.check_group_is_ours(group_id, requester_user_id)
logger.info("Attempting to create group with ID: %r", group_id)
@ -799,7 +799,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if group:
raise SynapseError(400, "Group already exists")
is_admin = yield self.auth.is_server_admin(
is_admin = await self.auth.is_server_admin(
UserID.from_string(requester_user_id)
)
if not is_admin:
@ -822,7 +822,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
long_description = profile.get("long_description")
user_profile = content.get("user_profile", {})
yield self.store.create_group(
await self.store.create_group(
group_id,
requester_user_id,
name=name,
@ -834,7 +834,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
if not self.hs.is_mine_id(requester_user_id):
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
await self.attestations.verify_attestation(
remote_attestation, user_id=requester_user_id, group_id=group_id
)
@ -845,7 +845,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
local_attestation = None
remote_attestation = None
yield self.store.add_user_to_group(
await self.store.add_user_to_group(
group_id,
requester_user_id,
is_admin=True,
@ -855,7 +855,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
)
if not self.hs.is_mine_id(requester_user_id):
yield self.store.add_remote_profile_cache(
await self.store.add_remote_profile_cache(
requester_user_id,
displayname=user_profile.get("displayname"),
avatar_url=user_profile.get("avatar_url"),
@ -863,8 +863,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"group_id": group_id}
@defer.inlineCallbacks
def delete_group(self, group_id, requester_user_id):
async def delete_group(self, group_id, requester_user_id):
"""Deletes a group, kicking out all current members.
Only group admins or server admins can call this request
@ -877,14 +876,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
Deferred
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
# Only server admins or group admins can delete groups.
is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id)
is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id)
if not is_admin:
is_admin = yield self.auth.is_server_admin(
is_admin = await self.auth.is_server_admin(
UserID.from_string(requester_user_id)
)
@ -892,18 +891,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise SynapseError(403, "User is not an admin")
# Before deleting the group lets kick everyone out of it
users = yield self.store.get_users_in_group(group_id, include_private=True)
users = await self.store.get_users_in_group(group_id, include_private=True)
@defer.inlineCallbacks
def _kick_user_from_group(user_id):
async def _kick_user_from_group(user_id):
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
yield groups_local.user_removed_from_group(group_id, user_id, {})
await groups_local.user_removed_from_group(group_id, user_id, {})
else:
yield self.transport_client.remove_user_from_group_notification(
await self.transport_client.remove_user_from_group_notification(
get_domain_from_id(user_id), group_id, user_id, {}
)
yield self.store.maybe_delete_remote_profile_cache(user_id)
await self.store.maybe_delete_remote_profile_cache(user_id)
# We kick users out in the order of:
# 1. Non-admins
@ -922,11 +920,11 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
else:
non_admins.append(u["user_id"])
yield concurrently_execute(_kick_user_from_group, non_admins, 10)
yield concurrently_execute(_kick_user_from_group, admins, 10)
yield _kick_user_from_group(requester_user_id)
await concurrently_execute(_kick_user_from_group, non_admins, 10)
await concurrently_execute(_kick_user_from_group, admins, 10)
await _kick_user_from_group(requester_user_id)
yield self.store.delete_group(group_id)
await self.store.delete_group(group_id)
def _parse_join_policy_from_contents(content):

View File

@ -126,30 +126,28 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now))
)
@defer.inlineCallbacks
def maybe_kick_guest_users(self, event, context=None):
async def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
current_state_ids = yield context.get_current_state_ids()
current_state = yield self.store.get_events(
current_state_ids = await context.get_current_state_ids()
current_state = await self.store.get_events(
list(current_state_ids.values())
)
else:
current_state = yield self.state_handler.get_current_state(
current_state = await self.state_handler.get_current_state(
event.room_id
)
current_state = list(current_state.values())
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state)
await self.kick_guest_users(current_state)
@defer.inlineCallbacks
def kick_guest_users(self, current_state):
async def kick_guest_users(self, current_state):
for member_event in current_state:
try:
if member_event.type != EventTypes.Member:
@ -180,7 +178,7 @@ class BaseHandler(object):
# homeserver.
requester = synapse.types.create_requester(target_user, is_guest=True)
handler = self.hs.get_room_member_handler()
yield handler.update_membership(
await handler.update_membership(
requester,
target_user,
member_event.room_id,

View File

@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.server import finish_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler
@ -69,15 +69,6 @@ class AuthHandler(BaseHandler):
self.bcrypt_rounds = hs.config.bcrypt_rounds
# This is not a cache per se, but a store of all current sessions that
# expire after N hours
self.sessions = ExpiringCache(
cache_name="register_sessions",
clock=hs.get_clock(),
expiry_ms=self.SESSION_EXPIRE_MS,
reset_expiry_on_get=True,
)
account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
@ -119,6 +110,15 @@ class AuthHandler(BaseHandler):
self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time.
if hs.config.worker_app is None:
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
"expire_old_sessions",
self._expire_old_sessions,
)
# Load the SSO HTML templates.
# The following template is shown to the user during a client login via SSO,
@ -301,16 +301,21 @@ class AuthHandler(BaseHandler):
if "session" in authdict:
sid = authdict["session"]
# Convert the URI and method to strings.
uri = request.uri.decode("utf-8")
method = request.uri.decode("utf-8")
# If there's no session ID, create a new session.
if not sid:
session = self._create_session(
clientdict, (request.uri, request.method, clientdict), description
session = await self.store.create_ui_auth_session(
clientdict, uri, method, description
)
session_id = session["id"]
else:
session = self._get_session_info(sid)
session_id = sid
try:
session = await self.store.get_ui_auth_session(sid)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (sid,))
if not clientdict:
# This was designed to allow the client to omit the parameters
@ -322,15 +327,15 @@ class AuthHandler(BaseHandler):
# on a homeserver.
# Revisit: Assuming the REST APIs do sensible validation, the data
# isn't arbitrary.
clientdict = session["clientdict"]
clientdict = session.clientdict
# Ensure that the queried operation does not vary between stages of
# the UI authentication session. This is done by generating a stable
# comparator based on the URI, method, and body (minus the auth dict)
# and storing it during the initial query. Subsequent queries ensure
# that this comparator has not changed.
comparator = (request.uri, request.method, clientdict)
if session["ui_auth"] != comparator:
comparator = (uri, method, clientdict)
if (session.uri, session.method, session.clientdict) != comparator:
raise SynapseError(
403,
"Requested operation has changed during the UI authentication session.",
@ -338,11 +343,9 @@ class AuthHandler(BaseHandler):
if not authdict:
raise InteractiveAuthIncompleteError(
self._auth_dict_for_flows(flows, session_id)
self._auth_dict_for_flows(flows, session.session_id)
)
creds = session["creds"]
# check auth type currently being presented
errordict = {} # type: Dict[str, Any]
if "type" in authdict:
@ -350,8 +353,9 @@ class AuthHandler(BaseHandler):
try:
result = await self._check_auth_dict(authdict, clientip)
if result:
creds[login_type] = result
self._save_session(session)
await self.store.mark_ui_auth_stage_complete(
session.session_id, login_type, result
)
except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
@ -367,6 +371,7 @@ class AuthHandler(BaseHandler):
# so that the client can have another go.
errordict = e.error_dict()
creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
@ -380,9 +385,9 @@ class AuthHandler(BaseHandler):
list(clientdict),
)
return creds, clientdict, session_id
return creds, clientdict, session.session_id
ret = self._auth_dict_for_flows(flows, session_id)
ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
@ -399,13 +404,11 @@ class AuthHandler(BaseHandler):
if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
sess = self._get_session_info(authdict["session"])
creds = sess["creds"]
result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
await self.store.mark_ui_auth_stage_complete(
authdict["session"], stagetype, result
)
return True
return False
@ -427,7 +430,7 @@ class AuthHandler(BaseHandler):
sid = authdict["session"]
return sid
def set_session_data(self, session_id: str, key: str, value: Any) -> None:
async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
@ -438,11 +441,12 @@ class AuthHandler(BaseHandler):
key: The key to store the data under
value: The data to store
"""
sess = self._get_session_info(session_id)
sess["serverdict"][key] = value
self._save_session(sess)
try:
await self.store.set_ui_auth_session_data(session_id, key, value)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
def get_session_data(
async def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
"""
@ -453,8 +457,18 @@ class AuthHandler(BaseHandler):
key: The key to store the data under
default: Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess["serverdict"].get(key, default)
try:
return await self.store.get_ui_auth_session_data(session_id, key, default)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
async def _expire_old_sessions(self):
"""
Invalidate any user interactive authentication sessions that have expired.
"""
now = self._clock.time_msec()
expiration_time = now - self.SESSION_EXPIRE_MS
await self.store.delete_old_ui_auth_sessions(expiration_time)
async def _check_auth_dict(
self, authdict: Dict[str, Any], clientip: str
@ -534,67 +548,6 @@ class AuthHandler(BaseHandler):
"params": params,
}
def _create_session(
self,
clientdict: Dict[str, Any],
ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
description: str,
) -> dict:
"""
Creates a new user interactive authentication session.
The session can be used to track data across multiple requests, e.g. for
interactive authentication.
Each session has the following keys:
id:
A unique identifier for this session. Passed back to the client
and returned for each stage.
clientdict:
The dictionary from the client root level, not the 'auth' key.
ui_auth:
A tuple which is checked at each stage of the authentication to
ensure that the asked for operation has not changed.
creds:
A map, which maps each auth-type (str) to the relevant identity
authenticated by that auth-type (mostly str, but for captcha, bool).
serverdict:
A map of data that is stored server-side and cannot be modified
by the client.
description:
A string description of the operation that the current
authentication is authorising.
Returns:
The newly created session.
"""
session_id = None
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
self.sessions[session_id] = {
"id": session_id,
"clientdict": clientdict,
"ui_auth": ui_auth,
"creds": {},
"serverdict": {},
"description": description,
}
return self.sessions[session_id]
def _get_session_info(self, session_id: str) -> dict:
"""
Gets a session given a session ID.
The session can be used to track data across multiple requests, e.g. for
interactive authentication.
"""
try:
return self.sessions[session_id]
except KeyError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
@ -994,13 +947,6 @@ class AuthHandler(BaseHandler):
await self.store.user_delete_threepid(user_id, medium, address)
return result
def _save_session(self, session: Dict[str, Any]) -> None:
"""Update the last used time on the session to now and add it back to the session store."""
# TODO: Persistent storage
logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session
async def hash(self, password: str) -> str:
"""Computes a secure hash of password.
@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler):
else:
return False
def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
"""
Get the HTML for the SSO redirect confirmation page.
@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler):
Returns:
The HTML to render.
"""
session = self._get_session_info(session_id)
try:
session = await self.store.get_ui_auth_session(session_id)
except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
return self._sso_auth_confirm_template.render(
description=session["description"], redirect_url=redirect_url,
description=session.description, redirect_url=redirect_url,
)
def complete_sso_ui_auth(
async def complete_sso_ui_auth(
self, registered_user_id: str, session_id: str, request: SynapseRequest,
):
"""Having figured out a mxid for this user, complete the HTTP request
@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler):
process.
"""
# Mark the stage of the authentication as successful.
sess = self._get_session_info(session_id)
creds = sess["creds"]
# Save the user who authenticated with SSO, this will be used to ensure
# that the account be modified is also the person who logged in.
creds[LoginType.SSO] = registered_user_id
self._save_session(sess)
await self.store.mark_ui_auth_stage_complete(
session_id, LoginType.SSO, registered_user_id
)
# Render the HTML and return.
html_bytes = self._sso_auth_success_template.encode("utf-8")

View File

@ -206,7 +206,7 @@ class CasHandler:
registered_user_id = await self._auth_handler.check_user_exists(user_id)
if session:
self._auth_handler.complete_sso_ui_auth(
await self._auth_handler.complete_sso_ui_auth(
registered_user_id, session, request,
)

View File

@ -86,8 +86,7 @@ class DirectoryHandler(BaseHandler):
room_alias, room_id, servers, creator=creator
)
@defer.inlineCallbacks
def create_association(
async def create_association(
self,
requester: Requester,
room_alias: RoomAlias,
@ -129,10 +128,10 @@ class DirectoryHandler(BaseHandler):
else:
# Server admins are not subject to the same constraints as normal
# users when creating an alias (e.g. being in the room).
is_admin = yield self.auth.is_server_admin(requester.user)
is_admin = await self.auth.is_server_admin(requester.user)
if (self.require_membership and check_membership) and not is_admin:
rooms_for_user = yield self.store.get_rooms_for_user(user_id)
rooms_for_user = await self.store.get_rooms_for_user(user_id)
if room_id not in rooms_for_user:
raise AuthError(
403, "You must be in the room to create an alias for it"
@ -149,7 +148,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to create alias")
can_create = yield self.can_modify_alias(room_alias, user_id=user_id)
can_create = await self.can_modify_alias(room_alias, user_id=user_id)
if not can_create:
raise AuthError(
400,
@ -157,10 +156,9 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
yield self._create_association(room_alias, room_id, servers, creator=user_id)
await self._create_association(room_alias, room_id, servers, creator=user_id)
@defer.inlineCallbacks
def delete_association(self, requester: Requester, room_alias: RoomAlias):
async def delete_association(self, requester: Requester, room_alias: RoomAlias):
"""Remove an alias from the directory
(this is only meant for human users; AS users should call
@ -184,7 +182,7 @@ class DirectoryHandler(BaseHandler):
user_id = requester.user.to_string()
try:
can_delete = yield self._user_can_delete_alias(room_alias, user_id)
can_delete = await self._user_can_delete_alias(room_alias, user_id)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown room alias")
@ -193,7 +191,7 @@ class DirectoryHandler(BaseHandler):
if not can_delete:
raise AuthError(403, "You don't have permission to delete the alias.")
can_delete = yield self.can_modify_alias(room_alias, user_id=user_id)
can_delete = await self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete:
raise SynapseError(
400,
@ -201,10 +199,10 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
room_id = yield self._delete_association(room_alias)
room_id = await self._delete_association(room_alias)
try:
yield self._update_canonical_alias(requester, user_id, room_id, room_alias)
await self._update_canonical_alias(requester, user_id, room_id, room_alias)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
@ -296,15 +294,14 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND,
)
@defer.inlineCallbacks
def _update_canonical_alias(
async def _update_canonical_alias(
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
):
"""
Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field.
"""
alias_event = yield self.state.get_current_state(
alias_event = await self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, ""
)
@ -335,7 +332,7 @@ class DirectoryHandler(BaseHandler):
del content["alt_aliases"]
if send_update:
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
@ -376,8 +373,7 @@ class DirectoryHandler(BaseHandler):
# either no interested services, or no service with an exclusive lock
return defer.succeed(True)
@defer.inlineCallbacks
def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
"""Determine whether a user can delete an alias.
One of the following must be true:
@ -388,24 +384,23 @@ class DirectoryHandler(BaseHandler):
for the current room.
"""
creator = yield self.store.get_room_alias_creator(alias.to_string())
creator = await self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id:
return True
# Resolve the alias to the corresponding room.
room_mapping = yield self.get_association(alias)
room_mapping = await self.get_association(alias)
room_id = room_mapping["room_id"]
if not room_id:
return False
res = yield self.auth.check_can_change_room_list(
res = await self.auth.check_can_change_room_list(
room_id, UserID.from_string(user_id)
)
return res
@defer.inlineCallbacks
def edit_published_room_list(
async def edit_published_room_list(
self, requester: Requester, room_id: str, visibility: str
):
"""Edit the entry of the room in the published room list.
@ -433,11 +428,11 @@ class DirectoryHandler(BaseHandler):
403, "This user is not permitted to publish rooms to the room list"
)
room = yield self.store.get_room(room_id)
room = await self.store.get_room(room_id)
if room is None:
raise SynapseError(400, "Unknown room")
can_change_room_list = yield self.auth.check_can_change_room_list(
can_change_room_list = await self.auth.check_can_change_room_list(
room_id, requester.user
)
if not can_change_room_list:
@ -449,8 +444,8 @@ class DirectoryHandler(BaseHandler):
making_public = visibility == "public"
if making_public:
room_aliases = yield self.store.get_aliases_for_room(room_id)
canonical_alias = yield self.store.get_canonical_alias_for_room(room_id)
room_aliases = await self.store.get_aliases_for_room(room_id)
canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
if canonical_alias:
room_aliases.append(canonical_alias)
@ -462,7 +457,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to publish room")
yield self.store.set_room_is_public(room_id, making_public)
await self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks
def edit_published_appservice_room_list(

View File

@ -2562,9 +2562,8 @@ class FederationHandler(BaseHandler):
"missing": [e.event_id for e in missing_locals],
}
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(
async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed
):
third_party_invite = {"signed": signed}
@ -2580,16 +2579,16 @@ class FederationHandler(BaseHandler):
"state_key": target_user_id,
}
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
room_version = yield self.store.get_room_version_id(room_id)
if await self.auth.check_host_in_room(room_id, self.hs.hostname):
room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
event_allowed = yield self.third_party_event_rules.check_event_allowed(
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@ -2601,7 +2600,7 @@ class FederationHandler(BaseHandler):
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
event, context = yield self.add_display_name_to_third_party_invite(
event, context = await self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context
)
@ -2612,19 +2611,19 @@ class FederationHandler(BaseHandler):
event.internal_metadata.send_on_behalf_of = self.hs.hostname
try:
yield self.auth.check_from_context(room_version, event, context)
await self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warning("Denying new third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
await self._check_signature(event, context)
# We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
await member_handler.send_membership_event(None, event, context)
else:
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
yield self.federation_client.forward_third_party_invite(
await self.federation_client.forward_third_party_invite(
destinations, room_id, event_dict
)

View File

@ -284,15 +284,14 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy")
@defer.inlineCallbacks
def create_group(self, group_id, user_id, content):
async def create_group(self, group_id, user_id, content):
"""Create a group
"""
logger.info("Asking to create group with ID: %r", group_id)
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.create_group(
res = await self.groups_server_handler.create_group(
group_id, user_id, content
)
local_attestation = None
@ -301,10 +300,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
content["user_profile"] = await self.profile_handler.get_profile(user_id)
try:
res = yield self.transport_client.create_group(
res = await self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content
)
except HttpResponseException as e:
@ -313,7 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
yield self.attestations.verify_attestation(
await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
@ -321,7 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
)
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
token = await self.store.register_user_group_membership(
group_id,
user_id,
membership="join",
@ -482,12 +481,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile}
@defer.inlineCallbacks
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):
"""Remove a user from a group
"""
if user_id == requester_user_id:
token = yield self.store.register_user_group_membership(
token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave"
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
@ -496,13 +496,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
# retry if the group server is currently down.
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.remove_user_from_group(
res = await self.groups_server_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content
)
else:
content["requester_user_id"] = requester_user_id
try:
res = yield self.transport_client.remove_user_from_group(
res = await self.transport_client.remove_user_from_group(
get_domain_from_id(group_id),
group_id,
requester_user_id,

View File

@ -626,8 +626,7 @@ class EventCreationHandler(object):
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
async def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
@ -647,7 +646,7 @@ class EventCreationHandler(object):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
prev_state = await self.deduplicate_state_event(event, context)
if prev_state is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
@ -656,7 +655,7 @@ class EventCreationHandler(object):
)
return prev_state
yield self.handle_new_client_event(
await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
)
@ -683,8 +682,7 @@ class EventCreationHandler(object):
return prev_event
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
async def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None
):
"""
@ -698,8 +696,8 @@ class EventCreationHandler(object):
# a situation where event persistence can't keep up, causing
# extremities to pile up, which in turn leads to state resolution
# taking longer.
with (yield self.limiter.queue(event_dict["room_id"])):
event, context = yield self.create_event(
with (await self.limiter.queue(event_dict["room_id"])):
event, context = await self.create_event(
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
)
@ -709,7 +707,7 @@ class EventCreationHandler(object):
spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
yield self.send_nonmember_event(
await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit
)
return event
@ -770,8 +768,7 @@ class EventCreationHandler(object):
return (event, context)
@measure_func("handle_new_client_event")
@defer.inlineCallbacks
def handle_new_client_event(
async def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Processes a new event. This includes checking auth, persisting it,
@ -794,9 +791,9 @@ class EventCreationHandler(object):
):
room_version = event.content.get("room_version", RoomVersions.V1.identifier)
else:
room_version = yield self.store.get_room_version_id(event.room_id)
room_version = await self.store.get_room_version_id(event.room_id)
event_allowed = yield self.third_party_event_rules.check_event_allowed(
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@ -805,7 +802,7 @@ class EventCreationHandler(object):
)
try:
yield self.auth.check_from_context(room_version, event, context)
await self.auth.check_from_context(room_version, event, context)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
raise err
@ -818,7 +815,7 @@ class EventCreationHandler(object):
logger.exception("Failed to encode content: %r", event.content)
raise
yield self.action_generator.handle_push_actions_for_event(event, context)
await self.action_generator.handle_push_actions_for_event(event, context)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead.
@ -826,7 +823,7 @@ class EventCreationHandler(object):
try:
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
yield self.send_event_to_master(
await self.send_event_to_master(
event_id=event.event_id,
store=self.store,
requester=requester,
@ -838,7 +835,7 @@ class EventCreationHandler(object):
success = True
return
yield self.persist_and_notify_client_event(
await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
@ -883,8 +880,7 @@ class EventCreationHandler(object):
Codes.BAD_ALIAS,
)
@defer.inlineCallbacks
def persist_and_notify_client_event(
async def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Called when we have fully built the event, have already
@ -901,7 +897,7 @@ class EventCreationHandler(object):
# user is actually admin or not).
is_admin_redaction = False
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
@ -913,11 +909,11 @@ class EventCreationHandler(object):
original_event and event.sender != original_event.sender
)
yield self.base_handler.ratelimit(
await self.base_handler.ratelimit(
requester, is_admin_redaction=is_admin_redaction
)
yield self.base_handler.maybe_kick_guest_users(event, context)
await self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Validate a newly added alias or newly added alt_aliases.
@ -927,7 +923,7 @@ class EventCreationHandler(object):
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
original_event = yield self.store.get_event(original_event_id)
original_event = await self.store.get_event(original_event_id)
if original_event:
original_alias = original_event.content.get("alias", None)
@ -937,7 +933,7 @@ class EventCreationHandler(object):
room_alias_str = event.content.get("alias", None)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias_str and room_alias_str != original_alias:
yield self._validate_canonical_alias(
await self._validate_canonical_alias(
directory_handler, room_alias_str, event.room_id
)
@ -957,7 +953,7 @@ class EventCreationHandler(object):
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
if new_alt_aliases:
for alias_str in new_alt_aliases:
yield self._validate_canonical_alias(
await self._validate_canonical_alias(
directory_handler, alias_str, event.room_id
)
@ -969,7 +965,7 @@ class EventCreationHandler(object):
def is_inviter_member_event(e):
return e.type == EventTypes.Member and e.sender == event.sender
current_state_ids = yield context.get_current_state_ids()
current_state_ids = await context.get_current_state_ids()
state_to_include_ids = [
e_id
@ -978,7 +974,7 @@ class EventCreationHandler(object):
or k == (EventTypes.Member, event.sender)
]
state_to_include = yield self.store.get_events(state_to_include_ids)
state_to_include = await self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [
{
@ -996,8 +992,8 @@ class EventCreationHandler(object):
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield defer.ensureDeferred(
federation_handler.send_invite(invitee.domain, event)
returned_invite = await federation_handler.send_invite(
invitee.domain, event
)
event.unsigned.pop("room_state", None)
@ -1005,7 +1001,7 @@ class EventCreationHandler(object):
event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
@ -1021,14 +1017,14 @@ class EventCreationHandler(object):
if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room")
prev_state_ids = yield context.get_prev_state_ids()
auth_events_ids = yield self.auth.compute_auth_events(
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version = yield self.store.get_room_version_id(event.room_id)
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if event_auth.check_redaction(
@ -1047,11 +1043,11 @@ class EventCreationHandler(object):
event.internal_metadata.recheck_redaction = False
if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
event_stream_id, max_stream_id = yield self.storage.persistence.persist_event(
event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
event, context=context
)
@ -1059,7 +1055,7 @@ class EventCreationHandler(object):
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify():
try:
@ -1083,13 +1079,12 @@ class EventCreationHandler(object):
except Exception:
logger.exception("Error bumping presence active time")
@defer.inlineCallbacks
def _send_dummy_events_to_fill_extremities(self):
async def _send_dummy_events_to_fill_extremities(self):
"""Background task to send dummy events into rooms that have a large
number of extremities
"""
self._expire_rooms_to_exclude_from_dummy_event_insertion()
room_ids = yield self.store.get_rooms_with_many_extremities(
room_ids = await self.store.get_rooms_with_many_extremities(
min_count=10,
limit=5,
room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
@ -1099,9 +1094,9 @@ class EventCreationHandler(object):
# For each room we need to find a joined member we can use to send
# the dummy event with.
latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
members = yield self.state.get_current_users_in_room(
members = await self.state.get_current_users_in_room(
room_id, latest_event_ids=latest_event_ids
)
dummy_event_sent = False
@ -1110,7 +1105,7 @@ class EventCreationHandler(object):
continue
requester = create_requester(user_id)
try:
event, context = yield self.create_event(
event, context = await self.create_event(
requester,
{
"type": "org.matrix.dummy_event",
@ -1123,7 +1118,7 @@ class EventCreationHandler(object):
event.internal_metadata.proactively_send = False
yield self.send_nonmember_event(
await self.send_nonmember_event(
requester, event, context, ratelimit=False
)
dummy_event_sent = True

View File

@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler):
return result["displayname"]
@defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
async def set_displayname(
self, target_user, requester, new_displayname, by_admin=False
):
"""Set the displayname of a user
Args:
@ -158,7 +159,7 @@ class BaseProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's displayname")
if not by_admin and not self.hs.config.enable_set_displayname:
profile = yield self.store.get_profileinfo(target_user.localpart)
profile = await self.store.get_profileinfo(target_user.localpart)
if profile.display_name:
raise SynapseError(
400,
@ -180,15 +181,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
yield self.store.set_profile_displayname(target_user.localpart, new_displayname)
await self.store.set_profile_displayname(target_user.localpart, new_displayname)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change(
profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
yield self._update_join_states(requester, target_user)
await self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@ -217,8 +218,9 @@ class BaseProfileHandler(BaseHandler):
return result["avatar_url"]
@defer.inlineCallbacks
def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
async def set_avatar_url(
self, target_user, requester, new_avatar_url, by_admin=False
):
"""target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user):
@ -228,7 +230,7 @@ class BaseProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's avatar_url")
if not by_admin and not self.hs.config.enable_set_avatar_url:
profile = yield self.store.get_profileinfo(target_user.localpart)
profile = await self.store.get_profileinfo(target_user.localpart)
if profile.avatar_url:
raise SynapseError(
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
@ -243,15 +245,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change(
profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
yield self._update_join_states(requester, target_user)
await self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def on_profile_query(self, args):
@ -279,21 +281,20 @@ class BaseProfileHandler(BaseHandler):
return response
@defer.inlineCallbacks
def _update_join_states(self, requester, target_user):
async def _update_join_states(self, requester, target_user):
if not self.hs.is_mine(target_user):
return
yield self.ratelimit(requester)
await self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user(target_user.to_string())
room_ids = await self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids:
handler = self.hs.get_room_member_handler()
try:
# Assume the target_user isn't a guest,
# because we don't let guests set profile or avatar data.
yield handler.update_membership(
await handler.update_membership(
requester,
target_user,
room_id,

View File

@ -145,9 +145,9 @@ class RegistrationHandler(BaseHandler):
"""Registers a new client on the server.
Args:
localpart : The local part of the user ID to register. If None,
localpart: The local part of the user ID to register. If None,
one will be generated.
password (unicode) : The password to assign to this user so they can
password (unicode): The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
user_type (str|None): type of user. One of the values from
@ -244,7 +244,7 @@ class RegistrationHandler(BaseHandler):
fail_count += 1
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
yield defer.ensureDeferred(self._auto_join_rooms(user_id))
else:
logger.info(
"Skipping auto-join for %s because consent is required at registration",
@ -266,8 +266,7 @@ class RegistrationHandler(BaseHandler):
return user_id
@defer.inlineCallbacks
def _auto_join_rooms(self, user_id):
async def _auto_join_rooms(self, user_id):
"""Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created.
@ -281,9 +280,9 @@ class RegistrationHandler(BaseHandler):
# that an auto-generated support or bot user is not a real user and will never be
# the user to create the room
should_auto_create_rooms = False
is_real_user = yield self.store.is_real_user(user_id)
is_real_user = await self.store.is_real_user(user_id)
if self.hs.config.autocreate_auto_join_rooms and is_real_user:
count = yield self.store.count_real_users()
count = await self.store.count_real_users()
should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms:
logger.info("Auto-joining %s to %s", user_id, r)
@ -302,7 +301,7 @@ class RegistrationHandler(BaseHandler):
# getting the RoomCreationHandler during init gives a dependency
# loop
yield self.hs.get_room_creation_handler().create_room(
await self.hs.get_room_creation_handler().create_room(
fake_requester,
config={
"preset": "public_chat",
@ -311,7 +310,7 @@ class RegistrationHandler(BaseHandler):
ratelimit=False,
)
else:
yield self._join_user_to_room(fake_requester, r)
await self._join_user_to_room(fake_requester, r)
except ConsentNotGivenError as e:
# Technically not necessary to pull out this error though
# moving away from bare excepts is a good thing to do.
@ -319,15 +318,14 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
@defer.inlineCallbacks
def post_consent_actions(self, user_id):
async def post_consent_actions(self, user_id):
"""A series of registration actions that can only be carried out once consent
has been granted
Args:
user_id (str): The user to join
"""
yield self._auto_join_rooms(user_id)
await self._auto_join_rooms(user_id)
@defer.inlineCallbacks
def appservice_register(self, user_localpart, as_token):
@ -394,14 +392,13 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id += 1
return str(id)
@defer.inlineCallbacks
def _join_user_to_room(self, requester, room_identifier):
async def _join_user_to_room(self, requester, room_identifier):
room_member_handler = self.hs.get_room_member_handler()
if RoomID.is_valid(room_identifier):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
room_id, remote_room_hosts = await room_member_handler.lookup_room_alias(
room_alias
)
room_id = room_id.to_string()
@ -410,7 +407,7 @@ class RegistrationHandler(BaseHandler):
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
yield room_member_handler.update_membership(
await room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@ -550,8 +547,7 @@ class RegistrationHandler(BaseHandler):
return (device_id, access_token)
@defer.inlineCallbacks
def post_registration_actions(self, user_id, auth_result, access_token):
async def post_registration_actions(self, user_id, auth_result, access_token):
"""A user has completed registration
Args:
@ -562,7 +558,7 @@ class RegistrationHandler(BaseHandler):
device, or None if `inhibit_login` enabled.
"""
if self.hs.config.worker_app:
yield self._post_registration_client(
await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token
)
return
@ -574,19 +570,18 @@ class RegistrationHandler(BaseHandler):
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
yield self.store.upsert_monthly_active_user(user_id)
await self.store.upsert_monthly_active_user(user_id)
yield self._register_email_threepid(user_id, threepid, access_token)
await self._register_email_threepid(user_id, threepid, access_token)
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
yield self._register_msisdn_threepid(user_id, threepid)
await self._register_msisdn_threepid(user_id, threepid)
if auth_result and LoginType.TERMS in auth_result:
yield self._on_user_consented(user_id, self.hs.config.user_consent_version)
await self._on_user_consented(user_id, self.hs.config.user_consent_version)
@defer.inlineCallbacks
def _on_user_consented(self, user_id, consent_version):
async def _on_user_consented(self, user_id, consent_version):
"""A user consented to the terms on registration
Args:
@ -595,8 +590,8 @@ class RegistrationHandler(BaseHandler):
consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
yield self.store.user_set_consent_version(user_id, consent_version)
yield self.post_consent_actions(user_id)
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
@defer.inlineCallbacks
def _register_email_threepid(self, user_id, threepid, token):

View File

@ -148,17 +148,16 @@ class RoomCreationHandler(BaseHandler):
return ret
@defer.inlineCallbacks
def _upgrade_room(
async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
user_id = requester.user.to_string()
# start by allocating a new room id
r = yield self.store.get_room(old_room_id)
r = await self.store.get_room(old_room_id)
if r is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,))
new_room_id = yield self._generate_room_id(
new_room_id = await self._generate_room_id(
creator_id=user_id, is_public=r["is_public"], room_version=new_version,
)
@ -169,7 +168,7 @@ class RoomCreationHandler(BaseHandler):
(
tombstone_event,
tombstone_context,
) = yield self.event_creation_handler.create_event(
) = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Tombstone,
@ -183,12 +182,12 @@ class RoomCreationHandler(BaseHandler):
},
token_id=requester.access_token_id,
)
old_room_version = yield self.store.get_room_version_id(old_room_id)
yield self.auth.check_from_context(
old_room_version = await self.store.get_room_version_id(old_room_id)
await self.auth.check_from_context(
old_room_version, tombstone_event, tombstone_context
)
yield self.clone_existing_room(
await self.clone_existing_room(
requester,
old_room_id=old_room_id,
new_room_id=new_room_id,
@ -197,32 +196,31 @@ class RoomCreationHandler(BaseHandler):
)
# now send the tombstone
yield self.event_creation_handler.send_nonmember_event(
await self.event_creation_handler.send_nonmember_event(
requester, tombstone_event, tombstone_context
)
old_room_state = yield tombstone_context.get_current_state_ids()
old_room_state = await tombstone_context.get_current_state_ids()
# update any aliases
yield self._move_aliases_to_new_room(
await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state
)
# Copy over user push rules, tags and migrate room directory state
yield self.room_member_handler.transfer_room_state_on_room_upgrade(
await self.room_member_handler.transfer_room_state_on_room_upgrade(
old_room_id, new_room_id
)
# finally, shut down the PLs in the old room, and update them in the new
# room.
yield self._update_upgraded_room_pls(
await self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state,
)
return new_room_id
@defer.inlineCallbacks
def _update_upgraded_room_pls(
async def _update_upgraded_room_pls(
self,
requester: Requester,
old_room_id: str,
@ -249,7 +247,7 @@ class RoomCreationHandler(BaseHandler):
)
return
old_room_pl_state = yield self.store.get_event(old_room_pl_event_id)
old_room_pl_state = await self.store.get_event(old_room_pl_event_id)
# we try to stop regular users from speaking by setting the PL required
# to send regular events and invites to 'Moderator' level. That's normally
@ -278,7 +276,7 @@ class RoomCreationHandler(BaseHandler):
if updated:
try:
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.PowerLevels,
@ -292,7 +290,7 @@ class RoomCreationHandler(BaseHandler):
except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e)
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.PowerLevels,
@ -304,8 +302,7 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False,
)
@defer.inlineCallbacks
def clone_existing_room(
async def clone_existing_room(
self,
requester: Requester,
old_room_id: str,
@ -338,7 +335,7 @@ class RoomCreationHandler(BaseHandler):
# Check if old room was non-federatable
# Get old room's create event
old_room_create_event = yield self.store.get_create_event_for_room(old_room_id)
old_room_create_event = await self.store.get_create_event_for_room(old_room_id)
# Check if the create event specified a non-federatable room
if not old_room_create_event.content.get("m.federate", True):
@ -361,11 +358,11 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.PowerLevels, ""),
)
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
old_room_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types(types_to_copy)
)
# map from event_id to BaseEvent
old_room_state_events = yield self.store.get_events(old_room_state_ids.values())
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
for k, old_event_id in iteritems(old_room_state_ids):
old_event = old_room_state_events.get(old_event_id)
@ -400,7 +397,7 @@ class RoomCreationHandler(BaseHandler):
if current_power_level < needed_power_level:
power_levels["users"][user_id] = needed_power_level
yield self._send_events_for_new_room(
await self._send_events_for_new_room(
requester,
new_room_id,
# we expect to override all the presets with initial_state, so this is
@ -412,12 +409,12 @@ class RoomCreationHandler(BaseHandler):
)
# Transfer membership events
old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
)
# map from event_id to BaseEvent
old_room_member_state_events = yield self.store.get_events(
old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
for k, old_event in iteritems(old_room_member_state_events):
@ -426,7 +423,7 @@ class RoomCreationHandler(BaseHandler):
"membership" in old_event.content
and old_event.content["membership"] == "ban"
):
yield self.room_member_handler.update_membership(
await self.room_member_handler.update_membership(
requester,
UserID.from_string(old_event["state_key"]),
new_room_id,
@ -438,8 +435,7 @@ class RoomCreationHandler(BaseHandler):
# XXX invites/joins
# XXX 3pid invites
@defer.inlineCallbacks
def _move_aliases_to_new_room(
async def _move_aliases_to_new_room(
self,
requester: Requester,
old_room_id: str,
@ -448,13 +444,13 @@ class RoomCreationHandler(BaseHandler):
):
directory_handler = self.hs.get_handlers().directory_handler
aliases = yield self.store.get_aliases_for_room(old_room_id)
aliases = await self.store.get_aliases_for_room(old_room_id)
# check to see if we have a canonical alias.
canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event_id:
canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
canonical_alias_event = await self.store.get_event(canonical_alias_event_id)
# first we try to remove the aliases from the old room (we suppress sending
# the room_aliases event until the end).
@ -472,7 +468,7 @@ class RoomCreationHandler(BaseHandler):
for alias_str in aliases:
alias = RoomAlias.from_string(alias_str)
try:
yield directory_handler.delete_association(requester, alias)
await directory_handler.delete_association(requester, alias)
removed_aliases.append(alias_str)
except SynapseError as e:
logger.warning("Unable to remove alias %s from old room: %s", alias, e)
@ -485,7 +481,7 @@ class RoomCreationHandler(BaseHandler):
# we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases:
try:
yield directory_handler.create_association(
await directory_handler.create_association(
requester,
RoomAlias.from_string(alias),
new_room_id,
@ -502,7 +498,7 @@ class RoomCreationHandler(BaseHandler):
# alias event for the new room with a copy of the information.
try:
if canonical_alias_event:
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
@ -518,8 +514,9 @@ class RoomCreationHandler(BaseHandler):
# we returned the new room to the client at this point.
logger.error("Unable to send updated alias events in new room: %s", e)
@defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True, creator_join_profile=None):
async def create_room(
self, requester, config, ratelimit=True, creator_join_profile=None
):
""" Creates a new room.
Args:
@ -547,7 +544,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
yield self.auth.check_auth_blocking(user_id)
await self.auth.check_auth_blocking(user_id)
if (
self._server_notices_mxid is not None
@ -556,11 +553,11 @@ class RoomCreationHandler(BaseHandler):
# allow the server notices mxid to create rooms
is_requester_admin = True
else:
is_requester_admin = yield self.auth.is_server_admin(requester.user)
is_requester_admin = await self.auth.is_server_admin(requester.user)
# Check whether the third party rules allows/changes the room create
# request.
event_allowed = yield self.third_party_event_rules.on_create_room(
event_allowed = await self.third_party_event_rules.on_create_room(
requester, config, is_requester_admin=is_requester_admin
)
if not event_allowed:
@ -574,7 +571,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit:
yield self.ratelimit(requester)
await self.ratelimit(requester)
room_version_id = config.get(
"room_version", self.config.default_room_version.identifier
@ -597,7 +594,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(400, "Invalid characters in room alias")
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
mapping = yield self.store.get_association_from_room_alias(room_alias)
mapping = await self.store.get_association_from_room_alias(room_alias)
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
@ -612,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
yield self.event_creation_handler.assert_accepted_privacy_policy(requester)
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")
if (
@ -631,13 +628,13 @@ class RoomCreationHandler(BaseHandler):
visibility = config.get("visibility", None)
is_public = visibility == "public"
room_id = yield self._generate_room_id(
room_id = await self._generate_room_id(
creator_id=user_id, is_public=is_public, room_version=room_version,
)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias:
yield directory_handler.create_association(
await directory_handler.create_association(
requester=requester,
room_id=room_id,
room_alias=room_alias,
@ -670,7 +667,7 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier
yield self._send_events_for_new_room(
await self._send_events_for_new_room(
requester,
room_id,
preset_config=preset_config,
@ -684,7 +681,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config:
name = config["name"]
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@ -698,7 +695,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@ -716,7 +713,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
yield self.room_member_handler.update_membership(
await self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
@ -730,7 +727,7 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
yield self.hs.get_room_member_handler().do_3pid_invite(
await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@ -748,8 +745,7 @@ class RoomCreationHandler(BaseHandler):
return result
@defer.inlineCallbacks
def _send_events_for_new_room(
async def _send_events_for_new_room(
self,
creator, # A Requester object.
room_id,
@ -769,11 +765,10 @@ class RoomCreationHandler(BaseHandler):
return e
@defer.inlineCallbacks
def send(etype, content, **kwargs):
async def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False
)
@ -784,10 +779,10 @@ class RoomCreationHandler(BaseHandler):
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id})
yield send(etype=EventTypes.Create, content=creation_content)
await send(etype=EventTypes.Create, content=creation_content)
logger.debug("Sending %s in new room", EventTypes.Member)
yield self.room_member_handler.update_membership(
await self.room_member_handler.update_membership(
creator,
creator.user,
room_id,
@ -800,7 +795,7 @@ class RoomCreationHandler(BaseHandler):
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
yield send(etype=EventTypes.PowerLevels, content=pl_content)
await send(etype=EventTypes.PowerLevels, content=pl_content)
else:
power_level_content = {
"users": {creator_id: 100},
@ -833,33 +828,33 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override:
power_level_content.update(power_level_content_override)
yield send(etype=EventTypes.PowerLevels, content=power_level_content)
await send(etype=EventTypes.PowerLevels, content=power_level_content)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
yield send(
await send(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
if (EventTypes.JoinRules, "") not in initial_state:
yield send(
await send(
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
)
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
yield send(
await send(
etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]},
)
if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state:
yield send(
await send(
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
)
for (etype, state_key), content in initial_state.items():
yield send(etype=etype, state_key=state_key, content=content)
await send(etype=etype, state_key=state_key, content=content)
@defer.inlineCallbacks
def _generate_room_id(

View File

@ -142,8 +142,7 @@ class RoomMemberHandler(object):
"""
raise NotImplementedError()
@defer.inlineCallbacks
def _local_membership_update(
async def _local_membership_update(
self,
requester,
target,
@ -164,7 +163,7 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
event, context = yield self.event_creation_handler.create_event(
event, context = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
@ -182,18 +181,18 @@ class RoomMemberHandler(object):
)
# Check if this event matches the previous membership event for the user.
duplicate = yield self.event_creation_handler.deduplicate_state_event(
duplicate = await self.event_creation_handler.deduplicate_state_event(
event, context
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return duplicate
yield self.event_creation_handler.handle_new_client_event(
await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit
)
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@ -203,15 +202,15 @@ class RoomMemberHandler(object):
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield self._user_joined_room(target, room_id)
await self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target, room_id)
await self._user_left_room(target, room_id)
return event
@ -253,8 +252,7 @@ class RoomMemberHandler(object):
for tag, tag_content in room_tags.items():
yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
@defer.inlineCallbacks
def update_membership(
async def update_membership(
self,
requester,
target,
@ -269,8 +267,8 @@ class RoomMemberHandler(object):
):
key = (room_id,)
with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership(
with (await self.member_linearizer.queue(key)):
result = await self._update_membership(
requester,
target,
room_id,
@ -285,8 +283,7 @@ class RoomMemberHandler(object):
return result
@defer.inlineCallbacks
def _update_membership(
async def _update_membership(
self,
requester,
target,
@ -321,7 +318,7 @@ class RoomMemberHandler(object):
# if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join.
if third_party_signed is not None:
yield self.federation_handler.exchange_third_party_invite(
await self.federation_handler.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
@ -332,7 +329,7 @@ class RoomMemberHandler(object):
remote_room_hosts = []
if effective_membership_state not in ("leave", "ban"):
is_blocked = yield self.store.is_room_blocked(room_id)
is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
@ -351,7 +348,7 @@ class RoomMemberHandler(object):
is_requester_admin = True
else:
is_requester_admin = yield self.auth.is_server_admin(requester.user)
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
if self.config.block_non_admin_invites:
@ -370,9 +367,9 @@ class RoomMemberHandler(object):
if block_invite:
raise SynapseError(403, "Invites have been disabled on this server")
latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids(
current_state_ids = await self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids
)
@ -381,7 +378,7 @@ class RoomMemberHandler(object):
# transitions and generic otherwise
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
if old_state_id:
old_state = yield self.store.get_event(old_state_id, allow_none=True)
old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
@ -413,7 +410,7 @@ class RoomMemberHandler(object):
old_membership == Membership.INVITE
and effective_membership_state == Membership.LEAVE
):
is_blocked = yield self._is_server_notice_room(room_id)
is_blocked = await self._is_server_notice_room(room_id)
if is_blocked:
raise SynapseError(
http_client.FORBIDDEN,
@ -424,18 +421,18 @@ class RoomMemberHandler(object):
if action == "kick":
raise AuthError(403, "The target user is not in the room")
is_host_in_room = yield self._is_host_in_room(current_state_ids)
is_host_in_room = await self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN:
if requester.is_guest:
guest_can_join = yield self._can_guest_join(current_state_ids)
guest_can_join = await self._can_guest_join(current_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
inviter = yield self._get_inviter(target.to_string(), room_id)
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@ -443,13 +440,13 @@ class RoomMemberHandler(object):
profile = self.profile_handler
if not content_specified:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
content["displayname"] = await profile.get_displayname(target)
content["avatar_url"] = await profile.get_avatar_url(target)
if requester.is_guest:
content["kind"] = "guest"
remote_join_response = yield self._remote_join(
remote_join_response = await self._remote_join(
requester, remote_room_hosts, room_id, target, content
)
@ -458,7 +455,7 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
inviter = yield self._get_inviter(target.to_string(), room_id)
inviter = await self._get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
@ -472,12 +469,12 @@ class RoomMemberHandler(object):
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
res = yield self._remote_reject_invite(
res = await self._remote_reject_invite(
requester, remote_room_hosts, room_id, target, content,
)
return res
res = yield self._local_membership_update(
res = await self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
@ -572,8 +569,7 @@ class RoomMemberHandler(object):
)
continue
@defer.inlineCallbacks
def send_membership_event(self, requester, event, context, ratelimit=True):
async def send_membership_event(self, requester, event, context, ratelimit=True):
"""
Change the membership status of a user in a room.
@ -599,27 +595,27 @@ class RoomMemberHandler(object):
else:
requester = types.create_requester(target_user)
prev_event = yield self.event_creation_handler.deduplicate_state_event(
prev_event = await self.event_creation_handler.deduplicate_state_event(
event, context
)
if prev_event is not None:
return
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = yield self._can_guest_join(prev_state_ids)
guest_can_join = await self._can_guest_join(prev_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if event.membership not in (Membership.LEAVE, Membership.BAN):
is_blocked = yield self.store.is_room_blocked(room_id)
is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
yield self.event_creation_handler.handle_new_client_event(
await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target_user], ratelimit=ratelimit
)
@ -633,15 +629,15 @@ class RoomMemberHandler(object):
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield self._user_joined_room(target_user, room_id)
await self._user_joined_room(target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target_user, room_id)
await self._user_left_room(target_user, room_id)
@defer.inlineCallbacks
def _can_guest_join(self, current_state_ids):
@ -699,8 +695,7 @@ class RoomMemberHandler(object):
if invite:
return UserID.from_string(invite.sender)
@defer.inlineCallbacks
def do_3pid_invite(
async def do_3pid_invite(
self,
room_id,
inviter,
@ -712,7 +707,7 @@ class RoomMemberHandler(object):
id_access_token=None,
):
if self.config.block_non_admin_invites:
is_requester_admin = yield self.auth.is_server_admin(requester.user)
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
raise SynapseError(
403, "Invites have been disabled on this server", Codes.FORBIDDEN
@ -720,9 +715,9 @@ class RoomMemberHandler(object):
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
yield self.base_handler.ratelimit(requester)
await self.base_handler.ratelimit(requester)
can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id
)
if not can_invite:
@ -737,16 +732,16 @@ class RoomMemberHandler(object):
403, "Looking up third-party identifiers is denied from this server"
)
invitee = yield self.identity_handler.lookup_3pid(
invitee = await self.identity_handler.lookup_3pid(
id_server, medium, address, id_access_token
)
if invitee:
yield self.update_membership(
await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
yield self._make_and_store_3pid_invite(
await self._make_and_store_3pid_invite(
requester,
id_server,
medium,
@ -757,8 +752,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
async def _make_and_store_3pid_invite(
self,
requester,
id_server,
@ -769,7 +763,7 @@ class RoomMemberHandler(object):
txn_id,
id_access_token=None,
):
room_state = yield self.state_handler.get_current_state(room_id)
room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
@ -807,7 +801,7 @@ class RoomMemberHandler(object):
public_keys,
fallback_public_key,
display_name,
) = yield self.identity_handler.ask_id_server_for_third_party_invite(
) = await self.identity_handler.ask_id_server_for_third_party_invite(
requester=requester,
id_server=id_server,
medium=medium,
@ -823,7 +817,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@ -917,8 +911,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return complexity["v1"] > max_complexity
@defer.inlineCallbacks
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join
"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
@ -933,7 +926,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if self.hs.config.limit_remote_rooms.enabled:
# Fetch the room complexity
too_complex = yield self._is_remote_room_too_complex(
too_complex = await self._is_remote_room_too_complex(
room_id, remote_room_hosts
)
if too_complex is True:
@ -947,12 +940,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield defer.ensureDeferred(
self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content
)
await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content
)
yield self._user_joined_room(user, room_id)
await self._user_joined_room(user, room_id)
# Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before.
@ -962,7 +953,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return
# Check again, but with the local state events
too_complex = yield self._is_local_room_too_complex(room_id)
too_complex = await self._is_local_room_too_complex(room_id)
if too_complex is False:
# We're under the limit.
@ -970,7 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# The room is too large. Leave.
requester = types.create_requester(user, None, False, None)
yield self.update_membership(
await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
raise SynapseError(
@ -1008,12 +999,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
"""
return user_joined_room(self.distributor, target, room_id)
return defer.succeed(user_joined_room(self.distributor, target, room_id))
def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room
"""
return user_left_room(self.distributor, target, room_id)
return defer.succeed(user_left_room(self.distributor, target, room_id))
@defer.inlineCallbacks
def forget(self, user, room_id):

View File

@ -149,7 +149,7 @@ class SamlHandler:
# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
self._auth_handler.complete_sso_ui_auth(
await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)

View File

@ -16,6 +16,7 @@
import abc
import logging
import re
from inspect import signature
from typing import Dict, List, Tuple
from six import raise_from
@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
must call `register` to register the path with the HTTP server.
Requests can be sent by calling the client returned by `make_client`.
Requests are sent to master process by default, but can be sent to other
named processes by specifying an `instance_name` keyword argument.
Attributes:
NAME (str): A name for the endpoint, added to the path as well as used
@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
)
# We reserve `instance_name` as a parameter to sending requests, so we
# assert here that sub classes don't try and use the name.
assert (
"instance_name" not in self.PATH_ARGS
), "`instance_name` is a reserved paramater name"
assert (
"instance_name"
not in signature(self.__class__._serialize_payload).parameters
), "`instance_name` is a reserved paramater name"
assert self.METHOD in ("PUT", "POST", "GET")
@abc.abstractmethod
@ -135,7 +148,11 @@ class ReplicationEndpoint(object):
@trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
def send_request(**kwargs):
def send_request(instance_name="master", **kwargs):
# Currently we only support sending requests to master process.
if instance_name != "master":
raise Exception("Unknown instance")
data = yield cls._serialize_payload(**kwargs)
url_args = [

View File

@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
def __init__(self, hs):
super().__init__(hs)
self._instance_name = hs.get_instance_name()
# We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams()
@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
upto_token = parse_integer(request, "upto_token", required=True)
updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token
self._instance_name, from_token, upto_token
)
return (

View File

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Dict, Optional
from typing import Optional
import six
@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
self.hs = hs
def stream_positions(self) -> Dict[str, int]:
"""
Get the current positions of all the streams this store wants to subscribe to
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()

View File

@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token()
result["user_account_data"] = position
result["room_account_data"] = position
result["tag_account_data"] = position
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token)

View File

@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
expiry_ms=30 * 60 * 1000,
)
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "to_device":
self._device_inbox_id_gen.advance(token)

View File

@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
# The user signature stream uses the same stream ID generator as the
# device list stream, so set them both to the device list ID
# generator's current token.
current_token = self._device_list_id_gen.get_current_token()
result[DeviceListsStream.NAME] = current_token
result[UserSignatureStream.NAME] = current_token
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)

View File

@ -93,12 +93,6 @@ class SlavedEventStore(
def get_room_min_stream_ordering(self):
return self._backfill_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfill"] = -self._backfill_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "events":
self._stream_id_gen.advance(token)

View File

@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedGroupServerStore, self).stream_positions()
result["groups"] = self._group_updates_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)

View File

@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPresenceStore, self).stream_positions()
if self.hs.config.use_presence:
position = self._presence_id_gen.get_current_token()
result["presence"] = position
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "presence":
self._presence_id_gen.advance(token)

View File

@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token)

View File

@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()

View File

@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
return result
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,))

View File

@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
def stream_positions(self):
result = super(RoomStore, self).stream_positions()
result["public_rooms"] = self._public_room_id_gen.get_current_token()
return result
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "public_rooms":
self._public_room_id_gen.advance(token)

View File

@ -16,7 +16,7 @@
"""
import logging
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING
from twisted.internet.protocol import ReconnectingClientFactory
@ -86,37 +86,22 @@ class ReplicationDataHandler:
def __init__(self, store: BaseSlavedStore):
self.store = store
async def on_rdata(self, stream_name: str, token: int, rows: list):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
stream_name: name of the replication stream for this batch of rows
instance_name: the instance that wrote the rows.
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, token, rows)
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
room_account_data = args.pop("room_account_data", None)
if user_account_data:
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
return args
async def on_position(self, stream_name: str, token: int):
self.store.process_replication_rows(stream_name, token, [])

View File

@ -278,19 +278,24 @@ class ReplicationCommandHandler:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
await self.on_rdata(stream_name, cmd.token, rows)
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(self, stream_name: str, token: int, rows: list):
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name: name of the replication stream for this batch of rows
instance_name: the instance that wrote the rows.
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
await self._replication_data_handler.on_rdata(stream_name, token, rows)
await self._replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
@ -314,15 +319,7 @@ class ReplicationCommandHandler:
self._pending_batches.pop(cmd.stream_name, [])
# Find where we previously streamed up to.
current_token = self._replication_data_handler.get_streams_to_replicate().get(
cmd.stream_name
)
if current_token is None:
logger.warning(
"Got POSITION for stream we're not subscribed to: %s",
cmd.stream_name,
)
return
current_token = stream.current_token()
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
@ -333,7 +330,9 @@ class ReplicationCommandHandler:
updates,
current_token,
missing_updates,
) = await stream.get_updates_since(current_token, cmd.token)
) = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)
# TODO: add some tests for this
@ -342,7 +341,10 @@ class ReplicationCommandHandler:
for token, rows in _batch_updates(updates):
await self.on_rdata(
cmd.stream_name, token, [stream.parse_row(row) for row in rows],
cmd.stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)
# We've now caught up to position sent to us, notify handler.

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
import logging
from collections import namedtuple
@ -21,10 +22,10 @@ from typing import (
Any,
Awaitable,
Callable,
Iterable,
List,
Optional,
Tuple,
TypeVar,
)
import attr
@ -49,7 +50,7 @@ Token = int
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question.
#
StreamRow = Tuple
StreamRow = TypeVar("StreamRow", bound=Tuple)
# The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc.
@ -65,6 +66,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
#
# The arguments are:
#
# * instance_name: the writer of the stream
# * from_token: the previous stream token: the starting point for fetching the
# updates
# * to_token: the new stream token: the point to get updates up to
@ -74,7 +76,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
# If there are more updates available, it should set `limited` in the result, and
# it will be called again to get the next batch.
#
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
class Stream(object):
@ -105,6 +107,7 @@ class Stream(object):
def __init__(
self,
local_instance_name: str,
current_token_function: Callable[[], Token],
update_function: UpdateFunction,
):
@ -120,9 +123,11 @@ class Stream(object):
stream tokens. See the UpdateFunction type definition for more info.
Args:
local_instance_name: The instance name of the current process
current_token_function: callback to get the current token, as above
update_function: callback go get stream updates, as above
"""
self.local_instance_name = local_instance_name
self.current_token = current_token_function
self.update_function = update_function
@ -147,14 +152,14 @@ class Stream(object):
"""
current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since(
self.last_token, current_token
self.local_instance_name, self.last_token, current_token
)
self.last_token = current_token
return updates, current_token, limited
async def get_updates_since(
self, from_token: Token, upto_token: Token
self, instance_name: str, from_token: Token, upto_token: Token
) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should
stream updates
@ -172,26 +177,25 @@ class Stream(object):
return [], upto_token, False
updates, upto_token, limited = await self.update_function(
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
)
return updates, upto_token, limited
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
async def update_function(from_token, upto_token, limit):
async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) == limit:
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
assert len(updates) <= limit
return updates, upto_token, limited
@ -206,10 +210,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
client = ReplicationGetStreamUpdates.make_client(hs)
async def update_function(
from_token: int, upto_token: int, limit: int
instance_name: str, from_token: int, upto_token: int, limit: int
) -> StreamUpdateResult:
result = await client(
stream_name=stream_name, from_token=from_token, upto_token=upto_token,
instance_name=instance_name,
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
)
return result["updates"], result["upto_token"], result["limited"]
@ -239,6 +246,7 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_backfill_token,
db_query_to_update_function(store.get_all_new_backfill_event_rows),
)
@ -274,7 +282,9 @@ class PresenceStream(Stream):
# Query master process
update_function = make_http_update_function(hs, self.NAME)
super().__init__(store.get_current_presence_token, update_function)
super().__init__(
hs.get_instance_name(), store.get_current_presence_token, update_function
)
class TypingStream(Stream):
@ -297,7 +307,9 @@ class TypingStream(Stream):
# Query master process
update_function = make_http_update_function(hs, self.NAME)
super().__init__(typing_handler.get_current_token, update_function)
super().__init__(
hs.get_instance_name(), typing_handler.get_current_token, update_function
)
class ReceiptsStream(Stream):
@ -318,6 +330,7 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_receipt_stream_id,
db_query_to_update_function(store.get_all_updated_receipts),
)
@ -335,14 +348,16 @@ class PushRulesStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
super(PushRulesStream, self).__init__(
self._current_token, self._update_function
hs.get_instance_name(), self._current_token, self._update_function
)
def _current_token(self) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
async def _update_function(self, from_token: Token, to_token: Token, limit: int):
async def _update_function(
self, instance_name: str, from_token: Token, to_token: Token, limit: int
):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
limited = False
@ -369,6 +384,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_pushers_stream_token,
db_query_to_update_function(store.get_all_updated_pushers_rows),
)
@ -400,6 +416,7 @@ class CachesStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token,
db_query_to_update_function(store.get_all_updated_caches),
)
@ -425,6 +442,7 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_current_public_room_stream_id,
db_query_to_update_function(store.get_all_new_public_rooms),
)
@ -445,6 +463,7 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
)
@ -462,6 +481,7 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_to_device_stream_token,
db_query_to_update_function(store.get_all_new_device_messages),
)
@ -481,6 +501,7 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_max_account_data_stream_id,
db_query_to_update_function(store.get_all_updated_tags),
)
@ -501,11 +522,13 @@ class AccountDataStream(Stream):
def __init__(self, hs: "synapse.server.HomeServer"):
self.store = hs.get_datastore()
super().__init__(
self.store.get_max_account_data_stream_id, self._update_function,
hs.get_instance_name(),
self.store.get_max_account_data_stream_id,
self._update_function,
)
async def _update_function(
self, from_token: int, to_token: int, limit: int
self, instance_name: str, from_token: int, to_token: int, limit: int
) -> StreamUpdateResult:
limited = False
global_results = await self.store.get_updated_global_account_data(
@ -530,16 +553,19 @@ class AccountDataStream(Stream):
# convert the global results to the right format, and limit them to the to_token
# at the same time
global_results = (
global_rows = (
(stream_id, (user_id, None, account_data_type))
for stream_id, user_id, account_data_type in global_results
if stream_id <= to_token
)
room_results = ((stream_id, rest) for stream_id, *rest in room_results)
room_rows = (
(stream_id, (user_id, room_id, account_data_type))
for stream_id, user_id, room_id, account_data_type in room_results
)
# we need to return a sorted list, so merge them together.
updates = list(heapq.merge(room_results, global_results))
updates = list(heapq.merge(room_rows, global_rows))
return updates, to_token, limited
@ -555,6 +581,7 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_group_stream_token,
db_query_to_update_function(store.get_all_groups_changes),
)
@ -572,6 +599,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_device_stream_token,
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes

View File

@ -118,11 +118,17 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
super().__init__(
self._store.get_current_events_token, self._update_function,
hs.get_instance_name(),
self._store.get_current_events_token,
self._update_function,
)
async def _update_function(
self, from_token: Token, current_token: Token, target_row_count: int
self,
instance_name: str,
from_token: Token,
current_token: Token,
target_row_count: int,
) -> StreamUpdateResult:
# the events stream merges together three separate sources:

View File

@ -48,8 +48,8 @@ class FederationStream(Stream):
current_token = lambda: 0
update_function = self._stub_update_function
super().__init__(current_token, update_function)
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
async def _stub_update_function(from_token, upto_token, limit):
async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False

View File

@ -140,7 +140,7 @@ class AuthRestServlet(RestServlet):
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
def on_GET(self, request, stagetype):
async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
if not session:
raise SynapseError(400, "No session supplied")
@ -180,7 +180,7 @@ class AuthRestServlet(RestServlet):
else:
raise SynapseError(400, "Homeserver not configured for SSO.")
html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
else:
raise SynapseError(404, "Unknown auth stage type")

View File

@ -499,7 +499,7 @@ class RegisterRestServlet(RestServlet):
# registered a user for this session, so we could just return the
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = self.auth_handler.get_session_data(
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
)
@ -598,7 +598,7 @@ class RegisterRestServlet(RestServlet):
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
self.auth_handler.set_session_data(
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)

View File

@ -16,8 +16,6 @@ import logging
from six import iteritems, string_types
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder
from synapse.config import ConfigError
@ -59,8 +57,7 @@ class ConsentServerNotices(object):
self._consent_uri_builder = ConsentURIBuilder(hs.config)
@defer.inlineCallbacks
def maybe_send_server_notice_to_user(self, user_id):
async def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, and does so if so
Args:
@ -78,7 +75,7 @@ class ConsentServerNotices(object):
return
self._users_in_progress.add(user_id)
try:
u = yield self._store.get_user_by_id(user_id)
u = await self._store.get_user_by_id(user_id)
if u["is_guest"] and not self._send_to_guests:
# don't send to guests
@ -100,8 +97,8 @@ class ConsentServerNotices(object):
content = copy_with_str_subst(
self._server_notice_content, {"consent_uri": consent_uri}
)
yield self._server_notices_manager.send_notice(user_id, content)
yield self._store.user_set_consent_server_notice_sent(
await self._server_notices_manager.send_notice(user_id, content)
await self._store.user_set_consent_server_notice_sent(
user_id, self._current_consent_version
)
except SynapseError as e:

View File

@ -50,8 +50,7 @@ class ResourceLimitsServerNotices(object):
self._notifier = hs.get_notifier()
@defer.inlineCallbacks
def maybe_send_server_notice_to_user(self, user_id):
async def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, this will be true in
two cases.
1. The server has reached its limit does not reflect this
@ -74,13 +73,13 @@ class ResourceLimitsServerNotices(object):
# Don't try and send server notices unless they've been enabled
return
timestamp = yield self._store.user_last_seen_monthly_active(user_id)
timestamp = await self._store.user_last_seen_monthly_active(user_id)
if timestamp is None:
# This user will be blocked from receiving the notice anyway.
# In practice, not sure we can ever get here
return
room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user(
room_id = await self._server_notices_manager.get_or_create_notice_room_for_user(
user_id
)
@ -88,10 +87,10 @@ class ResourceLimitsServerNotices(object):
logger.warning("Failed to get server notices room")
return
yield self._check_and_set_tags(user_id, room_id)
await self._check_and_set_tags(user_id, room_id)
# Determine current state of room
currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id)
currently_blocked, ref_events = await self._is_room_currently_blocked(room_id)
limit_msg = None
limit_type = None
@ -99,7 +98,7 @@ class ResourceLimitsServerNotices(object):
# Normally should always pass in user_id to check_auth_blocking
# if you have it, but in this case are checking what would happen
# to other users if they were to arrive.
yield self._auth.check_auth_blocking()
await self._auth.check_auth_blocking()
except ResourceLimitError as e:
limit_msg = e.msg
limit_type = e.limit_type
@ -112,22 +111,21 @@ class ResourceLimitsServerNotices(object):
# We have hit the MAU limit, but MAU alerting is disabled:
# reset room if necessary and return
if currently_blocked:
self._remove_limit_block_notification(user_id, ref_events)
await self._remove_limit_block_notification(user_id, ref_events)
return
if currently_blocked and not limit_msg:
# Room is notifying of a block, when it ought not to be.
yield self._remove_limit_block_notification(user_id, ref_events)
await self._remove_limit_block_notification(user_id, ref_events)
elif not currently_blocked and limit_msg:
# Room is not notifying of a block, when it ought to be.
yield self._apply_limit_block_notification(
await self._apply_limit_block_notification(
user_id, limit_msg, limit_type
)
except SynapseError as e:
logger.error("Error sending resource limits server notice: %s", e)
@defer.inlineCallbacks
def _remove_limit_block_notification(self, user_id, ref_events):
async def _remove_limit_block_notification(self, user_id, ref_events):
"""Utility method to remove limit block notifications from the server
notices room.
@ -137,12 +135,13 @@ class ResourceLimitsServerNotices(object):
limit blocking and need to be preserved.
"""
content = {"pinned": ref_events}
yield self._server_notices_manager.send_notice(
await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Pinned, ""
)
@defer.inlineCallbacks
def _apply_limit_block_notification(self, user_id, event_body, event_limit_type):
async def _apply_limit_block_notification(
self, user_id, event_body, event_limit_type
):
"""Utility method to apply limit block notifications in the server
notices room.
@ -159,12 +158,12 @@ class ResourceLimitsServerNotices(object):
"admin_contact": self._config.admin_contact,
"limit_type": event_limit_type,
}
event = yield self._server_notices_manager.send_notice(
event = await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Message
)
content = {"pinned": [event.event_id]}
yield self._server_notices_manager.send_notice(
await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Pinned, ""
)
@ -198,7 +197,7 @@ class ResourceLimitsServerNotices(object):
room_id(str): The room id of the server notices room
Returns:
Deferred[Tuple[bool, List]]:
bool: Is the room currently blocked
list: The list of pinned events that are unrelated to limit blocking
This list can be used as a convenience in the case where the block

View File

@ -14,11 +14,9 @@
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.types import UserID, create_requester
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@ -51,8 +49,7 @@ class ServerNoticesManager(object):
"""
return self._config.server_notices_mxid is not None
@defer.inlineCallbacks
def send_notice(
async def send_notice(
self, user_id, event_content, type=EventTypes.Message, state_key=None
):
"""Send a notice to the given user
@ -68,8 +65,8 @@ class ServerNoticesManager(object):
Returns:
Deferred[FrozenEvent]
"""
room_id = yield self.get_or_create_notice_room_for_user(user_id)
yield self.maybe_invite_user_to_room(user_id, room_id)
room_id = await self.get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id)
system_mxid = self._config.server_notices_mxid
requester = create_requester(system_mxid)
@ -86,13 +83,13 @@ class ServerNoticesManager(object):
if state_key is not None:
event_dict["state_key"] = state_key
res = yield self._event_creation_handler.create_and_send_nonmember_event(
res = await self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False
)
return res
@cachedInlineCallbacks()
def get_or_create_notice_room_for_user(self, user_id):
@cached()
async def get_or_create_notice_room_for_user(self, user_id):
"""Get the room for notices for a given user
If we have not yet created a notice room for this user, create it, but don't
@ -109,7 +106,7 @@ class ServerNoticesManager(object):
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
for room in rooms:
@ -118,7 +115,7 @@ class ServerNoticesManager(object):
# be joined. This is kinda deliberate, in that if somebody somehow
# manages to invite the system user to a room, that doesn't make it
# the server notices room.
user_ids = yield self._store.get_users_in_room(room.room_id)
user_ids = await self._store.get_users_in_room(room.room_id)
if self.server_notices_mxid in user_ids:
# we found a room which our user shares with the system notice
# user
@ -146,7 +143,7 @@ class ServerNoticesManager(object):
}
requester = create_requester(self.server_notices_mxid)
info = yield self._room_creation_handler.create_room(
info = await self._room_creation_handler.create_room(
requester,
config={
"preset": RoomCreationPreset.PRIVATE_CHAT,
@ -158,7 +155,7 @@ class ServerNoticesManager(object):
)
room_id = info["room_id"]
max_id = yield self._store.add_tag_to_room(
max_id = await self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
@ -166,8 +163,7 @@ class ServerNoticesManager(object):
logger.info("Created server notices room %s for %s", room_id, user_id)
return room_id
@defer.inlineCallbacks
def maybe_invite_user_to_room(self, user_id: str, room_id: str):
async def maybe_invite_user_to_room(self, user_id: str, room_id: str):
"""Invite the given user to the given server room, unless the user has already
joined or been invited to it.
@ -179,14 +175,14 @@ class ServerNoticesManager(object):
# Check whether the user has already joined or been invited to this room. If
# that's the case, there is no need to re-invite them.
joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
for room in joined_rooms:
if room.room_id == room_id:
return
yield self._room_member_handler.update_membership(
await self._room_member_handler.update_membership(
requester=requester,
target=UserID.from_string(user_id),
room_id=room_id,

View File

@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.server_notices.consent_server_notices import ConsentServerNotices
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
@ -36,18 +34,16 @@ class ServerNoticesSender(object):
ResourceLimitsServerNotices(hs),
)
@defer.inlineCallbacks
def on_user_syncing(self, user_id):
async def on_user_syncing(self, user_id):
"""Called when the user performs a sync operation.
Args:
user_id (str): mxid of user who synced
"""
for sn in self._server_notices:
yield sn.maybe_send_server_notice_to_user(user_id)
await sn.maybe_send_server_notice_to_user(user_id)
@defer.inlineCallbacks
def on_user_ip(self, user_id):
async def on_user_ip(self, user_id):
"""Called on the master when a worker process saw a client request.
Args:
@ -57,4 +53,4 @@ class ServerNoticesSender(object):
# we check for notices to send to the user in on_user_ip as well as
# in on_user_syncing
for sn in self._server_notices:
yield sn.maybe_send_server_notice_to_user(user_id)
await sn.maybe_send_server_notice_to_user(user_id)

View File

@ -66,6 +66,7 @@ from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionStore
from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
@ -112,6 +113,7 @@ class DataStore(
StatsStore,
RelationsStore,
CacheInvalidationStore,
UIAuthStore,
):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs

View File

@ -178,7 +178,7 @@ class AccountDataWorkerStore(SQLBaseStore):
async def get_updated_global_account_data(
self, last_id: int, current_id: int, limit: int
) -> List[Tuple]:
) -> List[Tuple[int, str, str]]:
"""Get the global account_data that has changed, for the account_data stream
Args:
@ -208,7 +208,7 @@ class AccountDataWorkerStore(SQLBaseStore):
async def get_updated_room_account_data(
self, last_id: int, current_id: int, limit: int
) -> List[Tuple]:
) -> List[Tuple[int, str, str, str]]:
"""Get the global account_data that has changed, for the account_data stream
Args:

View File

@ -273,8 +273,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user",
)
@defer.inlineCallbacks
def is_server_admin(self, user):
async def is_server_admin(self, user):
"""Determines if a user is an admin of this homeserver.
Args:
@ -283,7 +282,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns (bool):
true iff the user is a server admin, false otherwise.
"""
res = yield self.db.simple_select_one_onecol(
res = await self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",

View File

@ -0,0 +1,36 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS ui_auth_sessions(
session_id TEXT NOT NULL, -- The session ID passed to the client.
creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds).
serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse.
clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client.
uri TEXT NOT NULL, -- The URI the UI authentication session is using.
method TEXT NOT NULL, -- The HTTP method the UI authentication session is using.
-- The clientdict, uri, and method make up an tuple that must be immutable
-- throughout the lifetime of the UI Auth session.
description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur.
UNIQUE (session_id)
);
CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials(
session_id TEXT NOT NULL, -- The corresponding UI Auth session.
stage_type TEXT NOT NULL, -- The stage type.
result TEXT NOT NULL, -- The result of the stage verification, stored as JSON.
UNIQUE (session_id, stage_type),
FOREIGN KEY (session_id)
REFERENCES ui_auth_sessions (session_id)
);

View File

@ -0,0 +1,279 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Dict, Optional, Union
import attr
import synapse.util.stringutils as stringutils
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.types import JsonDict
@attr.s
class UIAuthSessionData:
session_id = attr.ib(type=str)
# The dictionary from the client root level, not the 'auth' key.
clientdict = attr.ib(type=JsonDict)
# The URI and method the session was intiatied with. These are checked at
# each stage of the authentication to ensure that the asked for operation
# has not changed.
uri = attr.ib(type=str)
method = attr.ib(type=str)
# A string description of the operation that the current authentication is
# authorising.
description = attr.ib(type=str)
class UIAuthWorkerStore(SQLBaseStore):
"""
Manage user interactive authentication sessions.
"""
async def create_ui_auth_session(
self, clientdict: JsonDict, uri: str, method: str, description: str,
) -> UIAuthSessionData:
"""
Creates a new user interactive authentication session.
The session can be used to track the stages necessary to authenticate a
user across multiple HTTP requests.
Args:
clientdict:
The dictionary from the client root level, not the 'auth' key.
uri:
The URI this session was initiated with, this is checked at each
stage of the authentication to ensure that the asked for
operation has not changed.
method:
The method this session was initiated with, this is checked at each
stage of the authentication to ensure that the asked for
operation has not changed.
description:
A string description of the operation that the current
authentication is authorising.
Returns:
The newly created session.
Raises:
StoreError if a unique session ID cannot be generated.
"""
# The clientdict gets stored as JSON.
clientdict_json = json.dumps(clientdict)
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
while attempts < 5:
session_id = stringutils.random_string(24)
try:
await self.db.simple_insert(
table="ui_auth_sessions",
values={
"session_id": session_id,
"clientdict": clientdict_json,
"uri": uri,
"method": method,
"description": description,
"serverdict": "{}",
"creation_time": self.hs.get_clock().time_msec(),
},
desc="create_ui_auth_session",
)
return UIAuthSessionData(
session_id, clientdict, uri, method, description
)
except self.db.engine.module.IntegrityError:
attempts += 1
raise StoreError(500, "Couldn't generate a session ID.")
async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
"""Retrieve a UI auth session.
Args:
session_id: The ID of the session.
Returns:
A dict containing the device information.
Raises:
StoreError if the session is not found.
"""
result = await self.db.simple_select_one(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("clientdict", "uri", "method", "description"),
desc="get_ui_auth_session",
)
result["clientdict"] = json.loads(result["clientdict"])
return UIAuthSessionData(session_id, **result)
async def mark_ui_auth_stage_complete(
self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
):
"""
Mark a session stage as completed.
Args:
session_id: The ID of the corresponding session.
stage_type: The completed stage type.
result: The result of the stage verification.
Raises:
StoreError if the session cannot be found.
"""
# Add (or update) the results of the current stage to the database.
#
# Note that we need to allow for the same stage to complete multiple
# times here so that registration is idempotent.
try:
await self.db.simple_upsert(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id, "stage_type": stage_type},
values={"result": json.dumps(result)},
desc="mark_ui_auth_stage_complete",
)
except self.db.engine.module.IntegrityError:
raise StoreError(400, "Unknown session ID: %s" % (session_id,))
async def get_completed_ui_auth_stages(
self, session_id: str
) -> Dict[str, Union[str, bool, JsonDict]]:
"""
Retrieve the completed stages of a UI authentication session.
Args:
session_id: The ID of the session.
Returns:
The completed stages mapped to the result of the verification of
that auth-type.
"""
results = {}
for row in await self.db.simple_select_list(
table="ui_auth_sessions_credentials",
keyvalues={"session_id": session_id},
retcols=("stage_type", "result"),
desc="get_completed_ui_auth_stages",
):
results[row["stage_type"]] = json.loads(row["result"])
return results
async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
the client.
Args:
session_id: The ID of this session as returned from check_auth
key: The key to store the data under
value: The data to store
Raises:
StoreError if the session cannot be found.
"""
await self.db.runInteraction(
"set_ui_auth_session_data",
self._set_ui_auth_session_data_txn,
session_id,
key,
value,
)
def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
# Get the current value.
result = self.db.simple_select_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
)
# Update it and add it back to the database.
serverdict = json.loads(result["serverdict"])
serverdict[key] = value
self.db.simple_update_one_txn(
txn,
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
updatevalues={"serverdict": json.dumps(serverdict)},
)
async def get_ui_auth_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
"""
Retrieve data stored with set_session_data
Args:
session_id: The ID of this session as returned from check_auth
key: The key to store the data under
default: Value to return if the key has not been set
Raises:
StoreError if the session cannot be found.
"""
result = await self.db.simple_select_one(
table="ui_auth_sessions",
keyvalues={"session_id": session_id},
retcols=("serverdict",),
desc="get_ui_auth_session_data",
)
serverdict = json.loads(result["serverdict"])
return serverdict.get(key, default)
class UIAuthStore(UIAuthWorkerStore):
def delete_old_ui_auth_sessions(self, expiration_time: int):
"""
Remove sessions which were last used earlier than the expiration time.
Args:
expiration_time: The latest time that is still considered valid.
This is an epoch time in milliseconds.
"""
return self.db.runInteraction(
"delete_old_ui_auth_sessions",
self._delete_old_ui_auth_sessions_txn,
expiration_time,
)
def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
# Get the expired sessions.
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
txn.execute(sql, [expiration_time])
session_ids = [r[0] for r in txn.fetchall()]
# Delete the corresponding completed credentials.
self.db.simple_delete_many_txn(
txn,
table="ui_auth_sessions_credentials",
column="session_id",
iterable=session_ids,
keyvalues={},
)
# Finally, delete the sessions.
self.db.simple_delete_many_txn(
txn,
table="ui_auth_sessions",
column="session_id",
iterable=session_ids,
keyvalues={},
)

View File

@ -85,6 +85,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank)
db_conn.execute("PRAGMA foreign_keys = ON;")
def is_deadlock(self, error):
return False

View File

@ -82,18 +82,26 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_name(self):
yield self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
yield defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
)
self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank Jr.",
)
# Set displayname again
yield self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank"
yield defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank"
)
)
self.assertEquals(
@ -112,16 +120,20 @@ class ProfileTestCase(unittest.TestCase):
)
# Setting displayname a second time is forbidden
d = self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
d = defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
)
yield self.assertFailure(d, SynapseError)
@defer.inlineCallbacks
def test_set_my_name_noauth(self):
d = self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
d = defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
)
)
yield self.assertFailure(d, AuthError)
@ -165,10 +177,12 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_avatar(self):
yield self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
yield defer.ensureDeferred(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
)
)
self.assertEquals(
@ -177,10 +191,12 @@ class ProfileTestCase(unittest.TestCase):
)
# Set avatar again
yield self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/me.png",
yield defer.ensureDeferred(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/me.png",
)
)
self.assertEquals(
@ -203,10 +219,12 @@ class ProfileTestCase(unittest.TestCase):
)
# Set avatar a second time is forbidden
d = self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
d = defer.ensureDeferred(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
)
)
yield self.assertFailure(d, SynapseError)

View File

@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.is_real_user = Mock(return_value=False)
self.store.is_real_user = Mock(return_value=defer.succeed(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.count_real_users = Mock(return_value=1)
self.store.is_real_user = Mock(return_value=True)
self.store.count_real_users = Mock(return_value=defer.succeed(1))
self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.count_real_users = Mock(return_value=2)
self.store.is_real_user = Mock(return_value=True)
self.store.count_real_users = Mock(return_value=defer.succeed(2))
self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
@defer.inlineCallbacks
def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
yield self.hs.get_auth().check_auth_blocking()
await self.hs.get_auth().check_auth_blocking()
need_register = True
try:
yield self.handler.check_username(localpart)
await self.handler.check_username(localpart)
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
need_register = False
@ -288,23 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
token = self.macaroon_generator.generate_access_token(user_id)
if need_register:
yield self.handler.register_with_store(
await self.handler.register_with_store(
user_id=user_id,
password_hash=password_hash,
create_profile_with_displayname=user.localpart,
)
else:
yield defer.ensureDeferred(
self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
)
await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(
await self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None
)
if displayname is not None:
# logger.info("setting user display name: %s -> %s", user_id, displayname)
yield self.hs.get_profile_handler().set_displayname(
await self.hs.get_profile_handler().set_displayname(
user, requester, displayname, by_admin=True
)

View File

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, List, Optional, Tuple
import attr
@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from synapse.app.generic_worker import GenericWorkerServer
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._server_transport = None
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore())
return TestReplicationDataHandler(self.worker_hs)
def reconnect(self):
if self._client_transport:
@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET")
class TestReplicationDataHandler(ReplicationDataHandler):
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
def __init__(self, store: BaseSlavedStore):
super().__init__(store)
# streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]
def __init__(self, hs: HomeServer):
super().__init__(hs)
# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
def get_streams_to_replicate(self):
return self.stream_positions
async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token
@attr.s()
class OneShotRequestFactory:

View File

@ -12,9 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from twisted.internet import defer
from synapse.replication.tcp.streams._base import (
_STREAM_UPDATE_TARGET_ROW_COUNT,
@ -25,10 +22,6 @@ from tests.replication.tcp.streams._base import BaseStreamTestCase
class AccountDataStreamTestCase(BaseStreamTestCase):
def prepare(self, reactor, clock, hs):
super().prepare(reactor, clock, hs)
self.test_handler.stream_positions["account_data"] = 0
def test_update_function_room_account_data_limit(self):
"""Test replication with many room account data updates
"""

View File

@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.user_tok = self.login("u1", "pass")
self.reconnect()
self.test_handler.stream_positions["events"] = 0
self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()
@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
# we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
for event in events:
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect()
self.replicate()
# now we should have received all the expected rows in the right order.
# we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
#
# we expect:
#
@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
# of the states that got reverted.
# - two rows for state2
received_rows = self.test_handler.received_rdata_rows
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
# first check the first two rows, which should be state1
@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
# we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
self.assertGreaterEqual(len(received_rows), len(events))
for i in range(NUM_USERS):
# for each user, we expect the PL event row, followed by state rows for

View File

@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
def test_receipt(self):
self.reconnect()
# make the client subscribe to the receipts stream
self.test_handler.stream_positions.update({"receipts": 0})
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastore().insert_receipt(
@ -44,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
# there should be one RDATA command
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
@ -74,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
# We should now have caught up and get the missing data
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows))

View File

@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.reconnect()
# make the client subscribe to the typing stream
self.test_handler.stream_positions.update({"typing": 0})
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
self.reactor.advance(0)
@ -50,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assert_request_is_get_repl_stream_updates(request, "typing")
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
@ -77,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(int(request.args[b"from_token"][0]), token)
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]

View File

@ -181,3 +181,43 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(channel.code, 403)
def test_complete_operation_unknown_session(self):
"""
Attempting to mark an invalid session as complete should error.
"""
# Make the initial request to register. (Later on a different password
# will be used.)
request, channel = self.make_request(
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
)
self.render(request)
# Returns a 401 as per the spec
self.assertEqual(request.code, 401)
# Grab the session
session = channel.json_body["session"]
# Assert our configured public key is being given
self.assertEqual(
channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
)
request, channel = self.make_request(
"GET", "auth/m.login.recaptcha/fallback/web?session=" + session
)
self.render(request)
self.assertEqual(request.code, 200)
# Attempt to complete an unknown session, which should return an error.
unknown_session = session + "unknown"
request, channel = self.make_request(
"POST",
"auth/m.login.recaptcha/fallback/web?session="
+ unknown_session
+ "&g-recaptcha-response=a",
)
self.render(request)
self.assertEqual(request.code, 400)

View File

@ -55,25 +55,18 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
self._rlsn._server_notices_manager.send_notice = Mock()
self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
self._rlsn._server_notices_manager.send_notice = Mock(
return_value=defer.succeed(Mock())
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
self.hs.config.limit_usage_by_mau = True
self.user_id = "@user_id:test"
# self.server_notices_mxid = "@server:test"
# self.server_notices_mxid_display_name = None
# self.server_notices_mxid_avatar_url = None
# self.server_notices_room_name = "Server Notices"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
returnValue=""
return_value=defer.succeed("!something:localhost")
)
self._rlsn._store.add_tag_to_room = Mock()
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_tags_for_room = Mock(return_value={})
self.hs.config.admin_contact = "mailto:user@test.com"
@ -95,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed"""
self._rlsn._auth.check_auth_blocking = Mock()
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
self._send_notice.assert_called_once()
@ -112,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, "foo")
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
)
mock_event = Mock(
@ -121,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
@ -129,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, but should have one
"""
self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, "foo")
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -142,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock()
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -153,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
self._rlsn._auth.check_auth_blocking = Mock()
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
)
@ -167,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
an alert message is not sent into the room
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
)
),
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self.assertTrue(self._send_notice.call_count == 0)
self.assertEqual(self._send_notice.call_count, 0)
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
)
),
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -198,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
)
),
)
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
return_value=defer.succeed((True, []))
)
@ -256,7 +254,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
)
# Call the function multiple times to ensure we only send the notice once
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))

View File

@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase):
user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
room = room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
room = ensureDeferred(
room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
)
)
self.reactor.advance(0.1)
self.room_id = self.successResultOf(room)["room_id"]

View File

@ -512,8 +512,8 @@ class MockClock(object):
return t
def looping_call(self, function, interval):
self.loopers.append([function, interval / 1000.0, self.now])
def looping_call(self, function, interval, *args, **kwargs):
self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
@ -543,9 +543,9 @@ class MockClock(object):
self.timers.append(t)
for looped in self.loopers:
func, interval, last = looped
func, interval, last, args, kwargs = looped
if last + interval < self.now:
func()
func(*args, **kwargs)
looped[2] = self.now
def advance_time_msec(self, ms):

View File

@ -200,8 +200,9 @@ commands = mypy \
synapse/replication \
synapse/rest \
synapse/spam_checker_api \
synapse/storage/engines \
synapse/storage/data_stores/main/ui_auth.py \
synapse/storage/database.py \
synapse/storage/engines \
synapse/streams \
synapse/util/caches/stream_change_cache.py \
tests/replication/tcp/streams \