Compare commits
9 Commits
2a6ad9c3d5
...
5f4e4819ca
Author | SHA1 | Date |
---|---|---|
Richard van der Hoff | 5f4e4819ca | |
Richard van der Hoff | 3356d9fe23 | |
Richard van der Hoff | ba2b934307 | |
Erik Johnston | 0e719f2398 | |
Erik Johnston | 3085cde577 | |
Andrew Morgan | 6b22921b19 | |
Andrew Morgan | 2e8955f4a6 | |
Richard van der Hoff | b2dba06079 | |
Patrick Cloke | 627b0f5f27 |
|
@ -0,0 +1 @@
|
|||
Use `stream.current_token()` and remove `stream_positions()`.
|
|
@ -0,0 +1 @@
|
|||
Persist user interactive authentication sessions across workers and Synapse restarts.
|
|
@ -0,0 +1 @@
|
|||
Convert RegistrationWorkerStore.is_server_admin and dependent code to async/await.
|
|
@ -0,0 +1 @@
|
|||
Improve error responses when accessing remote public room lists.
|
|
@ -0,0 +1 @@
|
|||
Thread through instance name to replication client.
|
|
@ -0,0 +1 @@
|
|||
Move catchup of replication streams logic to worker.
|
|
@ -537,8 +537,7 @@ class Auth(object):
|
|||
|
||||
return defer.succeed(auth_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_can_change_room_list(self, room_id: str, user: UserID):
|
||||
async def check_can_change_room_list(self, room_id: str, user: UserID):
|
||||
"""Determine whether the user is allowed to edit the room's entry in the
|
||||
published room list.
|
||||
|
||||
|
@ -547,17 +546,17 @@ class Auth(object):
|
|||
user
|
||||
"""
|
||||
|
||||
is_admin = yield self.is_server_admin(user)
|
||||
is_admin = await self.is_server_admin(user)
|
||||
if is_admin:
|
||||
return True
|
||||
|
||||
user_id = user.to_string()
|
||||
yield self.check_user_in_room(room_id, user_id)
|
||||
await self.check_user_in_room(room_id, user_id)
|
||||
|
||||
# We currently require the user is a "moderator" in the room. We do this
|
||||
# by checking if they would (theoretically) be able to change the
|
||||
# m.room.canonical_alias events
|
||||
power_level_event = yield self.state.get_current_state(
|
||||
power_level_event = await self.state.get_current_state(
|
||||
room_id, EventTypes.PowerLevels, ""
|
||||
)
|
||||
|
||||
|
|
|
@ -127,6 +127,7 @@ from synapse.storage.data_stores.main.monthly_active_users import (
|
|||
MonthlyActiveUsersWorkerStore,
|
||||
)
|
||||
from synapse.storage.data_stores.main.presence import UserPresenceState
|
||||
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
|
||||
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
|
||||
from synapse.types import ReadReceipt
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
@ -412,12 +413,6 @@ class GenericWorkerTyping(object):
|
|||
# map room IDs to sets of users currently typing
|
||||
self._room_typing = {}
|
||||
|
||||
def stream_positions(self):
|
||||
# We must update this typing token from the response of the previous
|
||||
# sync. In particular, the stream id may "reset" back to zero/a low
|
||||
# value which we *must* use for the next replication request.
|
||||
return {"typing": self._latest_room_serial}
|
||||
|
||||
def process_replication_rows(self, token, rows):
|
||||
if self._latest_room_serial > token:
|
||||
# The master has gone backwards. To prevent inconsistent data, just
|
||||
|
@ -439,6 +434,7 @@ class GenericWorkerSlavedStore(
|
|||
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
|
||||
# rather than going via the correct worker.
|
||||
UserDirectoryStore,
|
||||
UIAuthWorkerStore,
|
||||
SlavedDeviceInboxStore,
|
||||
SlavedDeviceStore,
|
||||
SlavedReceiptsStore,
|
||||
|
@ -650,20 +646,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
|
|||
else:
|
||||
self.send_handler = None
|
||||
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
await super(GenericWorkerReplicationHandler, self).on_rdata(
|
||||
stream_name, token, rows
|
||||
)
|
||||
await self.process_and_notify(stream_name, token, rows)
|
||||
async def on_rdata(self, stream_name, instance_name, token, rows):
|
||||
await super().on_rdata(stream_name, instance_name, token, rows)
|
||||
await self._process_and_notify(stream_name, instance_name, token, rows)
|
||||
|
||||
def get_streams_to_replicate(self):
|
||||
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
|
||||
args.update(self.typing_handler.stream_positions())
|
||||
if self.send_handler:
|
||||
args.update(self.send_handler.stream_positions())
|
||||
return args
|
||||
|
||||
async def process_and_notify(self, stream_name, token, rows):
|
||||
async def _process_and_notify(self, stream_name, instance_name, token, rows):
|
||||
try:
|
||||
if self.send_handler:
|
||||
await self.send_handler.process_replication_rows(
|
||||
|
@ -797,9 +784,6 @@ class FederationSenderHandler(object):
|
|||
def wake_destination(self, server: str):
|
||||
self.federation_sender.wake_destination(server)
|
||||
|
||||
def stream_positions(self):
|
||||
return {"federation": self.federation_position}
|
||||
|
||||
async def process_replication_rows(self, stream_name, token, rows):
|
||||
# The federation stream contains things that we want to send out, e.g.
|
||||
# presence, typing, etc.
|
||||
|
|
|
@ -883,18 +883,37 @@ class FederationClient(FederationBase):
|
|||
|
||||
def get_public_rooms(
|
||||
self,
|
||||
destination,
|
||||
limit=None,
|
||||
since_token=None,
|
||||
search_filter=None,
|
||||
include_all_networks=False,
|
||||
third_party_instance_id=None,
|
||||
remote_server: str,
|
||||
limit: Optional[int] = None,
|
||||
since_token: Optional[str] = None,
|
||||
search_filter: Optional[Dict] = None,
|
||||
include_all_networks: bool = False,
|
||||
third_party_instance_id: Optional[str] = None,
|
||||
):
|
||||
if destination == self.server_name:
|
||||
return
|
||||
"""Get the list of public rooms from a remote homeserver
|
||||
|
||||
Args:
|
||||
remote_server: The name of the remote server
|
||||
limit: Maximum amount of rooms to return
|
||||
since_token: Used for result pagination
|
||||
search_filter: A filter dictionary to send the remote homeserver
|
||||
and filter the result set
|
||||
include_all_networks: Whether to include results from all third party instances
|
||||
third_party_instance_id: Whether to only include results from a specific third
|
||||
party instance
|
||||
|
||||
Returns:
|
||||
Deferred[Dict[str, Any]]: The response from the remote server, or None if
|
||||
`remote_server` is the same as the local server_name
|
||||
|
||||
Raises:
|
||||
HttpResponseException: There was an exception returned from the remote server
|
||||
SynapseException: M_FORBIDDEN when the remote server has disallowed publicRoom
|
||||
requests over federation
|
||||
|
||||
"""
|
||||
return self.transport_layer.get_public_rooms(
|
||||
destination,
|
||||
remote_server,
|
||||
limit,
|
||||
since_token,
|
||||
search_filter,
|
||||
|
@ -957,14 +976,13 @@ class FederationClient(FederationBase):
|
|||
|
||||
return signed_events
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def forward_third_party_invite(self, destinations, room_id, event_dict):
|
||||
async def forward_third_party_invite(self, destinations, room_id, event_dict):
|
||||
for destination in destinations:
|
||||
if destination == self.server_name:
|
||||
continue
|
||||
|
||||
try:
|
||||
yield self.transport_layer.exchange_third_party_invite(
|
||||
await self.transport_layer.exchange_third_party_invite(
|
||||
destination=destination, room_id=room_id, event_dict=event_dict
|
||||
)
|
||||
return None
|
||||
|
|
|
@ -15,13 +15,14 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.api.urls import (
|
||||
FEDERATION_UNSTABLE_PREFIX,
|
||||
FEDERATION_V1_PREFIX,
|
||||
|
@ -326,18 +327,25 @@ class TransportLayerClient(object):
|
|||
@log_function
|
||||
def get_public_rooms(
|
||||
self,
|
||||
remote_server,
|
||||
limit,
|
||||
since_token,
|
||||
search_filter=None,
|
||||
include_all_networks=False,
|
||||
third_party_instance_id=None,
|
||||
remote_server: str,
|
||||
limit: Optional[int] = None,
|
||||
since_token: Optional[str] = None,
|
||||
search_filter: Optional[Dict] = None,
|
||||
include_all_networks: bool = False,
|
||||
third_party_instance_id: Optional[str] = None,
|
||||
):
|
||||
"""Get the list of public rooms from a remote homeserver
|
||||
|
||||
See synapse.federation.federation_client.FederationClient.get_public_rooms for
|
||||
more information.
|
||||
"""
|
||||
if search_filter:
|
||||
# this uses MSC2197 (Search Filtering over Federation)
|
||||
path = _create_v1_path("/publicRooms")
|
||||
|
||||
data = {"include_all_networks": "true" if include_all_networks else "false"}
|
||||
data = {
|
||||
"include_all_networks": "true" if include_all_networks else "false"
|
||||
} # type: Dict[str, Any]
|
||||
if third_party_instance_id:
|
||||
data["third_party_instance_id"] = third_party_instance_id
|
||||
if limit:
|
||||
|
@ -347,9 +355,19 @@ class TransportLayerClient(object):
|
|||
|
||||
data["filter"] = search_filter
|
||||
|
||||
response = yield self.client.post_json(
|
||||
destination=remote_server, path=path, data=data, ignore_backoff=True
|
||||
)
|
||||
try:
|
||||
response = yield self.client.post_json(
|
||||
destination=remote_server, path=path, data=data, ignore_backoff=True
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
if e.code == 403:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"You are not allowed to view the public rooms list of %s"
|
||||
% (remote_server,),
|
||||
errcode=Codes.FORBIDDEN,
|
||||
)
|
||||
raise
|
||||
else:
|
||||
path = _create_v1_path("/publicRooms")
|
||||
|
||||
|
@ -363,9 +381,19 @@ class TransportLayerClient(object):
|
|||
if since_token:
|
||||
args["since"] = [since_token]
|
||||
|
||||
response = yield self.client.get_json(
|
||||
destination=remote_server, path=path, args=args, ignore_backoff=True
|
||||
)
|
||||
try:
|
||||
response = yield self.client.get_json(
|
||||
destination=remote_server, path=path, args=args, ignore_backoff=True
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
if e.code == 403:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"You are not allowed to view the public rooms list of %s"
|
||||
% (remote_server,),
|
||||
errcode=Codes.FORBIDDEN,
|
||||
)
|
||||
raise
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
@ -748,17 +748,18 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||
|
||||
raise NotImplementedError()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
||||
async def remove_user_from_group(
|
||||
self, group_id, user_id, requester_user_id, content
|
||||
):
|
||||
"""Remove a user from the group; either a user is leaving or an admin
|
||||
kicked them.
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
is_kick = False
|
||||
if requester_user_id != user_id:
|
||||
is_admin = yield self.store.is_user_admin_in_group(
|
||||
is_admin = await self.store.is_user_admin_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
if not is_admin:
|
||||
|
@ -766,30 +767,29 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||
|
||||
is_kick = True
|
||||
|
||||
yield self.store.remove_user_from_group(group_id, user_id)
|
||||
await self.store.remove_user_from_group(group_id, user_id)
|
||||
|
||||
if is_kick:
|
||||
if self.hs.is_mine_id(user_id):
|
||||
groups_local = self.hs.get_groups_local_handler()
|
||||
yield groups_local.user_removed_from_group(group_id, user_id, {})
|
||||
await groups_local.user_removed_from_group(group_id, user_id, {})
|
||||
else:
|
||||
yield self.transport_client.remove_user_from_group_notification(
|
||||
await self.transport_client.remove_user_from_group_notification(
|
||||
get_domain_from_id(user_id), group_id, user_id, {}
|
||||
)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
||||
await self.store.maybe_delete_remote_profile_cache(user_id)
|
||||
|
||||
# Delete group if the last user has left
|
||||
users = yield self.store.get_users_in_group(group_id, include_private=True)
|
||||
users = await self.store.get_users_in_group(group_id, include_private=True)
|
||||
if not users:
|
||||
yield self.store.delete_group(group_id)
|
||||
await self.store.delete_group(group_id)
|
||||
|
||||
return {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_group(self, group_id, requester_user_id, content):
|
||||
group = yield self.check_group_is_ours(group_id, requester_user_id)
|
||||
async def create_group(self, group_id, requester_user_id, content):
|
||||
group = await self.check_group_is_ours(group_id, requester_user_id)
|
||||
|
||||
logger.info("Attempting to create group with ID: %r", group_id)
|
||||
|
||||
|
@ -799,7 +799,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||
if group:
|
||||
raise SynapseError(400, "Group already exists")
|
||||
|
||||
is_admin = yield self.auth.is_server_admin(
|
||||
is_admin = await self.auth.is_server_admin(
|
||||
UserID.from_string(requester_user_id)
|
||||
)
|
||||
if not is_admin:
|
||||
|
@ -822,7 +822,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||
long_description = profile.get("long_description")
|
||||
user_profile = content.get("user_profile", {})
|
||||
|
||||
yield self.store.create_group(
|
||||
await self.store.create_group(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
name=name,
|
||||
|
@ -834,7 +834,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||
if not self.hs.is_mine_id(requester_user_id):
|
||||
remote_attestation = content["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
await self.attestations.verify_attestation(
|
||||
remote_attestation, user_id=requester_user_id, group_id=group_id
|
||||
)
|
||||
|
||||
|
@ -845,7 +845,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||
local_attestation = None
|
||||
remote_attestation = None
|
||||
|
||||
yield self.store.add_user_to_group(
|
||||
await self.store.add_user_to_group(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
is_admin=True,
|
||||
|
@ -855,7 +855,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||
)
|
||||
|
||||
if not self.hs.is_mine_id(requester_user_id):
|
||||
yield self.store.add_remote_profile_cache(
|
||||
await self.store.add_remote_profile_cache(
|
||||
requester_user_id,
|
||||
displayname=user_profile.get("displayname"),
|
||||
avatar_url=user_profile.get("avatar_url"),
|
||||
|
@ -863,8 +863,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||
|
||||
return {"group_id": group_id}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group(self, group_id, requester_user_id):
|
||||
async def delete_group(self, group_id, requester_user_id):
|
||||
"""Deletes a group, kicking out all current members.
|
||||
|
||||
Only group admins or server admins can call this request
|
||||
|
@ -877,14 +876,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||
Deferred
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
|
||||
|
||||
# Only server admins or group admins can delete groups.
|
||||
|
||||
is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id)
|
||||
is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id)
|
||||
|
||||
if not is_admin:
|
||||
is_admin = yield self.auth.is_server_admin(
|
||||
is_admin = await self.auth.is_server_admin(
|
||||
UserID.from_string(requester_user_id)
|
||||
)
|
||||
|
||||
|
@ -892,18 +891,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||
raise SynapseError(403, "User is not an admin")
|
||||
|
||||
# Before deleting the group lets kick everyone out of it
|
||||
users = yield self.store.get_users_in_group(group_id, include_private=True)
|
||||
users = await self.store.get_users_in_group(group_id, include_private=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _kick_user_from_group(user_id):
|
||||
async def _kick_user_from_group(user_id):
|
||||
if self.hs.is_mine_id(user_id):
|
||||
groups_local = self.hs.get_groups_local_handler()
|
||||
yield groups_local.user_removed_from_group(group_id, user_id, {})
|
||||
await groups_local.user_removed_from_group(group_id, user_id, {})
|
||||
else:
|
||||
yield self.transport_client.remove_user_from_group_notification(
|
||||
await self.transport_client.remove_user_from_group_notification(
|
||||
get_domain_from_id(user_id), group_id, user_id, {}
|
||||
)
|
||||
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
||||
await self.store.maybe_delete_remote_profile_cache(user_id)
|
||||
|
||||
# We kick users out in the order of:
|
||||
# 1. Non-admins
|
||||
|
@ -922,11 +920,11 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
|
|||
else:
|
||||
non_admins.append(u["user_id"])
|
||||
|
||||
yield concurrently_execute(_kick_user_from_group, non_admins, 10)
|
||||
yield concurrently_execute(_kick_user_from_group, admins, 10)
|
||||
yield _kick_user_from_group(requester_user_id)
|
||||
await concurrently_execute(_kick_user_from_group, non_admins, 10)
|
||||
await concurrently_execute(_kick_user_from_group, admins, 10)
|
||||
await _kick_user_from_group(requester_user_id)
|
||||
|
||||
yield self.store.delete_group(group_id)
|
||||
await self.store.delete_group(group_id)
|
||||
|
||||
|
||||
def _parse_join_policy_from_contents(content):
|
||||
|
|
|
@ -126,30 +126,28 @@ class BaseHandler(object):
|
|||
retry_after_ms=int(1000 * (time_allowed - time_now))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_kick_guest_users(self, event, context=None):
|
||||
async def maybe_kick_guest_users(self, event, context=None):
|
||||
# Technically this function invalidates current_state by changing it.
|
||||
# Hopefully this isn't that important to the caller.
|
||||
if event.type == EventTypes.GuestAccess:
|
||||
guest_access = event.content.get("guest_access", "forbidden")
|
||||
if guest_access != "can_join":
|
||||
if context:
|
||||
current_state_ids = yield context.get_current_state_ids()
|
||||
current_state = yield self.store.get_events(
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
current_state = await self.store.get_events(
|
||||
list(current_state_ids.values())
|
||||
)
|
||||
else:
|
||||
current_state = yield self.state_handler.get_current_state(
|
||||
current_state = await self.state_handler.get_current_state(
|
||||
event.room_id
|
||||
)
|
||||
|
||||
current_state = list(current_state.values())
|
||||
|
||||
logger.info("maybe_kick_guest_users %r", current_state)
|
||||
yield self.kick_guest_users(current_state)
|
||||
await self.kick_guest_users(current_state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def kick_guest_users(self, current_state):
|
||||
async def kick_guest_users(self, current_state):
|
||||
for member_event in current_state:
|
||||
try:
|
||||
if member_event.type != EventTypes.Member:
|
||||
|
@ -180,7 +178,7 @@ class BaseHandler(object):
|
|||
# homeserver.
|
||||
requester = synapse.types.create_requester(target_user, is_guest=True)
|
||||
handler = self.hs.get_room_member_handler()
|
||||
yield handler.update_membership(
|
||||
await handler.update_membership(
|
||||
requester,
|
||||
target_user,
|
||||
member_event.room_id,
|
||||
|
|
|
@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
|||
from synapse.http.server import finish_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.types import Requester, UserID
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
|
@ -69,15 +69,6 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
self.bcrypt_rounds = hs.config.bcrypt_rounds
|
||||
|
||||
# This is not a cache per se, but a store of all current sessions that
|
||||
# expire after N hours
|
||||
self.sessions = ExpiringCache(
|
||||
cache_name="register_sessions",
|
||||
clock=hs.get_clock(),
|
||||
expiry_ms=self.SESSION_EXPIRE_MS,
|
||||
reset_expiry_on_get=True,
|
||||
)
|
||||
|
||||
account_handler = ModuleApi(hs, self)
|
||||
self.password_providers = [
|
||||
module(config=config, account_handler=account_handler)
|
||||
|
@ -119,6 +110,15 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
self._clock = self.hs.get_clock()
|
||||
|
||||
# Expire old UI auth sessions after a period of time.
|
||||
if hs.config.worker_app is None:
|
||||
self._clock.looping_call(
|
||||
run_as_background_process,
|
||||
5 * 60 * 1000,
|
||||
"expire_old_sessions",
|
||||
self._expire_old_sessions,
|
||||
)
|
||||
|
||||
# Load the SSO HTML templates.
|
||||
|
||||
# The following template is shown to the user during a client login via SSO,
|
||||
|
@ -301,16 +301,21 @@ class AuthHandler(BaseHandler):
|
|||
if "session" in authdict:
|
||||
sid = authdict["session"]
|
||||
|
||||
# Convert the URI and method to strings.
|
||||
uri = request.uri.decode("utf-8")
|
||||
method = request.uri.decode("utf-8")
|
||||
|
||||
# If there's no session ID, create a new session.
|
||||
if not sid:
|
||||
session = self._create_session(
|
||||
clientdict, (request.uri, request.method, clientdict), description
|
||||
session = await self.store.create_ui_auth_session(
|
||||
clientdict, uri, method, description
|
||||
)
|
||||
session_id = session["id"]
|
||||
|
||||
else:
|
||||
session = self._get_session_info(sid)
|
||||
session_id = sid
|
||||
try:
|
||||
session = await self.store.get_ui_auth_session(sid)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (sid,))
|
||||
|
||||
if not clientdict:
|
||||
# This was designed to allow the client to omit the parameters
|
||||
|
@ -322,15 +327,15 @@ class AuthHandler(BaseHandler):
|
|||
# on a homeserver.
|
||||
# Revisit: Assuming the REST APIs do sensible validation, the data
|
||||
# isn't arbitrary.
|
||||
clientdict = session["clientdict"]
|
||||
clientdict = session.clientdict
|
||||
|
||||
# Ensure that the queried operation does not vary between stages of
|
||||
# the UI authentication session. This is done by generating a stable
|
||||
# comparator based on the URI, method, and body (minus the auth dict)
|
||||
# and storing it during the initial query. Subsequent queries ensure
|
||||
# that this comparator has not changed.
|
||||
comparator = (request.uri, request.method, clientdict)
|
||||
if session["ui_auth"] != comparator:
|
||||
comparator = (uri, method, clientdict)
|
||||
if (session.uri, session.method, session.clientdict) != comparator:
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Requested operation has changed during the UI authentication session.",
|
||||
|
@ -338,11 +343,9 @@ class AuthHandler(BaseHandler):
|
|||
|
||||
if not authdict:
|
||||
raise InteractiveAuthIncompleteError(
|
||||
self._auth_dict_for_flows(flows, session_id)
|
||||
self._auth_dict_for_flows(flows, session.session_id)
|
||||
)
|
||||
|
||||
creds = session["creds"]
|
||||
|
||||
# check auth type currently being presented
|
||||
errordict = {} # type: Dict[str, Any]
|
||||
if "type" in authdict:
|
||||
|
@ -350,8 +353,9 @@ class AuthHandler(BaseHandler):
|
|||
try:
|
||||
result = await self._check_auth_dict(authdict, clientip)
|
||||
if result:
|
||||
creds[login_type] = result
|
||||
self._save_session(session)
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
session.session_id, login_type, result
|
||||
)
|
||||
except LoginError as e:
|
||||
if login_type == LoginType.EMAIL_IDENTITY:
|
||||
# riot used to have a bug where it would request a new
|
||||
|
@ -367,6 +371,7 @@ class AuthHandler(BaseHandler):
|
|||
# so that the client can have another go.
|
||||
errordict = e.error_dict()
|
||||
|
||||
creds = await self.store.get_completed_ui_auth_stages(session.session_id)
|
||||
for f in flows:
|
||||
if len(set(f) - set(creds)) == 0:
|
||||
# it's very useful to know what args are stored, but this can
|
||||
|
@ -380,9 +385,9 @@ class AuthHandler(BaseHandler):
|
|||
list(clientdict),
|
||||
)
|
||||
|
||||
return creds, clientdict, session_id
|
||||
return creds, clientdict, session.session_id
|
||||
|
||||
ret = self._auth_dict_for_flows(flows, session_id)
|
||||
ret = self._auth_dict_for_flows(flows, session.session_id)
|
||||
ret["completed"] = list(creds)
|
||||
ret.update(errordict)
|
||||
raise InteractiveAuthIncompleteError(ret)
|
||||
|
@ -399,13 +404,11 @@ class AuthHandler(BaseHandler):
|
|||
if "session" not in authdict:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
sess = self._get_session_info(authdict["session"])
|
||||
creds = sess["creds"]
|
||||
|
||||
result = await self.checkers[stagetype].check_auth(authdict, clientip)
|
||||
if result:
|
||||
creds[stagetype] = result
|
||||
self._save_session(sess)
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
authdict["session"], stagetype, result
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -427,7 +430,7 @@ class AuthHandler(BaseHandler):
|
|||
sid = authdict["session"]
|
||||
return sid
|
||||
|
||||
def set_session_data(self, session_id: str, key: str, value: Any) -> None:
|
||||
async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
|
||||
"""
|
||||
Store a key-value pair into the sessions data associated with this
|
||||
request. This data is stored server-side and cannot be modified by
|
||||
|
@ -438,11 +441,12 @@ class AuthHandler(BaseHandler):
|
|||
key: The key to store the data under
|
||||
value: The data to store
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
sess["serverdict"][key] = value
|
||||
self._save_session(sess)
|
||||
try:
|
||||
await self.store.set_ui_auth_session_data(session_id, key, value)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
def get_session_data(
|
||||
async def get_session_data(
|
||||
self, session_id: str, key: str, default: Optional[Any] = None
|
||||
) -> Any:
|
||||
"""
|
||||
|
@ -453,8 +457,18 @@ class AuthHandler(BaseHandler):
|
|||
key: The key to store the data under
|
||||
default: Value to return if the key has not been set
|
||||
"""
|
||||
sess = self._get_session_info(session_id)
|
||||
return sess["serverdict"].get(key, default)
|
||||
try:
|
||||
return await self.store.get_ui_auth_session_data(session_id, key, default)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
async def _expire_old_sessions(self):
|
||||
"""
|
||||
Invalidate any user interactive authentication sessions that have expired.
|
||||
"""
|
||||
now = self._clock.time_msec()
|
||||
expiration_time = now - self.SESSION_EXPIRE_MS
|
||||
await self.store.delete_old_ui_auth_sessions(expiration_time)
|
||||
|
||||
async def _check_auth_dict(
|
||||
self, authdict: Dict[str, Any], clientip: str
|
||||
|
@ -534,67 +548,6 @@ class AuthHandler(BaseHandler):
|
|||
"params": params,
|
||||
}
|
||||
|
||||
def _create_session(
|
||||
self,
|
||||
clientdict: Dict[str, Any],
|
||||
ui_auth: Tuple[bytes, bytes, Dict[str, Any]],
|
||||
description: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Creates a new user interactive authentication session.
|
||||
|
||||
The session can be used to track data across multiple requests, e.g. for
|
||||
interactive authentication.
|
||||
|
||||
Each session has the following keys:
|
||||
|
||||
id:
|
||||
A unique identifier for this session. Passed back to the client
|
||||
and returned for each stage.
|
||||
clientdict:
|
||||
The dictionary from the client root level, not the 'auth' key.
|
||||
ui_auth:
|
||||
A tuple which is checked at each stage of the authentication to
|
||||
ensure that the asked for operation has not changed.
|
||||
creds:
|
||||
A map, which maps each auth-type (str) to the relevant identity
|
||||
authenticated by that auth-type (mostly str, but for captcha, bool).
|
||||
serverdict:
|
||||
A map of data that is stored server-side and cannot be modified
|
||||
by the client.
|
||||
description:
|
||||
A string description of the operation that the current
|
||||
authentication is authorising.
|
||||
Returns:
|
||||
The newly created session.
|
||||
"""
|
||||
session_id = None
|
||||
while session_id is None or session_id in self.sessions:
|
||||
session_id = stringutils.random_string(24)
|
||||
|
||||
self.sessions[session_id] = {
|
||||
"id": session_id,
|
||||
"clientdict": clientdict,
|
||||
"ui_auth": ui_auth,
|
||||
"creds": {},
|
||||
"serverdict": {},
|
||||
"description": description,
|
||||
}
|
||||
|
||||
return self.sessions[session_id]
|
||||
|
||||
def _get_session_info(self, session_id: str) -> dict:
|
||||
"""
|
||||
Gets a session given a session ID.
|
||||
|
||||
The session can be used to track data across multiple requests, e.g. for
|
||||
interactive authentication.
|
||||
"""
|
||||
try:
|
||||
return self.sessions[session_id]
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
async def get_access_token_for_user_id(
|
||||
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
|
||||
):
|
||||
|
@ -994,13 +947,6 @@ class AuthHandler(BaseHandler):
|
|||
await self.store.user_delete_threepid(user_id, medium, address)
|
||||
return result
|
||||
|
||||
def _save_session(self, session: Dict[str, Any]) -> None:
|
||||
"""Update the last used time on the session to now and add it back to the session store."""
|
||||
# TODO: Persistent storage
|
||||
logger.debug("Saving session %s", session)
|
||||
session["last_used"] = self.hs.get_clock().time_msec()
|
||||
self.sessions[session["id"]] = session
|
||||
|
||||
async def hash(self, password: str) -> str:
|
||||
"""Computes a secure hash of password.
|
||||
|
||||
|
@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler):
|
|||
else:
|
||||
return False
|
||||
|
||||
def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
|
||||
async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
|
||||
"""
|
||||
Get the HTML for the SSO redirect confirmation page.
|
||||
|
||||
|
@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler):
|
|||
Returns:
|
||||
The HTML to render.
|
||||
"""
|
||||
session = self._get_session_info(session_id)
|
||||
try:
|
||||
session = await self.store.get_ui_auth_session(session_id)
|
||||
except StoreError:
|
||||
raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
|
||||
return self._sso_auth_confirm_template.render(
|
||||
description=session["description"], redirect_url=redirect_url,
|
||||
description=session.description, redirect_url=redirect_url,
|
||||
)
|
||||
|
||||
def complete_sso_ui_auth(
|
||||
async def complete_sso_ui_auth(
|
||||
self, registered_user_id: str, session_id: str, request: SynapseRequest,
|
||||
):
|
||||
"""Having figured out a mxid for this user, complete the HTTP request
|
||||
|
@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler):
|
|||
process.
|
||||
"""
|
||||
# Mark the stage of the authentication as successful.
|
||||
sess = self._get_session_info(session_id)
|
||||
creds = sess["creds"]
|
||||
|
||||
# Save the user who authenticated with SSO, this will be used to ensure
|
||||
# that the account be modified is also the person who logged in.
|
||||
creds[LoginType.SSO] = registered_user_id
|
||||
self._save_session(sess)
|
||||
await self.store.mark_ui_auth_stage_complete(
|
||||
session_id, LoginType.SSO, registered_user_id
|
||||
)
|
||||
|
||||
# Render the HTML and return.
|
||||
html_bytes = self._sso_auth_success_template.encode("utf-8")
|
||||
|
|
|
@ -206,7 +206,7 @@ class CasHandler:
|
|||
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||
|
||||
if session:
|
||||
self._auth_handler.complete_sso_ui_auth(
|
||||
await self._auth_handler.complete_sso_ui_auth(
|
||||
registered_user_id, session, request,
|
||||
)
|
||||
|
||||
|
|
|
@ -86,8 +86,7 @@ class DirectoryHandler(BaseHandler):
|
|||
room_alias, room_id, servers, creator=creator
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_association(
|
||||
async def create_association(
|
||||
self,
|
||||
requester: Requester,
|
||||
room_alias: RoomAlias,
|
||||
|
@ -129,10 +128,10 @@ class DirectoryHandler(BaseHandler):
|
|||
else:
|
||||
# Server admins are not subject to the same constraints as normal
|
||||
# users when creating an alias (e.g. being in the room).
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
is_admin = await self.auth.is_server_admin(requester.user)
|
||||
|
||||
if (self.require_membership and check_membership) and not is_admin:
|
||||
rooms_for_user = yield self.store.get_rooms_for_user(user_id)
|
||||
rooms_for_user = await self.store.get_rooms_for_user(user_id)
|
||||
if room_id not in rooms_for_user:
|
||||
raise AuthError(
|
||||
403, "You must be in the room to create an alias for it"
|
||||
|
@ -149,7 +148,7 @@ class DirectoryHandler(BaseHandler):
|
|||
# per alias creation rule?
|
||||
raise SynapseError(403, "Not allowed to create alias")
|
||||
|
||||
can_create = yield self.can_modify_alias(room_alias, user_id=user_id)
|
||||
can_create = await self.can_modify_alias(room_alias, user_id=user_id)
|
||||
if not can_create:
|
||||
raise AuthError(
|
||||
400,
|
||||
|
@ -157,10 +156,9 @@ class DirectoryHandler(BaseHandler):
|
|||
errcode=Codes.EXCLUSIVE,
|
||||
)
|
||||
|
||||
yield self._create_association(room_alias, room_id, servers, creator=user_id)
|
||||
await self._create_association(room_alias, room_id, servers, creator=user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_association(self, requester: Requester, room_alias: RoomAlias):
|
||||
async def delete_association(self, requester: Requester, room_alias: RoomAlias):
|
||||
"""Remove an alias from the directory
|
||||
|
||||
(this is only meant for human users; AS users should call
|
||||
|
@ -184,7 +182,7 @@ class DirectoryHandler(BaseHandler):
|
|||
user_id = requester.user.to_string()
|
||||
|
||||
try:
|
||||
can_delete = yield self._user_can_delete_alias(room_alias, user_id)
|
||||
can_delete = await self._user_can_delete_alias(room_alias, user_id)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise NotFoundError("Unknown room alias")
|
||||
|
@ -193,7 +191,7 @@ class DirectoryHandler(BaseHandler):
|
|||
if not can_delete:
|
||||
raise AuthError(403, "You don't have permission to delete the alias.")
|
||||
|
||||
can_delete = yield self.can_modify_alias(room_alias, user_id=user_id)
|
||||
can_delete = await self.can_modify_alias(room_alias, user_id=user_id)
|
||||
if not can_delete:
|
||||
raise SynapseError(
|
||||
400,
|
||||
|
@ -201,10 +199,10 @@ class DirectoryHandler(BaseHandler):
|
|||
errcode=Codes.EXCLUSIVE,
|
||||
)
|
||||
|
||||
room_id = yield self._delete_association(room_alias)
|
||||
room_id = await self._delete_association(room_alias)
|
||||
|
||||
try:
|
||||
yield self._update_canonical_alias(requester, user_id, room_id, room_alias)
|
||||
await self._update_canonical_alias(requester, user_id, room_id, room_alias)
|
||||
except AuthError as e:
|
||||
logger.info("Failed to update alias events: %s", e)
|
||||
|
||||
|
@ -296,15 +294,14 @@ class DirectoryHandler(BaseHandler):
|
|||
Codes.NOT_FOUND,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_canonical_alias(
|
||||
async def _update_canonical_alias(
|
||||
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
|
||||
):
|
||||
"""
|
||||
Send an updated canonical alias event if the removed alias was set as
|
||||
the canonical alias or listed in the alt_aliases field.
|
||||
"""
|
||||
alias_event = yield self.state.get_current_state(
|
||||
alias_event = await self.state.get_current_state(
|
||||
room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
|
||||
|
@ -335,7 +332,7 @@ class DirectoryHandler(BaseHandler):
|
|||
del content["alt_aliases"]
|
||||
|
||||
if send_update:
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.CanonicalAlias,
|
||||
|
@ -376,8 +373,7 @@ class DirectoryHandler(BaseHandler):
|
|||
# either no interested services, or no service with an exclusive lock
|
||||
return defer.succeed(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
|
||||
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
|
||||
"""Determine whether a user can delete an alias.
|
||||
|
||||
One of the following must be true:
|
||||
|
@ -388,24 +384,23 @@ class DirectoryHandler(BaseHandler):
|
|||
for the current room.
|
||||
|
||||
"""
|
||||
creator = yield self.store.get_room_alias_creator(alias.to_string())
|
||||
creator = await self.store.get_room_alias_creator(alias.to_string())
|
||||
|
||||
if creator is not None and creator == user_id:
|
||||
return True
|
||||
|
||||
# Resolve the alias to the corresponding room.
|
||||
room_mapping = yield self.get_association(alias)
|
||||
room_mapping = await self.get_association(alias)
|
||||
room_id = room_mapping["room_id"]
|
||||
if not room_id:
|
||||
return False
|
||||
|
||||
res = yield self.auth.check_can_change_room_list(
|
||||
res = await self.auth.check_can_change_room_list(
|
||||
room_id, UserID.from_string(user_id)
|
||||
)
|
||||
return res
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def edit_published_room_list(
|
||||
async def edit_published_room_list(
|
||||
self, requester: Requester, room_id: str, visibility: str
|
||||
):
|
||||
"""Edit the entry of the room in the published room list.
|
||||
|
@ -433,11 +428,11 @@ class DirectoryHandler(BaseHandler):
|
|||
403, "This user is not permitted to publish rooms to the room list"
|
||||
)
|
||||
|
||||
room = yield self.store.get_room(room_id)
|
||||
room = await self.store.get_room(room_id)
|
||||
if room is None:
|
||||
raise SynapseError(400, "Unknown room")
|
||||
|
||||
can_change_room_list = yield self.auth.check_can_change_room_list(
|
||||
can_change_room_list = await self.auth.check_can_change_room_list(
|
||||
room_id, requester.user
|
||||
)
|
||||
if not can_change_room_list:
|
||||
|
@ -449,8 +444,8 @@ class DirectoryHandler(BaseHandler):
|
|||
|
||||
making_public = visibility == "public"
|
||||
if making_public:
|
||||
room_aliases = yield self.store.get_aliases_for_room(room_id)
|
||||
canonical_alias = yield self.store.get_canonical_alias_for_room(room_id)
|
||||
room_aliases = await self.store.get_aliases_for_room(room_id)
|
||||
canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
|
||||
if canonical_alias:
|
||||
room_aliases.append(canonical_alias)
|
||||
|
||||
|
@ -462,7 +457,7 @@ class DirectoryHandler(BaseHandler):
|
|||
# per alias creation rule?
|
||||
raise SynapseError(403, "Not allowed to publish room")
|
||||
|
||||
yield self.store.set_room_is_public(room_id, making_public)
|
||||
await self.store.set_room_is_public(room_id, making_public)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def edit_published_appservice_room_list(
|
||||
|
|
|
@ -2562,9 +2562,8 @@ class FederationHandler(BaseHandler):
|
|||
"missing": [e.event_id for e in missing_locals],
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def exchange_third_party_invite(
|
||||
async def exchange_third_party_invite(
|
||||
self, sender_user_id, target_user_id, room_id, signed
|
||||
):
|
||||
third_party_invite = {"signed": signed}
|
||||
|
@ -2580,16 +2579,16 @@ class FederationHandler(BaseHandler):
|
|||
"state_key": target_user_id,
|
||||
}
|
||||
|
||||
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
|
||||
room_version = yield self.store.get_room_version_id(room_id)
|
||||
if await self.auth.check_host_in_room(room_id, self.hs.hostname):
|
||||
room_version = await self.store.get_room_version_id(room_id)
|
||||
builder = self.event_builder_factory.new(room_version, event_dict)
|
||||
|
||||
EventValidator().validate_builder(builder)
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
event, context = await self.event_creation_handler.create_new_client_event(
|
||||
builder=builder
|
||||
)
|
||||
|
||||
event_allowed = yield self.third_party_event_rules.check_event_allowed(
|
||||
event_allowed = await self.third_party_event_rules.check_event_allowed(
|
||||
event, context
|
||||
)
|
||||
if not event_allowed:
|
||||
|
@ -2601,7 +2600,7 @@ class FederationHandler(BaseHandler):
|
|||
403, "This event is not allowed in this context", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
event, context = yield self.add_display_name_to_third_party_invite(
|
||||
event, context = await self.add_display_name_to_third_party_invite(
|
||||
room_version, event_dict, event, context
|
||||
)
|
||||
|
||||
|
@ -2612,19 +2611,19 @@ class FederationHandler(BaseHandler):
|
|||
event.internal_metadata.send_on_behalf_of = self.hs.hostname
|
||||
|
||||
try:
|
||||
yield self.auth.check_from_context(room_version, event, context)
|
||||
await self.auth.check_from_context(room_version, event, context)
|
||||
except AuthError as e:
|
||||
logger.warning("Denying new third party invite %r because %s", event, e)
|
||||
raise e
|
||||
|
||||
yield self._check_signature(event, context)
|
||||
await self._check_signature(event, context)
|
||||
|
||||
# We retrieve the room member handler here as to not cause a cyclic dependency
|
||||
member_handler = self.hs.get_room_member_handler()
|
||||
yield member_handler.send_membership_event(None, event, context)
|
||||
await member_handler.send_membership_event(None, event, context)
|
||||
else:
|
||||
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
|
||||
yield self.federation_client.forward_third_party_invite(
|
||||
await self.federation_client.forward_third_party_invite(
|
||||
destinations, room_id, event_dict
|
||||
)
|
||||
|
||||
|
|
|
@ -284,15 +284,14 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
|
||||
set_group_join_policy = _create_rerouter("set_group_join_policy")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_group(self, group_id, user_id, content):
|
||||
async def create_group(self, group_id, user_id, content):
|
||||
"""Create a group
|
||||
"""
|
||||
|
||||
logger.info("Asking to create group with ID: %r", group_id)
|
||||
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.create_group(
|
||||
res = await self.groups_server_handler.create_group(
|
||||
group_id, user_id, content
|
||||
)
|
||||
local_attestation = None
|
||||
|
@ -301,10 +300,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
content["attestation"] = local_attestation
|
||||
|
||||
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
|
||||
content["user_profile"] = await self.profile_handler.get_profile(user_id)
|
||||
|
||||
try:
|
||||
res = yield self.transport_client.create_group(
|
||||
res = await self.transport_client.create_group(
|
||||
get_domain_from_id(group_id), group_id, user_id, content
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
|
@ -313,7 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
raise SynapseError(502, "Failed to contact group server")
|
||||
|
||||
remote_attestation = res["attestation"]
|
||||
yield self.attestations.verify_attestation(
|
||||
await self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
|
@ -321,7 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
)
|
||||
|
||||
is_publicised = content.get("publicise", False)
|
||||
token = yield self.store.register_user_group_membership(
|
||||
token = await self.store.register_user_group_membership(
|
||||
group_id,
|
||||
user_id,
|
||||
membership="join",
|
||||
|
@ -482,12 +481,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
|
||||
return {"state": "invite", "user_profile": user_profile}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
||||
async def remove_user_from_group(
|
||||
self, group_id, user_id, requester_user_id, content
|
||||
):
|
||||
"""Remove a user from a group
|
||||
"""
|
||||
if user_id == requester_user_id:
|
||||
token = yield self.store.register_user_group_membership(
|
||||
token = await self.store.register_user_group_membership(
|
||||
group_id, user_id, membership="leave"
|
||||
)
|
||||
self.notifier.on_new_event("groups_key", token, users=[user_id])
|
||||
|
@ -496,13 +496,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
|||
# retry if the group server is currently down.
|
||||
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.remove_user_from_group(
|
||||
res = await self.groups_server_handler.remove_user_from_group(
|
||||
group_id, user_id, requester_user_id, content
|
||||
)
|
||||
else:
|
||||
content["requester_user_id"] = requester_user_id
|
||||
try:
|
||||
res = yield self.transport_client.remove_user_from_group(
|
||||
res = await self.transport_client.remove_user_from_group(
|
||||
get_domain_from_id(group_id),
|
||||
group_id,
|
||||
requester_user_id,
|
||||
|
|
|
@ -626,8 +626,7 @@ class EventCreationHandler(object):
|
|||
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
|
||||
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_nonmember_event(self, requester, event, context, ratelimit=True):
|
||||
async def send_nonmember_event(self, requester, event, context, ratelimit=True):
|
||||
"""
|
||||
Persists and notifies local clients and federation of an event.
|
||||
|
||||
|
@ -647,7 +646,7 @@ class EventCreationHandler(object):
|
|||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
|
||||
if event.is_state():
|
||||
prev_state = yield self.deduplicate_state_event(event, context)
|
||||
prev_state = await self.deduplicate_state_event(event, context)
|
||||
if prev_state is not None:
|
||||
logger.info(
|
||||
"Not bothering to persist state event %s duplicated by %s",
|
||||
|
@ -656,7 +655,7 @@ class EventCreationHandler(object):
|
|||
)
|
||||
return prev_state
|
||||
|
||||
yield self.handle_new_client_event(
|
||||
await self.handle_new_client_event(
|
||||
requester=requester, event=event, context=context, ratelimit=ratelimit
|
||||
)
|
||||
|
||||
|
@ -683,8 +682,7 @@ class EventCreationHandler(object):
|
|||
return prev_event
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_and_send_nonmember_event(
|
||||
async def create_and_send_nonmember_event(
|
||||
self, requester, event_dict, ratelimit=True, txn_id=None
|
||||
):
|
||||
"""
|
||||
|
@ -698,8 +696,8 @@ class EventCreationHandler(object):
|
|||
# a situation where event persistence can't keep up, causing
|
||||
# extremities to pile up, which in turn leads to state resolution
|
||||
# taking longer.
|
||||
with (yield self.limiter.queue(event_dict["room_id"])):
|
||||
event, context = yield self.create_event(
|
||||
with (await self.limiter.queue(event_dict["room_id"])):
|
||||
event, context = await self.create_event(
|
||||
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
|
||||
)
|
||||
|
||||
|
@ -709,7 +707,7 @@ class EventCreationHandler(object):
|
|||
spam_error = "Spam is not permitted here"
|
||||
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
|
||||
|
||||
yield self.send_nonmember_event(
|
||||
await self.send_nonmember_event(
|
||||
requester, event, context, ratelimit=ratelimit
|
||||
)
|
||||
return event
|
||||
|
@ -770,8 +768,7 @@ class EventCreationHandler(object):
|
|||
return (event, context)
|
||||
|
||||
@measure_func("handle_new_client_event")
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_client_event(
|
||||
async def handle_new_client_event(
|
||||
self, requester, event, context, ratelimit=True, extra_users=[]
|
||||
):
|
||||
"""Processes a new event. This includes checking auth, persisting it,
|
||||
|
@ -794,9 +791,9 @@ class EventCreationHandler(object):
|
|||
):
|
||||
room_version = event.content.get("room_version", RoomVersions.V1.identifier)
|
||||
else:
|
||||
room_version = yield self.store.get_room_version_id(event.room_id)
|
||||
room_version = await self.store.get_room_version_id(event.room_id)
|
||||
|
||||
event_allowed = yield self.third_party_event_rules.check_event_allowed(
|
||||
event_allowed = await self.third_party_event_rules.check_event_allowed(
|
||||
event, context
|
||||
)
|
||||
if not event_allowed:
|
||||
|
@ -805,7 +802,7 @@ class EventCreationHandler(object):
|
|||
)
|
||||
|
||||
try:
|
||||
yield self.auth.check_from_context(room_version, event, context)
|
||||
await self.auth.check_from_context(room_version, event, context)
|
||||
except AuthError as err:
|
||||
logger.warning("Denying new event %r because %s", event, err)
|
||||
raise err
|
||||
|
@ -818,7 +815,7 @@ class EventCreationHandler(object):
|
|||
logger.exception("Failed to encode content: %r", event.content)
|
||||
raise
|
||||
|
||||
yield self.action_generator.handle_push_actions_for_event(event, context)
|
||||
await self.action_generator.handle_push_actions_for_event(event, context)
|
||||
|
||||
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
|
||||
# hack around with a try/finally instead.
|
||||
|
@ -826,7 +823,7 @@ class EventCreationHandler(object):
|
|||
try:
|
||||
# If we're a worker we need to hit out to the master.
|
||||
if self.config.worker_app:
|
||||
yield self.send_event_to_master(
|
||||
await self.send_event_to_master(
|
||||
event_id=event.event_id,
|
||||
store=self.store,
|
||||
requester=requester,
|
||||
|
@ -838,7 +835,7 @@ class EventCreationHandler(object):
|
|||
success = True
|
||||
return
|
||||
|
||||
yield self.persist_and_notify_client_event(
|
||||
await self.persist_and_notify_client_event(
|
||||
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
|
||||
)
|
||||
|
||||
|
@ -883,8 +880,7 @@ class EventCreationHandler(object):
|
|||
Codes.BAD_ALIAS,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def persist_and_notify_client_event(
|
||||
async def persist_and_notify_client_event(
|
||||
self, requester, event, context, ratelimit=True, extra_users=[]
|
||||
):
|
||||
"""Called when we have fully built the event, have already
|
||||
|
@ -901,7 +897,7 @@ class EventCreationHandler(object):
|
|||
# user is actually admin or not).
|
||||
is_admin_redaction = False
|
||||
if event.type == EventTypes.Redaction:
|
||||
original_event = yield self.store.get_event(
|
||||
original_event = await self.store.get_event(
|
||||
event.redacts,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
get_prev_content=False,
|
||||
|
@ -913,11 +909,11 @@ class EventCreationHandler(object):
|
|||
original_event and event.sender != original_event.sender
|
||||
)
|
||||
|
||||
yield self.base_handler.ratelimit(
|
||||
await self.base_handler.ratelimit(
|
||||
requester, is_admin_redaction=is_admin_redaction
|
||||
)
|
||||
|
||||
yield self.base_handler.maybe_kick_guest_users(event, context)
|
||||
await self.base_handler.maybe_kick_guest_users(event, context)
|
||||
|
||||
if event.type == EventTypes.CanonicalAlias:
|
||||
# Validate a newly added alias or newly added alt_aliases.
|
||||
|
@ -927,7 +923,7 @@ class EventCreationHandler(object):
|
|||
|
||||
original_event_id = event.unsigned.get("replaces_state")
|
||||
if original_event_id:
|
||||
original_event = yield self.store.get_event(original_event_id)
|
||||
original_event = await self.store.get_event(original_event_id)
|
||||
|
||||
if original_event:
|
||||
original_alias = original_event.content.get("alias", None)
|
||||
|
@ -937,7 +933,7 @@ class EventCreationHandler(object):
|
|||
room_alias_str = event.content.get("alias", None)
|
||||
directory_handler = self.hs.get_handlers().directory_handler
|
||||
if room_alias_str and room_alias_str != original_alias:
|
||||
yield self._validate_canonical_alias(
|
||||
await self._validate_canonical_alias(
|
||||
directory_handler, room_alias_str, event.room_id
|
||||
)
|
||||
|
||||
|
@ -957,7 +953,7 @@ class EventCreationHandler(object):
|
|||
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
|
||||
if new_alt_aliases:
|
||||
for alias_str in new_alt_aliases:
|
||||
yield self._validate_canonical_alias(
|
||||
await self._validate_canonical_alias(
|
||||
directory_handler, alias_str, event.room_id
|
||||
)
|
||||
|
||||
|
@ -969,7 +965,7 @@ class EventCreationHandler(object):
|
|||
def is_inviter_member_event(e):
|
||||
return e.type == EventTypes.Member and e.sender == event.sender
|
||||
|
||||
current_state_ids = yield context.get_current_state_ids()
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
|
||||
state_to_include_ids = [
|
||||
e_id
|
||||
|
@ -978,7 +974,7 @@ class EventCreationHandler(object):
|
|||
or k == (EventTypes.Member, event.sender)
|
||||
]
|
||||
|
||||
state_to_include = yield self.store.get_events(state_to_include_ids)
|
||||
state_to_include = await self.store.get_events(state_to_include_ids)
|
||||
|
||||
event.unsigned["invite_room_state"] = [
|
||||
{
|
||||
|
@ -996,8 +992,8 @@ class EventCreationHandler(object):
|
|||
# way? If we have been invited by a remote server, we need
|
||||
# to get them to sign the event.
|
||||
|
||||
returned_invite = yield defer.ensureDeferred(
|
||||
federation_handler.send_invite(invitee.domain, event)
|
||||
returned_invite = await federation_handler.send_invite(
|
||||
invitee.domain, event
|
||||
)
|
||||
event.unsigned.pop("room_state", None)
|
||||
|
||||
|
@ -1005,7 +1001,7 @@ class EventCreationHandler(object):
|
|||
event.signatures.update(returned_invite.signatures)
|
||||
|
||||
if event.type == EventTypes.Redaction:
|
||||
original_event = yield self.store.get_event(
|
||||
original_event = await self.store.get_event(
|
||||
event.redacts,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
get_prev_content=False,
|
||||
|
@ -1021,14 +1017,14 @@ class EventCreationHandler(object):
|
|||
if original_event.room_id != event.room_id:
|
||||
raise SynapseError(400, "Cannot redact event from a different room")
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
auth_events_ids = yield self.auth.compute_auth_events(
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
auth_events_ids = await self.auth.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=True
|
||||
)
|
||||
auth_events = yield self.store.get_events(auth_events_ids)
|
||||
auth_events = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
||||
|
||||
room_version = yield self.store.get_room_version_id(event.room_id)
|
||||
room_version = await self.store.get_room_version_id(event.room_id)
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
|
||||
if event_auth.check_redaction(
|
||||
|
@ -1047,11 +1043,11 @@ class EventCreationHandler(object):
|
|||
event.internal_metadata.recheck_redaction = False
|
||||
|
||||
if event.type == EventTypes.Create:
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
if prev_state_ids:
|
||||
raise AuthError(403, "Changing the room create event is forbidden")
|
||||
|
||||
event_stream_id, max_stream_id = yield self.storage.persistence.persist_event(
|
||||
event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
|
||||
event, context=context
|
||||
)
|
||||
|
||||
|
@ -1059,7 +1055,7 @@ class EventCreationHandler(object):
|
|||
# If there's an expiry timestamp on the event, schedule its expiry.
|
||||
self._message_handler.maybe_schedule_expiry(event)
|
||||
|
||||
yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
|
||||
await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
|
||||
|
||||
def _notify():
|
||||
try:
|
||||
|
@ -1083,13 +1079,12 @@ class EventCreationHandler(object):
|
|||
except Exception:
|
||||
logger.exception("Error bumping presence active time")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_dummy_events_to_fill_extremities(self):
|
||||
async def _send_dummy_events_to_fill_extremities(self):
|
||||
"""Background task to send dummy events into rooms that have a large
|
||||
number of extremities
|
||||
"""
|
||||
self._expire_rooms_to_exclude_from_dummy_event_insertion()
|
||||
room_ids = yield self.store.get_rooms_with_many_extremities(
|
||||
room_ids = await self.store.get_rooms_with_many_extremities(
|
||||
min_count=10,
|
||||
limit=5,
|
||||
room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
|
||||
|
@ -1099,9 +1094,9 @@ class EventCreationHandler(object):
|
|||
# For each room we need to find a joined member we can use to send
|
||||
# the dummy event with.
|
||||
|
||||
latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
|
||||
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
|
||||
|
||||
members = yield self.state.get_current_users_in_room(
|
||||
members = await self.state.get_current_users_in_room(
|
||||
room_id, latest_event_ids=latest_event_ids
|
||||
)
|
||||
dummy_event_sent = False
|
||||
|
@ -1110,7 +1105,7 @@ class EventCreationHandler(object):
|
|||
continue
|
||||
requester = create_requester(user_id)
|
||||
try:
|
||||
event, context = yield self.create_event(
|
||||
event, context = await self.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": "org.matrix.dummy_event",
|
||||
|
@ -1123,7 +1118,7 @@ class EventCreationHandler(object):
|
|||
|
||||
event.internal_metadata.proactively_send = False
|
||||
|
||||
yield self.send_nonmember_event(
|
||||
await self.send_nonmember_event(
|
||||
requester, event, context, ratelimit=False
|
||||
)
|
||||
dummy_event_sent = True
|
||||
|
|
|
@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler):
|
|||
|
||||
return result["displayname"]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
|
||||
async def set_displayname(
|
||||
self, target_user, requester, new_displayname, by_admin=False
|
||||
):
|
||||
"""Set the displayname of a user
|
||||
|
||||
Args:
|
||||
|
@ -158,7 +159,7 @@ class BaseProfileHandler(BaseHandler):
|
|||
raise AuthError(400, "Cannot set another user's displayname")
|
||||
|
||||
if not by_admin and not self.hs.config.enable_set_displayname:
|
||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
if profile.display_name:
|
||||
raise SynapseError(
|
||||
400,
|
||||
|
@ -180,15 +181,15 @@ class BaseProfileHandler(BaseHandler):
|
|||
if by_admin:
|
||||
requester = create_requester(target_user)
|
||||
|
||||
yield self.store.set_profile_displayname(target_user.localpart, new_displayname)
|
||||
await self.store.set_profile_displayname(target_user.localpart, new_displayname)
|
||||
|
||||
if self.hs.config.user_directory_search_all_users:
|
||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||
yield self.user_directory_handler.handle_local_profile_change(
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
await self.user_directory_handler.handle_local_profile_change(
|
||||
target_user.to_string(), profile
|
||||
)
|
||||
|
||||
yield self._update_join_states(requester, target_user)
|
||||
await self._update_join_states(requester, target_user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_avatar_url(self, target_user):
|
||||
|
@ -217,8 +218,9 @@ class BaseProfileHandler(BaseHandler):
|
|||
|
||||
return result["avatar_url"]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
|
||||
async def set_avatar_url(
|
||||
self, target_user, requester, new_avatar_url, by_admin=False
|
||||
):
|
||||
"""target_user is the user whose avatar_url is to be changed;
|
||||
auth_user is the user attempting to make this change."""
|
||||
if not self.hs.is_mine(target_user):
|
||||
|
@ -228,7 +230,7 @@ class BaseProfileHandler(BaseHandler):
|
|||
raise AuthError(400, "Cannot set another user's avatar_url")
|
||||
|
||||
if not by_admin and not self.hs.config.enable_set_avatar_url:
|
||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
if profile.avatar_url:
|
||||
raise SynapseError(
|
||||
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
|
||||
|
@ -243,15 +245,15 @@ class BaseProfileHandler(BaseHandler):
|
|||
if by_admin:
|
||||
requester = create_requester(target_user)
|
||||
|
||||
yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
|
||||
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
|
||||
|
||||
if self.hs.config.user_directory_search_all_users:
|
||||
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||
yield self.user_directory_handler.handle_local_profile_change(
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
await self.user_directory_handler.handle_local_profile_change(
|
||||
target_user.to_string(), profile
|
||||
)
|
||||
|
||||
yield self._update_join_states(requester, target_user)
|
||||
await self._update_join_states(requester, target_user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_profile_query(self, args):
|
||||
|
@ -279,21 +281,20 @@ class BaseProfileHandler(BaseHandler):
|
|||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_join_states(self, requester, target_user):
|
||||
async def _update_join_states(self, requester, target_user):
|
||||
if not self.hs.is_mine(target_user):
|
||||
return
|
||||
|
||||
yield self.ratelimit(requester)
|
||||
await self.ratelimit(requester)
|
||||
|
||||
room_ids = yield self.store.get_rooms_for_user(target_user.to_string())
|
||||
room_ids = await self.store.get_rooms_for_user(target_user.to_string())
|
||||
|
||||
for room_id in room_ids:
|
||||
handler = self.hs.get_room_member_handler()
|
||||
try:
|
||||
# Assume the target_user isn't a guest,
|
||||
# because we don't let guests set profile or avatar data.
|
||||
yield handler.update_membership(
|
||||
await handler.update_membership(
|
||||
requester,
|
||||
target_user,
|
||||
room_id,
|
||||
|
|
|
@ -145,9 +145,9 @@ class RegistrationHandler(BaseHandler):
|
|||
"""Registers a new client on the server.
|
||||
|
||||
Args:
|
||||
localpart : The local part of the user ID to register. If None,
|
||||
localpart: The local part of the user ID to register. If None,
|
||||
one will be generated.
|
||||
password (unicode) : The password to assign to this user so they can
|
||||
password (unicode): The password to assign to this user so they can
|
||||
login again. This can be None which means they cannot login again
|
||||
via a password (e.g. the user is an application service user).
|
||||
user_type (str|None): type of user. One of the values from
|
||||
|
@ -244,7 +244,7 @@ class RegistrationHandler(BaseHandler):
|
|||
fail_count += 1
|
||||
|
||||
if not self.hs.config.user_consent_at_registration:
|
||||
yield self._auto_join_rooms(user_id)
|
||||
yield defer.ensureDeferred(self._auto_join_rooms(user_id))
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping auto-join for %s because consent is required at registration",
|
||||
|
@ -266,8 +266,7 @@ class RegistrationHandler(BaseHandler):
|
|||
|
||||
return user_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _auto_join_rooms(self, user_id):
|
||||
async def _auto_join_rooms(self, user_id):
|
||||
"""Automatically joins users to auto join rooms - creating the room in the first place
|
||||
if the user is the first to be created.
|
||||
|
||||
|
@ -281,9 +280,9 @@ class RegistrationHandler(BaseHandler):
|
|||
# that an auto-generated support or bot user is not a real user and will never be
|
||||
# the user to create the room
|
||||
should_auto_create_rooms = False
|
||||
is_real_user = yield self.store.is_real_user(user_id)
|
||||
is_real_user = await self.store.is_real_user(user_id)
|
||||
if self.hs.config.autocreate_auto_join_rooms and is_real_user:
|
||||
count = yield self.store.count_real_users()
|
||||
count = await self.store.count_real_users()
|
||||
should_auto_create_rooms = count == 1
|
||||
for r in self.hs.config.auto_join_rooms:
|
||||
logger.info("Auto-joining %s to %s", user_id, r)
|
||||
|
@ -302,7 +301,7 @@ class RegistrationHandler(BaseHandler):
|
|||
|
||||
# getting the RoomCreationHandler during init gives a dependency
|
||||
# loop
|
||||
yield self.hs.get_room_creation_handler().create_room(
|
||||
await self.hs.get_room_creation_handler().create_room(
|
||||
fake_requester,
|
||||
config={
|
||||
"preset": "public_chat",
|
||||
|
@ -311,7 +310,7 @@ class RegistrationHandler(BaseHandler):
|
|||
ratelimit=False,
|
||||
)
|
||||
else:
|
||||
yield self._join_user_to_room(fake_requester, r)
|
||||
await self._join_user_to_room(fake_requester, r)
|
||||
except ConsentNotGivenError as e:
|
||||
# Technically not necessary to pull out this error though
|
||||
# moving away from bare excepts is a good thing to do.
|
||||
|
@ -319,15 +318,14 @@ class RegistrationHandler(BaseHandler):
|
|||
except Exception as e:
|
||||
logger.error("Failed to join new user to %r: %r", r, e)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_consent_actions(self, user_id):
|
||||
async def post_consent_actions(self, user_id):
|
||||
"""A series of registration actions that can only be carried out once consent
|
||||
has been granted
|
||||
|
||||
Args:
|
||||
user_id (str): The user to join
|
||||
"""
|
||||
yield self._auto_join_rooms(user_id)
|
||||
await self._auto_join_rooms(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def appservice_register(self, user_localpart, as_token):
|
||||
|
@ -394,14 +392,13 @@ class RegistrationHandler(BaseHandler):
|
|||
self._next_generated_user_id += 1
|
||||
return str(id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _join_user_to_room(self, requester, room_identifier):
|
||||
async def _join_user_to_room(self, requester, room_identifier):
|
||||
room_member_handler = self.hs.get_room_member_handler()
|
||||
if RoomID.is_valid(room_identifier):
|
||||
room_id = room_identifier
|
||||
elif RoomAlias.is_valid(room_identifier):
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
|
||||
room_id, remote_room_hosts = await room_member_handler.lookup_room_alias(
|
||||
room_alias
|
||||
)
|
||||
room_id = room_id.to_string()
|
||||
|
@ -410,7 +407,7 @@ class RegistrationHandler(BaseHandler):
|
|||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||
)
|
||||
|
||||
yield room_member_handler.update_membership(
|
||||
await room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
target=requester.user,
|
||||
room_id=room_id,
|
||||
|
@ -550,8 +547,7 @@ class RegistrationHandler(BaseHandler):
|
|||
|
||||
return (device_id, access_token)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_registration_actions(self, user_id, auth_result, access_token):
|
||||
async def post_registration_actions(self, user_id, auth_result, access_token):
|
||||
"""A user has completed registration
|
||||
|
||||
Args:
|
||||
|
@ -562,7 +558,7 @@ class RegistrationHandler(BaseHandler):
|
|||
device, or None if `inhibit_login` enabled.
|
||||
"""
|
||||
if self.hs.config.worker_app:
|
||||
yield self._post_registration_client(
|
||||
await self._post_registration_client(
|
||||
user_id=user_id, auth_result=auth_result, access_token=access_token
|
||||
)
|
||||
return
|
||||
|
@ -574,19 +570,18 @@ class RegistrationHandler(BaseHandler):
|
|||
if is_threepid_reserved(
|
||||
self.hs.config.mau_limits_reserved_threepids, threepid
|
||||
):
|
||||
yield self.store.upsert_monthly_active_user(user_id)
|
||||
await self.store.upsert_monthly_active_user(user_id)
|
||||
|
||||
yield self._register_email_threepid(user_id, threepid, access_token)
|
||||
await self._register_email_threepid(user_id, threepid, access_token)
|
||||
|
||||
if auth_result and LoginType.MSISDN in auth_result:
|
||||
threepid = auth_result[LoginType.MSISDN]
|
||||
yield self._register_msisdn_threepid(user_id, threepid)
|
||||
await self._register_msisdn_threepid(user_id, threepid)
|
||||
|
||||
if auth_result and LoginType.TERMS in auth_result:
|
||||
yield self._on_user_consented(user_id, self.hs.config.user_consent_version)
|
||||
await self._on_user_consented(user_id, self.hs.config.user_consent_version)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _on_user_consented(self, user_id, consent_version):
|
||||
async def _on_user_consented(self, user_id, consent_version):
|
||||
"""A user consented to the terms on registration
|
||||
|
||||
Args:
|
||||
|
@ -595,8 +590,8 @@ class RegistrationHandler(BaseHandler):
|
|||
consented to.
|
||||
"""
|
||||
logger.info("%s has consented to the privacy policy", user_id)
|
||||
yield self.store.user_set_consent_version(user_id, consent_version)
|
||||
yield self.post_consent_actions(user_id)
|
||||
await self.store.user_set_consent_version(user_id, consent_version)
|
||||
await self.post_consent_actions(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _register_email_threepid(self, user_id, threepid, token):
|
||||
|
|
|
@ -148,17 +148,16 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _upgrade_room(
|
||||
async def _upgrade_room(
|
||||
self, requester: Requester, old_room_id: str, new_version: RoomVersion
|
||||
):
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
# start by allocating a new room id
|
||||
r = yield self.store.get_room(old_room_id)
|
||||
r = await self.store.get_room(old_room_id)
|
||||
if r is None:
|
||||
raise NotFoundError("Unknown room id %s" % (old_room_id,))
|
||||
new_room_id = yield self._generate_room_id(
|
||||
new_room_id = await self._generate_room_id(
|
||||
creator_id=user_id, is_public=r["is_public"], room_version=new_version,
|
||||
)
|
||||
|
||||
|
@ -169,7 +168,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
(
|
||||
tombstone_event,
|
||||
tombstone_context,
|
||||
) = yield self.event_creation_handler.create_event(
|
||||
) = await self.event_creation_handler.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Tombstone,
|
||||
|
@ -183,12 +182,12 @@ class RoomCreationHandler(BaseHandler):
|
|||
},
|
||||
token_id=requester.access_token_id,
|
||||
)
|
||||
old_room_version = yield self.store.get_room_version_id(old_room_id)
|
||||
yield self.auth.check_from_context(
|
||||
old_room_version = await self.store.get_room_version_id(old_room_id)
|
||||
await self.auth.check_from_context(
|
||||
old_room_version, tombstone_event, tombstone_context
|
||||
)
|
||||
|
||||
yield self.clone_existing_room(
|
||||
await self.clone_existing_room(
|
||||
requester,
|
||||
old_room_id=old_room_id,
|
||||
new_room_id=new_room_id,
|
||||
|
@ -197,32 +196,31 @@ class RoomCreationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
# now send the tombstone
|
||||
yield self.event_creation_handler.send_nonmember_event(
|
||||
await self.event_creation_handler.send_nonmember_event(
|
||||
requester, tombstone_event, tombstone_context
|
||||
)
|
||||
|
||||
old_room_state = yield tombstone_context.get_current_state_ids()
|
||||
old_room_state = await tombstone_context.get_current_state_ids()
|
||||
|
||||
# update any aliases
|
||||
yield self._move_aliases_to_new_room(
|
||||
await self._move_aliases_to_new_room(
|
||||
requester, old_room_id, new_room_id, old_room_state
|
||||
)
|
||||
|
||||
# Copy over user push rules, tags and migrate room directory state
|
||||
yield self.room_member_handler.transfer_room_state_on_room_upgrade(
|
||||
await self.room_member_handler.transfer_room_state_on_room_upgrade(
|
||||
old_room_id, new_room_id
|
||||
)
|
||||
|
||||
# finally, shut down the PLs in the old room, and update them in the new
|
||||
# room.
|
||||
yield self._update_upgraded_room_pls(
|
||||
await self._update_upgraded_room_pls(
|
||||
requester, old_room_id, new_room_id, old_room_state,
|
||||
)
|
||||
|
||||
return new_room_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_upgraded_room_pls(
|
||||
async def _update_upgraded_room_pls(
|
||||
self,
|
||||
requester: Requester,
|
||||
old_room_id: str,
|
||||
|
@ -249,7 +247,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
)
|
||||
return
|
||||
|
||||
old_room_pl_state = yield self.store.get_event(old_room_pl_event_id)
|
||||
old_room_pl_state = await self.store.get_event(old_room_pl_event_id)
|
||||
|
||||
# we try to stop regular users from speaking by setting the PL required
|
||||
# to send regular events and invites to 'Moderator' level. That's normally
|
||||
|
@ -278,7 +276,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
if updated:
|
||||
try:
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.PowerLevels,
|
||||
|
@ -292,7 +290,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
except AuthError as e:
|
||||
logger.warning("Unable to update PLs in old room: %s", e)
|
||||
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.PowerLevels,
|
||||
|
@ -304,8 +302,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
ratelimit=False,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def clone_existing_room(
|
||||
async def clone_existing_room(
|
||||
self,
|
||||
requester: Requester,
|
||||
old_room_id: str,
|
||||
|
@ -338,7 +335,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
# Check if old room was non-federatable
|
||||
|
||||
# Get old room's create event
|
||||
old_room_create_event = yield self.store.get_create_event_for_room(old_room_id)
|
||||
old_room_create_event = await self.store.get_create_event_for_room(old_room_id)
|
||||
|
||||
# Check if the create event specified a non-federatable room
|
||||
if not old_room_create_event.content.get("m.federate", True):
|
||||
|
@ -361,11 +358,11 @@ class RoomCreationHandler(BaseHandler):
|
|||
(EventTypes.PowerLevels, ""),
|
||||
)
|
||||
|
||||
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
|
||||
old_room_state_ids = await self.store.get_filtered_current_state_ids(
|
||||
old_room_id, StateFilter.from_types(types_to_copy)
|
||||
)
|
||||
# map from event_id to BaseEvent
|
||||
old_room_state_events = yield self.store.get_events(old_room_state_ids.values())
|
||||
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
|
||||
|
||||
for k, old_event_id in iteritems(old_room_state_ids):
|
||||
old_event = old_room_state_events.get(old_event_id)
|
||||
|
@ -400,7 +397,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
if current_power_level < needed_power_level:
|
||||
power_levels["users"][user_id] = needed_power_level
|
||||
|
||||
yield self._send_events_for_new_room(
|
||||
await self._send_events_for_new_room(
|
||||
requester,
|
||||
new_room_id,
|
||||
# we expect to override all the presets with initial_state, so this is
|
||||
|
@ -412,12 +409,12 @@ class RoomCreationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
# Transfer membership events
|
||||
old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
|
||||
old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
|
||||
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
|
||||
)
|
||||
|
||||
# map from event_id to BaseEvent
|
||||
old_room_member_state_events = yield self.store.get_events(
|
||||
old_room_member_state_events = await self.store.get_events(
|
||||
old_room_member_state_ids.values()
|
||||
)
|
||||
for k, old_event in iteritems(old_room_member_state_events):
|
||||
|
@ -426,7 +423,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
"membership" in old_event.content
|
||||
and old_event.content["membership"] == "ban"
|
||||
):
|
||||
yield self.room_member_handler.update_membership(
|
||||
await self.room_member_handler.update_membership(
|
||||
requester,
|
||||
UserID.from_string(old_event["state_key"]),
|
||||
new_room_id,
|
||||
|
@ -438,8 +435,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
# XXX invites/joins
|
||||
# XXX 3pid invites
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _move_aliases_to_new_room(
|
||||
async def _move_aliases_to_new_room(
|
||||
self,
|
||||
requester: Requester,
|
||||
old_room_id: str,
|
||||
|
@ -448,13 +444,13 @@ class RoomCreationHandler(BaseHandler):
|
|||
):
|
||||
directory_handler = self.hs.get_handlers().directory_handler
|
||||
|
||||
aliases = yield self.store.get_aliases_for_room(old_room_id)
|
||||
aliases = await self.store.get_aliases_for_room(old_room_id)
|
||||
|
||||
# check to see if we have a canonical alias.
|
||||
canonical_alias_event = None
|
||||
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
|
||||
if canonical_alias_event_id:
|
||||
canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
|
||||
canonical_alias_event = await self.store.get_event(canonical_alias_event_id)
|
||||
|
||||
# first we try to remove the aliases from the old room (we suppress sending
|
||||
# the room_aliases event until the end).
|
||||
|
@ -472,7 +468,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
for alias_str in aliases:
|
||||
alias = RoomAlias.from_string(alias_str)
|
||||
try:
|
||||
yield directory_handler.delete_association(requester, alias)
|
||||
await directory_handler.delete_association(requester, alias)
|
||||
removed_aliases.append(alias_str)
|
||||
except SynapseError as e:
|
||||
logger.warning("Unable to remove alias %s from old room: %s", alias, e)
|
||||
|
@ -485,7 +481,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
# we can now add any aliases we successfully removed to the new room.
|
||||
for alias in removed_aliases:
|
||||
try:
|
||||
yield directory_handler.create_association(
|
||||
await directory_handler.create_association(
|
||||
requester,
|
||||
RoomAlias.from_string(alias),
|
||||
new_room_id,
|
||||
|
@ -502,7 +498,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
# alias event for the new room with a copy of the information.
|
||||
try:
|
||||
if canonical_alias_event:
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.CanonicalAlias,
|
||||
|
@ -518,8 +514,9 @@ class RoomCreationHandler(BaseHandler):
|
|||
# we returned the new room to the client at this point.
|
||||
logger.error("Unable to send updated alias events in new room: %s", e)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_room(self, requester, config, ratelimit=True, creator_join_profile=None):
|
||||
async def create_room(
|
||||
self, requester, config, ratelimit=True, creator_join_profile=None
|
||||
):
|
||||
""" Creates a new room.
|
||||
|
||||
Args:
|
||||
|
@ -547,7 +544,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
yield self.auth.check_auth_blocking(user_id)
|
||||
await self.auth.check_auth_blocking(user_id)
|
||||
|
||||
if (
|
||||
self._server_notices_mxid is not None
|
||||
|
@ -556,11 +553,11 @@ class RoomCreationHandler(BaseHandler):
|
|||
# allow the server notices mxid to create rooms
|
||||
is_requester_admin = True
|
||||
else:
|
||||
is_requester_admin = yield self.auth.is_server_admin(requester.user)
|
||||
is_requester_admin = await self.auth.is_server_admin(requester.user)
|
||||
|
||||
# Check whether the third party rules allows/changes the room create
|
||||
# request.
|
||||
event_allowed = yield self.third_party_event_rules.on_create_room(
|
||||
event_allowed = await self.third_party_event_rules.on_create_room(
|
||||
requester, config, is_requester_admin=is_requester_admin
|
||||
)
|
||||
if not event_allowed:
|
||||
|
@ -574,7 +571,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
raise SynapseError(403, "You are not permitted to create rooms")
|
||||
|
||||
if ratelimit:
|
||||
yield self.ratelimit(requester)
|
||||
await self.ratelimit(requester)
|
||||
|
||||
room_version_id = config.get(
|
||||
"room_version", self.config.default_room_version.identifier
|
||||
|
@ -597,7 +594,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
raise SynapseError(400, "Invalid characters in room alias")
|
||||
|
||||
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
|
||||
mapping = yield self.store.get_association_from_room_alias(room_alias)
|
||||
mapping = await self.store.get_association_from_room_alias(room_alias)
|
||||
|
||||
if mapping:
|
||||
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
|
||||
|
@ -612,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
except Exception:
|
||||
raise SynapseError(400, "Invalid user_id: %s" % (i,))
|
||||
|
||||
yield self.event_creation_handler.assert_accepted_privacy_policy(requester)
|
||||
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
|
||||
|
||||
power_level_content_override = config.get("power_level_content_override")
|
||||
if (
|
||||
|
@ -631,13 +628,13 @@ class RoomCreationHandler(BaseHandler):
|
|||
visibility = config.get("visibility", None)
|
||||
is_public = visibility == "public"
|
||||
|
||||
room_id = yield self._generate_room_id(
|
||||
room_id = await self._generate_room_id(
|
||||
creator_id=user_id, is_public=is_public, room_version=room_version,
|
||||
)
|
||||
|
||||
directory_handler = self.hs.get_handlers().directory_handler
|
||||
if room_alias:
|
||||
yield directory_handler.create_association(
|
||||
await directory_handler.create_association(
|
||||
requester=requester,
|
||||
room_id=room_id,
|
||||
room_alias=room_alias,
|
||||
|
@ -670,7 +667,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
# override any attempt to set room versions via the creation_content
|
||||
creation_content["room_version"] = room_version.identifier
|
||||
|
||||
yield self._send_events_for_new_room(
|
||||
await self._send_events_for_new_room(
|
||||
requester,
|
||||
room_id,
|
||||
preset_config=preset_config,
|
||||
|
@ -684,7 +681,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
if "name" in config:
|
||||
name = config["name"]
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Name,
|
||||
|
@ -698,7 +695,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
if "topic" in config:
|
||||
topic = config["topic"]
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Topic,
|
||||
|
@ -716,7 +713,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
if is_direct:
|
||||
content["is_direct"] = is_direct
|
||||
|
||||
yield self.room_member_handler.update_membership(
|
||||
await self.room_member_handler.update_membership(
|
||||
requester,
|
||||
UserID.from_string(invitee),
|
||||
room_id,
|
||||
|
@ -730,7 +727,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
id_access_token = invite_3pid.get("id_access_token") # optional
|
||||
address = invite_3pid["address"]
|
||||
medium = invite_3pid["medium"]
|
||||
yield self.hs.get_room_member_handler().do_3pid_invite(
|
||||
await self.hs.get_room_member_handler().do_3pid_invite(
|
||||
room_id,
|
||||
requester.user,
|
||||
medium,
|
||||
|
@ -748,8 +745,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_events_for_new_room(
|
||||
async def _send_events_for_new_room(
|
||||
self,
|
||||
creator, # A Requester object.
|
||||
room_id,
|
||||
|
@ -769,11 +765,10 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
return e
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send(etype, content, **kwargs):
|
||||
async def send(etype, content, **kwargs):
|
||||
event = create(etype, content, **kwargs)
|
||||
logger.debug("Sending %s in new room", etype)
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
creator, event, ratelimit=False
|
||||
)
|
||||
|
||||
|
@ -784,10 +779,10 @@ class RoomCreationHandler(BaseHandler):
|
|||
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
|
||||
|
||||
creation_content.update({"creator": creator_id})
|
||||
yield send(etype=EventTypes.Create, content=creation_content)
|
||||
await send(etype=EventTypes.Create, content=creation_content)
|
||||
|
||||
logger.debug("Sending %s in new room", EventTypes.Member)
|
||||
yield self.room_member_handler.update_membership(
|
||||
await self.room_member_handler.update_membership(
|
||||
creator,
|
||||
creator.user,
|
||||
room_id,
|
||||
|
@ -800,7 +795,7 @@ class RoomCreationHandler(BaseHandler):
|
|||
# of the first events that get sent into a room.
|
||||
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
|
||||
if pl_content is not None:
|
||||
yield send(etype=EventTypes.PowerLevels, content=pl_content)
|
||||
await send(etype=EventTypes.PowerLevels, content=pl_content)
|
||||
else:
|
||||
power_level_content = {
|
||||
"users": {creator_id: 100},
|
||||
|
@ -833,33 +828,33 @@ class RoomCreationHandler(BaseHandler):
|
|||
if power_level_content_override:
|
||||
power_level_content.update(power_level_content_override)
|
||||
|
||||
yield send(etype=EventTypes.PowerLevels, content=power_level_content)
|
||||
await send(etype=EventTypes.PowerLevels, content=power_level_content)
|
||||
|
||||
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
|
||||
yield send(
|
||||
await send(
|
||||
etype=EventTypes.CanonicalAlias,
|
||||
content={"alias": room_alias.to_string()},
|
||||
)
|
||||
|
||||
if (EventTypes.JoinRules, "") not in initial_state:
|
||||
yield send(
|
||||
await send(
|
||||
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
|
||||
)
|
||||
|
||||
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
|
||||
yield send(
|
||||
await send(
|
||||
etype=EventTypes.RoomHistoryVisibility,
|
||||
content={"history_visibility": config["history_visibility"]},
|
||||
)
|
||||
|
||||
if config["guest_can_join"]:
|
||||
if (EventTypes.GuestAccess, "") not in initial_state:
|
||||
yield send(
|
||||
await send(
|
||||
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
|
||||
)
|
||||
|
||||
for (etype, state_key), content in initial_state.items():
|
||||
yield send(etype=etype, state_key=state_key, content=content)
|
||||
await send(etype=etype, state_key=state_key, content=content)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_room_id(
|
||||
|
|
|
@ -142,8 +142,7 @@ class RoomMemberHandler(object):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _local_membership_update(
|
||||
async def _local_membership_update(
|
||||
self,
|
||||
requester,
|
||||
target,
|
||||
|
@ -164,7 +163,7 @@ class RoomMemberHandler(object):
|
|||
if requester.is_guest:
|
||||
content["kind"] = "guest"
|
||||
|
||||
event, context = yield self.event_creation_handler.create_event(
|
||||
event, context = await self.event_creation_handler.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
|
@ -182,18 +181,18 @@ class RoomMemberHandler(object):
|
|||
)
|
||||
|
||||
# Check if this event matches the previous membership event for the user.
|
||||
duplicate = yield self.event_creation_handler.deduplicate_state_event(
|
||||
duplicate = await self.event_creation_handler.deduplicate_state_event(
|
||||
event, context
|
||||
)
|
||||
if duplicate is not None:
|
||||
# Discard the new event since this membership change is a no-op.
|
||||
return duplicate
|
||||
|
||||
yield self.event_creation_handler.handle_new_client_event(
|
||||
await self.event_creation_handler.handle_new_client_event(
|
||||
requester, event, context, extra_users=[target], ratelimit=ratelimit
|
||||
)
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
|
||||
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
||||
|
||||
|
@ -203,15 +202,15 @@ class RoomMemberHandler(object):
|
|||
# info.
|
||||
newly_joined = True
|
||||
if prev_member_event_id:
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
if newly_joined:
|
||||
yield self._user_joined_room(target, room_id)
|
||||
await self._user_joined_room(target, room_id)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
if prev_member_event_id:
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
if prev_member_event.membership == Membership.JOIN:
|
||||
yield self._user_left_room(target, room_id)
|
||||
await self._user_left_room(target, room_id)
|
||||
|
||||
return event
|
||||
|
||||
|
@ -253,8 +252,7 @@ class RoomMemberHandler(object):
|
|||
for tag, tag_content in room_tags.items():
|
||||
yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_membership(
|
||||
async def update_membership(
|
||||
self,
|
||||
requester,
|
||||
target,
|
||||
|
@ -269,8 +267,8 @@ class RoomMemberHandler(object):
|
|||
):
|
||||
key = (room_id,)
|
||||
|
||||
with (yield self.member_linearizer.queue(key)):
|
||||
result = yield self._update_membership(
|
||||
with (await self.member_linearizer.queue(key)):
|
||||
result = await self._update_membership(
|
||||
requester,
|
||||
target,
|
||||
room_id,
|
||||
|
@ -285,8 +283,7 @@ class RoomMemberHandler(object):
|
|||
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_membership(
|
||||
async def _update_membership(
|
||||
self,
|
||||
requester,
|
||||
target,
|
||||
|
@ -321,7 +318,7 @@ class RoomMemberHandler(object):
|
|||
# if this is a join with a 3pid signature, we may need to turn a 3pid
|
||||
# invite into a normal invite before we can handle the join.
|
||||
if third_party_signed is not None:
|
||||
yield self.federation_handler.exchange_third_party_invite(
|
||||
await self.federation_handler.exchange_third_party_invite(
|
||||
third_party_signed["sender"],
|
||||
target.to_string(),
|
||||
room_id,
|
||||
|
@ -332,7 +329,7 @@ class RoomMemberHandler(object):
|
|||
remote_room_hosts = []
|
||||
|
||||
if effective_membership_state not in ("leave", "ban"):
|
||||
is_blocked = yield self.store.is_room_blocked(room_id)
|
||||
is_blocked = await self.store.is_room_blocked(room_id)
|
||||
if is_blocked:
|
||||
raise SynapseError(403, "This room has been blocked on this server")
|
||||
|
||||
|
@ -351,7 +348,7 @@ class RoomMemberHandler(object):
|
|||
is_requester_admin = True
|
||||
|
||||
else:
|
||||
is_requester_admin = yield self.auth.is_server_admin(requester.user)
|
||||
is_requester_admin = await self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_requester_admin:
|
||||
if self.config.block_non_admin_invites:
|
||||
|
@ -370,9 +367,9 @@ class RoomMemberHandler(object):
|
|||
if block_invite:
|
||||
raise SynapseError(403, "Invites have been disabled on this server")
|
||||
|
||||
latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
|
||||
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
|
||||
|
||||
current_state_ids = yield self.state_handler.get_current_state_ids(
|
||||
current_state_ids = await self.state_handler.get_current_state_ids(
|
||||
room_id, latest_event_ids=latest_event_ids
|
||||
)
|
||||
|
||||
|
@ -381,7 +378,7 @@ class RoomMemberHandler(object):
|
|||
# transitions and generic otherwise
|
||||
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
|
||||
if old_state_id:
|
||||
old_state = yield self.store.get_event(old_state_id, allow_none=True)
|
||||
old_state = await self.store.get_event(old_state_id, allow_none=True)
|
||||
old_membership = old_state.content.get("membership") if old_state else None
|
||||
if action == "unban" and old_membership != "ban":
|
||||
raise SynapseError(
|
||||
|
@ -413,7 +410,7 @@ class RoomMemberHandler(object):
|
|||
old_membership == Membership.INVITE
|
||||
and effective_membership_state == Membership.LEAVE
|
||||
):
|
||||
is_blocked = yield self._is_server_notice_room(room_id)
|
||||
is_blocked = await self._is_server_notice_room(room_id)
|
||||
if is_blocked:
|
||||
raise SynapseError(
|
||||
http_client.FORBIDDEN,
|
||||
|
@ -424,18 +421,18 @@ class RoomMemberHandler(object):
|
|||
if action == "kick":
|
||||
raise AuthError(403, "The target user is not in the room")
|
||||
|
||||
is_host_in_room = yield self._is_host_in_room(current_state_ids)
|
||||
is_host_in_room = await self._is_host_in_room(current_state_ids)
|
||||
|
||||
if effective_membership_state == Membership.JOIN:
|
||||
if requester.is_guest:
|
||||
guest_can_join = yield self._can_guest_join(current_state_ids)
|
||||
guest_can_join = await self._can_guest_join(current_state_ids)
|
||||
if not guest_can_join:
|
||||
# This should be an auth check, but guests are a local concept,
|
||||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
if not is_host_in_room:
|
||||
inviter = yield self._get_inviter(target.to_string(), room_id)
|
||||
inviter = await self._get_inviter(target.to_string(), room_id)
|
||||
if inviter and not self.hs.is_mine(inviter):
|
||||
remote_room_hosts.append(inviter.domain)
|
||||
|
||||
|
@ -443,13 +440,13 @@ class RoomMemberHandler(object):
|
|||
|
||||
profile = self.profile_handler
|
||||
if not content_specified:
|
||||
content["displayname"] = yield profile.get_displayname(target)
|
||||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
||||
content["displayname"] = await profile.get_displayname(target)
|
||||
content["avatar_url"] = await profile.get_avatar_url(target)
|
||||
|
||||
if requester.is_guest:
|
||||
content["kind"] = "guest"
|
||||
|
||||
remote_join_response = yield self._remote_join(
|
||||
remote_join_response = await self._remote_join(
|
||||
requester, remote_room_hosts, room_id, target, content
|
||||
)
|
||||
|
||||
|
@ -458,7 +455,7 @@ class RoomMemberHandler(object):
|
|||
elif effective_membership_state == Membership.LEAVE:
|
||||
if not is_host_in_room:
|
||||
# perhaps we've been invited
|
||||
inviter = yield self._get_inviter(target.to_string(), room_id)
|
||||
inviter = await self._get_inviter(target.to_string(), room_id)
|
||||
if not inviter:
|
||||
raise SynapseError(404, "Not a known room")
|
||||
|
||||
|
@ -472,12 +469,12 @@ class RoomMemberHandler(object):
|
|||
else:
|
||||
# send the rejection to the inviter's HS.
|
||||
remote_room_hosts = remote_room_hosts + [inviter.domain]
|
||||
res = yield self._remote_reject_invite(
|
||||
res = await self._remote_reject_invite(
|
||||
requester, remote_room_hosts, room_id, target, content,
|
||||
)
|
||||
return res
|
||||
|
||||
res = yield self._local_membership_update(
|
||||
res = await self._local_membership_update(
|
||||
requester=requester,
|
||||
target=target,
|
||||
room_id=room_id,
|
||||
|
@ -572,8 +569,7 @@ class RoomMemberHandler(object):
|
|||
)
|
||||
continue
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_membership_event(self, requester, event, context, ratelimit=True):
|
||||
async def send_membership_event(self, requester, event, context, ratelimit=True):
|
||||
"""
|
||||
Change the membership status of a user in a room.
|
||||
|
||||
|
@ -599,27 +595,27 @@ class RoomMemberHandler(object):
|
|||
else:
|
||||
requester = types.create_requester(target_user)
|
||||
|
||||
prev_event = yield self.event_creation_handler.deduplicate_state_event(
|
||||
prev_event = await self.event_creation_handler.deduplicate_state_event(
|
||||
event, context
|
||||
)
|
||||
if prev_event is not None:
|
||||
return
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
if event.membership == Membership.JOIN:
|
||||
if requester.is_guest:
|
||||
guest_can_join = yield self._can_guest_join(prev_state_ids)
|
||||
guest_can_join = await self._can_guest_join(prev_state_ids)
|
||||
if not guest_can_join:
|
||||
# This should be an auth check, but guests are a local concept,
|
||||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
if event.membership not in (Membership.LEAVE, Membership.BAN):
|
||||
is_blocked = yield self.store.is_room_blocked(room_id)
|
||||
is_blocked = await self.store.is_room_blocked(room_id)
|
||||
if is_blocked:
|
||||
raise SynapseError(403, "This room has been blocked on this server")
|
||||
|
||||
yield self.event_creation_handler.handle_new_client_event(
|
||||
await self.event_creation_handler.handle_new_client_event(
|
||||
requester, event, context, extra_users=[target_user], ratelimit=ratelimit
|
||||
)
|
||||
|
||||
|
@ -633,15 +629,15 @@ class RoomMemberHandler(object):
|
|||
# info.
|
||||
newly_joined = True
|
||||
if prev_member_event_id:
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
if newly_joined:
|
||||
yield self._user_joined_room(target_user, room_id)
|
||||
await self._user_joined_room(target_user, room_id)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
if prev_member_event_id:
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
if prev_member_event.membership == Membership.JOIN:
|
||||
yield self._user_left_room(target_user, room_id)
|
||||
await self._user_left_room(target_user, room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _can_guest_join(self, current_state_ids):
|
||||
|
@ -699,8 +695,7 @@ class RoomMemberHandler(object):
|
|||
if invite:
|
||||
return UserID.from_string(invite.sender)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_3pid_invite(
|
||||
async def do_3pid_invite(
|
||||
self,
|
||||
room_id,
|
||||
inviter,
|
||||
|
@ -712,7 +707,7 @@ class RoomMemberHandler(object):
|
|||
id_access_token=None,
|
||||
):
|
||||
if self.config.block_non_admin_invites:
|
||||
is_requester_admin = yield self.auth.is_server_admin(requester.user)
|
||||
is_requester_admin = await self.auth.is_server_admin(requester.user)
|
||||
if not is_requester_admin:
|
||||
raise SynapseError(
|
||||
403, "Invites have been disabled on this server", Codes.FORBIDDEN
|
||||
|
@ -720,9 +715,9 @@ class RoomMemberHandler(object):
|
|||
|
||||
# We need to rate limit *before* we send out any 3PID invites, so we
|
||||
# can't just rely on the standard ratelimiting of events.
|
||||
yield self.base_handler.ratelimit(requester)
|
||||
await self.base_handler.ratelimit(requester)
|
||||
|
||||
can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
|
||||
can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
|
||||
medium, address, room_id
|
||||
)
|
||||
if not can_invite:
|
||||
|
@ -737,16 +732,16 @@ class RoomMemberHandler(object):
|
|||
403, "Looking up third-party identifiers is denied from this server"
|
||||
)
|
||||
|
||||
invitee = yield self.identity_handler.lookup_3pid(
|
||||
invitee = await self.identity_handler.lookup_3pid(
|
||||
id_server, medium, address, id_access_token
|
||||
)
|
||||
|
||||
if invitee:
|
||||
yield self.update_membership(
|
||||
await self.update_membership(
|
||||
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
|
||||
)
|
||||
else:
|
||||
yield self._make_and_store_3pid_invite(
|
||||
await self._make_and_store_3pid_invite(
|
||||
requester,
|
||||
id_server,
|
||||
medium,
|
||||
|
@ -757,8 +752,7 @@ class RoomMemberHandler(object):
|
|||
id_access_token=id_access_token,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _make_and_store_3pid_invite(
|
||||
async def _make_and_store_3pid_invite(
|
||||
self,
|
||||
requester,
|
||||
id_server,
|
||||
|
@ -769,7 +763,7 @@ class RoomMemberHandler(object):
|
|||
txn_id,
|
||||
id_access_token=None,
|
||||
):
|
||||
room_state = yield self.state_handler.get_current_state(room_id)
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
|
||||
inviter_display_name = ""
|
||||
inviter_avatar_url = ""
|
||||
|
@ -807,7 +801,7 @@ class RoomMemberHandler(object):
|
|||
public_keys,
|
||||
fallback_public_key,
|
||||
display_name,
|
||||
) = yield self.identity_handler.ask_id_server_for_third_party_invite(
|
||||
) = await self.identity_handler.ask_id_server_for_third_party_invite(
|
||||
requester=requester,
|
||||
id_server=id_server,
|
||||
medium=medium,
|
||||
|
@ -823,7 +817,7 @@ class RoomMemberHandler(object):
|
|||
id_access_token=id_access_token,
|
||||
)
|
||||
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.ThirdPartyInvite,
|
||||
|
@ -917,8 +911,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
|
||||
return complexity["v1"] > max_complexity
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
|
||||
async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
|
||||
"""Implements RoomMemberHandler._remote_join
|
||||
"""
|
||||
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
|
||||
|
@ -933,7 +926,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
|
||||
if self.hs.config.limit_remote_rooms.enabled:
|
||||
# Fetch the room complexity
|
||||
too_complex = yield self._is_remote_room_too_complex(
|
||||
too_complex = await self._is_remote_room_too_complex(
|
||||
room_id, remote_room_hosts
|
||||
)
|
||||
if too_complex is True:
|
||||
|
@ -947,12 +940,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
# join dance for now, since we're kinda implicitly checking
|
||||
# that we are allowed to join when we decide whether or not we
|
||||
# need to do the invite/join dance.
|
||||
yield defer.ensureDeferred(
|
||||
self.federation_handler.do_invite_join(
|
||||
remote_room_hosts, room_id, user.to_string(), content
|
||||
)
|
||||
await self.federation_handler.do_invite_join(
|
||||
remote_room_hosts, room_id, user.to_string(), content
|
||||
)
|
||||
yield self._user_joined_room(user, room_id)
|
||||
await self._user_joined_room(user, room_id)
|
||||
|
||||
# Check the room we just joined wasn't too large, if we didn't fetch the
|
||||
# complexity of it before.
|
||||
|
@ -962,7 +953,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
return
|
||||
|
||||
# Check again, but with the local state events
|
||||
too_complex = yield self._is_local_room_too_complex(room_id)
|
||||
too_complex = await self._is_local_room_too_complex(room_id)
|
||||
|
||||
if too_complex is False:
|
||||
# We're under the limit.
|
||||
|
@ -970,7 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
|
||||
# The room is too large. Leave.
|
||||
requester = types.create_requester(user, None, False, None)
|
||||
yield self.update_membership(
|
||||
await self.update_membership(
|
||||
requester=requester, target=user, room_id=room_id, action="leave"
|
||||
)
|
||||
raise SynapseError(
|
||||
|
@ -1008,12 +999,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
def _user_joined_room(self, target, room_id):
|
||||
"""Implements RoomMemberHandler._user_joined_room
|
||||
"""
|
||||
return user_joined_room(self.distributor, target, room_id)
|
||||
return defer.succeed(user_joined_room(self.distributor, target, room_id))
|
||||
|
||||
def _user_left_room(self, target, room_id):
|
||||
"""Implements RoomMemberHandler._user_left_room
|
||||
"""
|
||||
return user_left_room(self.distributor, target, room_id)
|
||||
return defer.succeed(user_left_room(self.distributor, target, room_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def forget(self, user, room_id):
|
||||
|
|
|
@ -149,7 +149,7 @@ class SamlHandler:
|
|||
|
||||
# Complete the interactive auth session or the login.
|
||||
if current_session and current_session.ui_auth_session_id:
|
||||
self._auth_handler.complete_sso_ui_auth(
|
||||
await self._auth_handler.complete_sso_ui_auth(
|
||||
user_id, current_session.ui_auth_session_id, request
|
||||
)
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import abc
|
||||
import logging
|
||||
import re
|
||||
from inspect import signature
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from six import raise_from
|
||||
|
@ -60,6 +61,8 @@ class ReplicationEndpoint(object):
|
|||
must call `register` to register the path with the HTTP server.
|
||||
|
||||
Requests can be sent by calling the client returned by `make_client`.
|
||||
Requests are sent to master process by default, but can be sent to other
|
||||
named processes by specifying an `instance_name` keyword argument.
|
||||
|
||||
Attributes:
|
||||
NAME (str): A name for the endpoint, added to the path as well as used
|
||||
|
@ -91,6 +94,16 @@ class ReplicationEndpoint(object):
|
|||
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
|
||||
)
|
||||
|
||||
# We reserve `instance_name` as a parameter to sending requests, so we
|
||||
# assert here that sub classes don't try and use the name.
|
||||
assert (
|
||||
"instance_name" not in self.PATH_ARGS
|
||||
), "`instance_name` is a reserved paramater name"
|
||||
assert (
|
||||
"instance_name"
|
||||
not in signature(self.__class__._serialize_payload).parameters
|
||||
), "`instance_name` is a reserved paramater name"
|
||||
|
||||
assert self.METHOD in ("PUT", "POST", "GET")
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@ -135,7 +148,11 @@ class ReplicationEndpoint(object):
|
|||
|
||||
@trace(opname="outgoing_replication_request")
|
||||
@defer.inlineCallbacks
|
||||
def send_request(**kwargs):
|
||||
def send_request(instance_name="master", **kwargs):
|
||||
# Currently we only support sending requests to master process.
|
||||
if instance_name != "master":
|
||||
raise Exception("Unknown instance")
|
||||
|
||||
data = yield cls._serialize_payload(**kwargs)
|
||||
|
||||
url_args = [
|
||||
|
|
|
@ -50,6 +50,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
|||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
# We pull the streams from the replication steamer (if we try and make
|
||||
# them ourselves we end up in an import loop).
|
||||
self.streams = hs.get_replication_streamer().get_streams()
|
||||
|
@ -67,7 +69,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
|||
upto_token = parse_integer(request, "upto_token", required=True)
|
||||
|
||||
updates, upto_token, limited = await stream.get_updates_since(
|
||||
from_token, upto_token
|
||||
self._instance_name, from_token, upto_token
|
||||
)
|
||||
|
||||
return (
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import six
|
||||
|
||||
|
@ -49,19 +49,6 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
|
|||
|
||||
self.hs = hs
|
||||
|
||||
def stream_positions(self) -> Dict[str, int]:
|
||||
"""
|
||||
Get the current positions of all the streams this store wants to subscribe to
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
pos = {}
|
||||
if self._cache_id_gen:
|
||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||
return pos
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
if self._cache_id_gen:
|
||||
return self._cache_id_gen.get_current_token()
|
||||
|
|
|
@ -32,14 +32,6 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
|
|||
def get_max_account_data_stream_id(self):
|
||||
return self._account_data_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedAccountDataStore, self).stream_positions()
|
||||
position = self._account_data_id_gen.get_current_token()
|
||||
result["user_account_data"] = position
|
||||
result["room_account_data"] = position
|
||||
result["tag_account_data"] = position
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "tag_account_data":
|
||||
self._account_data_id_gen.advance(token)
|
||||
|
|
|
@ -43,11 +43,6 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
|||
expiry_ms=30 * 60 * 1000,
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedDeviceInboxStore, self).stream_positions()
|
||||
result["to_device"] = self._device_inbox_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "to_device":
|
||||
self._device_inbox_id_gen.advance(token)
|
||||
|
|
|
@ -48,16 +48,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
|||
"DeviceListFederationStreamChangeCache", device_list_max
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedDeviceStore, self).stream_positions()
|
||||
# The user signature stream uses the same stream ID generator as the
|
||||
# device list stream, so set them both to the device list ID
|
||||
# generator's current token.
|
||||
current_token = self._device_list_id_gen.get_current_token()
|
||||
result[DeviceListsStream.NAME] = current_token
|
||||
result[UserSignatureStream.NAME] = current_token
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(token)
|
||||
|
|
|
@ -93,12 +93,6 @@ class SlavedEventStore(
|
|||
def get_room_min_stream_ordering(self):
|
||||
return self._backfill_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedEventStore, self).stream_positions()
|
||||
result["events"] = self._stream_id_gen.get_current_token()
|
||||
result["backfill"] = -self._backfill_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "events":
|
||||
self._stream_id_gen.advance(token)
|
||||
|
|
|
@ -37,11 +37,6 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
|||
def get_group_stream_token(self):
|
||||
return self._group_updates_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedGroupServerStore, self).stream_positions()
|
||||
result["groups"] = self._group_updates_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "groups":
|
||||
self._group_updates_id_gen.advance(token)
|
||||
|
|
|
@ -41,15 +41,6 @@ class SlavedPresenceStore(BaseSlavedStore):
|
|||
def get_current_presence_token(self):
|
||||
return self._presence_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPresenceStore, self).stream_positions()
|
||||
|
||||
if self.hs.config.use_presence:
|
||||
position = self._presence_id_gen.get_current_token()
|
||||
result["presence"] = position
|
||||
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "presence":
|
||||
self._presence_id_gen.advance(token)
|
||||
|
|
|
@ -37,11 +37,6 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
|||
def get_max_push_rules_stream_id(self):
|
||||
return self._push_rules_stream_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPushRuleStore, self).stream_positions()
|
||||
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "push_rules":
|
||||
self._push_rules_stream_id_gen.advance(token)
|
||||
|
|
|
@ -28,11 +28,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
|||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPusherStore, self).stream_positions()
|
||||
result["pushers"] = self._pushers_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def get_pushers_stream_token(self):
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
||||
|
|
|
@ -42,11 +42,6 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
|||
def get_max_receipt_stream_id(self):
|
||||
return self._receipts_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedReceiptsStore, self).stream_positions()
|
||||
result["receipts"] = self._receipts_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
||||
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
|
||||
|
|
|
@ -30,11 +30,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
|
|||
def get_current_public_room_stream_id(self):
|
||||
return self._public_room_id_gen.get_current_token()
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(RoomStore, self).stream_positions()
|
||||
result["public_rooms"] = self._public_room_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "public_rooms":
|
||||
self._public_room_id_gen.advance(token)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
|
||||
|
@ -86,37 +86,22 @@ class ReplicationDataHandler:
|
|||
def __init__(self, store: BaseSlavedStore):
|
||||
self.store = store
|
||||
|
||||
async def on_rdata(self, stream_name: str, token: int, rows: list):
|
||||
async def on_rdata(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||
handle more.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
stream_name: name of the replication stream for this batch of rows
|
||||
instance_name: the instance that wrote the rows.
|
||||
token: stream token for this batch of rows
|
||||
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
|
||||
"""
|
||||
self.store.process_replication_rows(stream_name, token, rows)
|
||||
|
||||
def get_streams_to_replicate(self) -> Dict[str, int]:
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
args = self.store.stream_positions()
|
||||
user_account_data = args.pop("user_account_data", None)
|
||||
room_account_data = args.pop("room_account_data", None)
|
||||
if user_account_data:
|
||||
args["account_data"] = user_account_data
|
||||
elif room_account_data:
|
||||
args["account_data"] = room_account_data
|
||||
return args
|
||||
|
||||
async def on_position(self, stream_name: str, token: int):
|
||||
self.store.process_replication_rows(stream_name, token, [])
|
||||
|
||||
|
|
|
@ -278,19 +278,24 @@ class ReplicationCommandHandler:
|
|||
# Check if this is the last of a batch of updates
|
||||
rows = self._pending_batches.pop(stream_name, [])
|
||||
rows.append(row)
|
||||
await self.on_rdata(stream_name, cmd.token, rows)
|
||||
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
|
||||
|
||||
async def on_rdata(self, stream_name: str, token: int, rows: list):
|
||||
async def on_rdata(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
Args:
|
||||
stream_name: name of the replication stream for this batch of rows
|
||||
instance_name: the instance that wrote the rows.
|
||||
token: stream token for this batch of rows
|
||||
rows: a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
await self._replication_data_handler.on_rdata(stream_name, token, rows)
|
||||
await self._replication_data_handler.on_rdata(
|
||||
stream_name, instance_name, token, rows
|
||||
)
|
||||
|
||||
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
|
@ -314,15 +319,7 @@ class ReplicationCommandHandler:
|
|||
self._pending_batches.pop(cmd.stream_name, [])
|
||||
|
||||
# Find where we previously streamed up to.
|
||||
current_token = self._replication_data_handler.get_streams_to_replicate().get(
|
||||
cmd.stream_name
|
||||
)
|
||||
if current_token is None:
|
||||
logger.warning(
|
||||
"Got POSITION for stream we're not subscribed to: %s",
|
||||
cmd.stream_name,
|
||||
)
|
||||
return
|
||||
current_token = stream.current_token()
|
||||
|
||||
# If the position token matches our current token then we're up to
|
||||
# date and there's nothing to do. Otherwise, fetch all updates
|
||||
|
@ -333,7 +330,9 @@ class ReplicationCommandHandler:
|
|||
updates,
|
||||
current_token,
|
||||
missing_updates,
|
||||
) = await stream.get_updates_since(current_token, cmd.token)
|
||||
) = await stream.get_updates_since(
|
||||
cmd.instance_name, current_token, cmd.token
|
||||
)
|
||||
|
||||
# TODO: add some tests for this
|
||||
|
||||
|
@ -342,7 +341,10 @@ class ReplicationCommandHandler:
|
|||
|
||||
for token, rows in _batch_updates(updates):
|
||||
await self.on_rdata(
|
||||
cmd.stream_name, token, [stream.parse_row(row) for row in rows],
|
||||
cmd.stream_name,
|
||||
cmd.instance_name,
|
||||
token,
|
||||
[stream.parse_row(row) for row in rows],
|
||||
)
|
||||
|
||||
# We've now caught up to position sent to us, notify handler.
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import heapq
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
@ -21,10 +22,10 @@ from typing import (
|
|||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
@ -49,7 +50,7 @@ Token = int
|
|||
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
|
||||
# just a row from a database query, though this is dependent on the stream in question.
|
||||
#
|
||||
StreamRow = Tuple
|
||||
StreamRow = TypeVar("StreamRow", bound=Tuple)
|
||||
|
||||
# The type returned by the update_function of a stream, as well as get_updates(),
|
||||
# get_updates_since, etc.
|
||||
|
@ -65,6 +66,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
|
|||
#
|
||||
# The arguments are:
|
||||
#
|
||||
# * instance_name: the writer of the stream
|
||||
# * from_token: the previous stream token: the starting point for fetching the
|
||||
# updates
|
||||
# * to_token: the new stream token: the point to get updates up to
|
||||
|
@ -74,7 +76,7 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
|
|||
# If there are more updates available, it should set `limited` in the result, and
|
||||
# it will be called again to get the next batch.
|
||||
#
|
||||
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
|
||||
UpdateFunction = Callable[[str, Token, Token, int], Awaitable[StreamUpdateResult]]
|
||||
|
||||
|
||||
class Stream(object):
|
||||
|
@ -105,6 +107,7 @@ class Stream(object):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
local_instance_name: str,
|
||||
current_token_function: Callable[[], Token],
|
||||
update_function: UpdateFunction,
|
||||
):
|
||||
|
@ -120,9 +123,11 @@ class Stream(object):
|
|||
stream tokens. See the UpdateFunction type definition for more info.
|
||||
|
||||
Args:
|
||||
local_instance_name: The instance name of the current process
|
||||
current_token_function: callback to get the current token, as above
|
||||
update_function: callback go get stream updates, as above
|
||||
"""
|
||||
self.local_instance_name = local_instance_name
|
||||
self.current_token = current_token_function
|
||||
self.update_function = update_function
|
||||
|
||||
|
@ -147,14 +152,14 @@ class Stream(object):
|
|||
"""
|
||||
current_token = self.current_token()
|
||||
updates, current_token, limited = await self.get_updates_since(
|
||||
self.last_token, current_token
|
||||
self.local_instance_name, self.last_token, current_token
|
||||
)
|
||||
self.last_token = current_token
|
||||
|
||||
return updates, current_token, limited
|
||||
|
||||
async def get_updates_since(
|
||||
self, from_token: Token, upto_token: Token
|
||||
self, instance_name: str, from_token: Token, upto_token: Token
|
||||
) -> StreamUpdateResult:
|
||||
"""Like get_updates except allows specifying from when we should
|
||||
stream updates
|
||||
|
@ -172,26 +177,25 @@ class Stream(object):
|
|||
return [], upto_token, False
|
||||
|
||||
updates, upto_token, limited = await self.update_function(
|
||||
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
|
||||
instance_name, from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
|
||||
)
|
||||
return updates, upto_token, limited
|
||||
|
||||
|
||||
def db_query_to_update_function(
|
||||
query_function: Callable[[Token, Token, int], Awaitable[Iterable[tuple]]]
|
||||
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
||||
) -> UpdateFunction:
|
||||
"""Wraps a db query function which returns a list of rows to make it
|
||||
suitable for use as an `update_function` for the Stream class
|
||||
"""
|
||||
|
||||
async def update_function(from_token, upto_token, limit):
|
||||
async def update_function(instance_name, from_token, upto_token, limit):
|
||||
rows = await query_function(from_token, upto_token, limit)
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
limited = False
|
||||
if len(updates) == limit:
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
assert len(updates) <= limit
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
|
@ -206,10 +210,13 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
|
|||
client = ReplicationGetStreamUpdates.make_client(hs)
|
||||
|
||||
async def update_function(
|
||||
from_token: int, upto_token: int, limit: int
|
||||
instance_name: str, from_token: int, upto_token: int, limit: int
|
||||
) -> StreamUpdateResult:
|
||||
result = await client(
|
||||
stream_name=stream_name, from_token=from_token, upto_token=upto_token,
|
||||
instance_name=instance_name,
|
||||
stream_name=stream_name,
|
||||
from_token=from_token,
|
||||
upto_token=upto_token,
|
||||
)
|
||||
return result["updates"], result["upto_token"], result["limited"]
|
||||
|
||||
|
@ -239,6 +246,7 @@ class BackfillStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_current_backfill_token,
|
||||
db_query_to_update_function(store.get_all_new_backfill_event_rows),
|
||||
)
|
||||
|
@ -274,7 +282,9 @@ class PresenceStream(Stream):
|
|||
# Query master process
|
||||
update_function = make_http_update_function(hs, self.NAME)
|
||||
|
||||
super().__init__(store.get_current_presence_token, update_function)
|
||||
super().__init__(
|
||||
hs.get_instance_name(), store.get_current_presence_token, update_function
|
||||
)
|
||||
|
||||
|
||||
class TypingStream(Stream):
|
||||
|
@ -297,7 +307,9 @@ class TypingStream(Stream):
|
|||
# Query master process
|
||||
update_function = make_http_update_function(hs, self.NAME)
|
||||
|
||||
super().__init__(typing_handler.get_current_token, update_function)
|
||||
super().__init__(
|
||||
hs.get_instance_name(), typing_handler.get_current_token, update_function
|
||||
)
|
||||
|
||||
|
||||
class ReceiptsStream(Stream):
|
||||
|
@ -318,6 +330,7 @@ class ReceiptsStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_max_receipt_stream_id,
|
||||
db_query_to_update_function(store.get_all_updated_receipts),
|
||||
)
|
||||
|
@ -335,14 +348,16 @@ class PushRulesStream(Stream):
|
|||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
super(PushRulesStream, self).__init__(
|
||||
self._current_token, self._update_function
|
||||
hs.get_instance_name(), self._current_token, self._update_function
|
||||
)
|
||||
|
||||
def _current_token(self) -> int:
|
||||
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
||||
return push_rules_token
|
||||
|
||||
async def _update_function(self, from_token: Token, to_token: Token, limit: int):
|
||||
async def _update_function(
|
||||
self, instance_name: str, from_token: Token, to_token: Token, limit: int
|
||||
):
|
||||
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
||||
|
||||
limited = False
|
||||
|
@ -369,6 +384,7 @@ class PushersStream(Stream):
|
|||
store = hs.get_datastore()
|
||||
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_pushers_stream_token,
|
||||
db_query_to_update_function(store.get_all_updated_pushers_rows),
|
||||
)
|
||||
|
@ -400,6 +416,7 @@ class CachesStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_cache_stream_token,
|
||||
db_query_to_update_function(store.get_all_updated_caches),
|
||||
)
|
||||
|
@ -425,6 +442,7 @@ class PublicRoomsStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_current_public_room_stream_id,
|
||||
db_query_to_update_function(store.get_all_new_public_rooms),
|
||||
)
|
||||
|
@ -445,6 +463,7 @@ class DeviceListsStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_device_stream_token,
|
||||
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
|
||||
)
|
||||
|
@ -462,6 +481,7 @@ class ToDeviceStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_to_device_stream_token,
|
||||
db_query_to_update_function(store.get_all_new_device_messages),
|
||||
)
|
||||
|
@ -481,6 +501,7 @@ class TagAccountDataStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_max_account_data_stream_id,
|
||||
db_query_to_update_function(store.get_all_updated_tags),
|
||||
)
|
||||
|
@ -501,11 +522,13 @@ class AccountDataStream(Stream):
|
|||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
super().__init__(
|
||||
self.store.get_max_account_data_stream_id, self._update_function,
|
||||
hs.get_instance_name(),
|
||||
self.store.get_max_account_data_stream_id,
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
async def _update_function(
|
||||
self, from_token: int, to_token: int, limit: int
|
||||
self, instance_name: str, from_token: int, to_token: int, limit: int
|
||||
) -> StreamUpdateResult:
|
||||
limited = False
|
||||
global_results = await self.store.get_updated_global_account_data(
|
||||
|
@ -530,16 +553,19 @@ class AccountDataStream(Stream):
|
|||
|
||||
# convert the global results to the right format, and limit them to the to_token
|
||||
# at the same time
|
||||
global_results = (
|
||||
global_rows = (
|
||||
(stream_id, (user_id, None, account_data_type))
|
||||
for stream_id, user_id, account_data_type in global_results
|
||||
if stream_id <= to_token
|
||||
)
|
||||
|
||||
room_results = ((stream_id, rest) for stream_id, *rest in room_results)
|
||||
room_rows = (
|
||||
(stream_id, (user_id, room_id, account_data_type))
|
||||
for stream_id, user_id, room_id, account_data_type in room_results
|
||||
)
|
||||
|
||||
# we need to return a sorted list, so merge them together.
|
||||
updates = list(heapq.merge(room_results, global_results))
|
||||
updates = list(heapq.merge(room_rows, global_rows))
|
||||
return updates, to_token, limited
|
||||
|
||||
|
||||
|
@ -555,6 +581,7 @@ class GroupServerStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_group_stream_token,
|
||||
db_query_to_update_function(store.get_all_groups_changes),
|
||||
)
|
||||
|
@ -572,6 +599,7 @@ class UserSignatureStream(Stream):
|
|||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_device_stream_token,
|
||||
db_query_to_update_function(
|
||||
store.get_all_user_signature_changes_for_remotes
|
||||
|
|
|
@ -118,11 +118,17 @@ class EventsStream(Stream):
|
|||
def __init__(self, hs):
|
||||
self._store = hs.get_datastore()
|
||||
super().__init__(
|
||||
self._store.get_current_events_token, self._update_function,
|
||||
hs.get_instance_name(),
|
||||
self._store.get_current_events_token,
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
async def _update_function(
|
||||
self, from_token: Token, current_token: Token, target_row_count: int
|
||||
self,
|
||||
instance_name: str,
|
||||
from_token: Token,
|
||||
current_token: Token,
|
||||
target_row_count: int,
|
||||
) -> StreamUpdateResult:
|
||||
|
||||
# the events stream merges together three separate sources:
|
||||
|
|
|
@ -48,8 +48,8 @@ class FederationStream(Stream):
|
|||
current_token = lambda: 0
|
||||
update_function = self._stub_update_function
|
||||
|
||||
super().__init__(current_token, update_function)
|
||||
super().__init__(hs.get_instance_name(), current_token, update_function)
|
||||
|
||||
@staticmethod
|
||||
async def _stub_update_function(from_token, upto_token, limit):
|
||||
async def _stub_update_function(instance_name, from_token, upto_token, limit):
|
||||
return [], upto_token, False
|
||||
|
|
|
@ -140,7 +140,7 @@ class AuthRestServlet(RestServlet):
|
|||
self._cas_server_url = hs.config.cas_server_url
|
||||
self._cas_service_url = hs.config.cas_service_url
|
||||
|
||||
def on_GET(self, request, stagetype):
|
||||
async def on_GET(self, request, stagetype):
|
||||
session = parse_string(request, "session")
|
||||
if not session:
|
||||
raise SynapseError(400, "No session supplied")
|
||||
|
@ -180,7 +180,7 @@ class AuthRestServlet(RestServlet):
|
|||
else:
|
||||
raise SynapseError(400, "Homeserver not configured for SSO.")
|
||||
|
||||
html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
|
||||
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
|
||||
|
||||
else:
|
||||
raise SynapseError(404, "Unknown auth stage type")
|
||||
|
|
|
@ -499,7 +499,7 @@ class RegisterRestServlet(RestServlet):
|
|||
# registered a user for this session, so we could just return the
|
||||
# user here. We carry on and go through the auth checks though,
|
||||
# for paranoia.
|
||||
registered_user_id = self.auth_handler.get_session_data(
|
||||
registered_user_id = await self.auth_handler.get_session_data(
|
||||
session_id, "registered_user_id", None
|
||||
)
|
||||
|
||||
|
@ -598,7 +598,7 @@ class RegisterRestServlet(RestServlet):
|
|||
|
||||
# remember that we've now registered that user account, and with
|
||||
# what user ID (since the user may not have specified)
|
||||
self.auth_handler.set_session_data(
|
||||
await self.auth_handler.set_session_data(
|
||||
session_id, "registered_user_id", registered_user_id
|
||||
)
|
||||
|
||||
|
|
|
@ -16,8 +16,6 @@ import logging
|
|||
|
||||
from six import iteritems, string_types
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.urls import ConsentURIBuilder
|
||||
from synapse.config import ConfigError
|
||||
|
@ -59,8 +57,7 @@ class ConsentServerNotices(object):
|
|||
|
||||
self._consent_uri_builder = ConsentURIBuilder(hs.config)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_send_server_notice_to_user(self, user_id):
|
||||
async def maybe_send_server_notice_to_user(self, user_id):
|
||||
"""Check if we need to send a notice to this user, and does so if so
|
||||
|
||||
Args:
|
||||
|
@ -78,7 +75,7 @@ class ConsentServerNotices(object):
|
|||
return
|
||||
self._users_in_progress.add(user_id)
|
||||
try:
|
||||
u = yield self._store.get_user_by_id(user_id)
|
||||
u = await self._store.get_user_by_id(user_id)
|
||||
|
||||
if u["is_guest"] and not self._send_to_guests:
|
||||
# don't send to guests
|
||||
|
@ -100,8 +97,8 @@ class ConsentServerNotices(object):
|
|||
content = copy_with_str_subst(
|
||||
self._server_notice_content, {"consent_uri": consent_uri}
|
||||
)
|
||||
yield self._server_notices_manager.send_notice(user_id, content)
|
||||
yield self._store.user_set_consent_server_notice_sent(
|
||||
await self._server_notices_manager.send_notice(user_id, content)
|
||||
await self._store.user_set_consent_server_notice_sent(
|
||||
user_id, self._current_consent_version
|
||||
)
|
||||
except SynapseError as e:
|
||||
|
|
|
@ -50,8 +50,7 @@ class ResourceLimitsServerNotices(object):
|
|||
|
||||
self._notifier = hs.get_notifier()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_send_server_notice_to_user(self, user_id):
|
||||
async def maybe_send_server_notice_to_user(self, user_id):
|
||||
"""Check if we need to send a notice to this user, this will be true in
|
||||
two cases.
|
||||
1. The server has reached its limit does not reflect this
|
||||
|
@ -74,13 +73,13 @@ class ResourceLimitsServerNotices(object):
|
|||
# Don't try and send server notices unless they've been enabled
|
||||
return
|
||||
|
||||
timestamp = yield self._store.user_last_seen_monthly_active(user_id)
|
||||
timestamp = await self._store.user_last_seen_monthly_active(user_id)
|
||||
if timestamp is None:
|
||||
# This user will be blocked from receiving the notice anyway.
|
||||
# In practice, not sure we can ever get here
|
||||
return
|
||||
|
||||
room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user(
|
||||
room_id = await self._server_notices_manager.get_or_create_notice_room_for_user(
|
||||
user_id
|
||||
)
|
||||
|
||||
|
@ -88,10 +87,10 @@ class ResourceLimitsServerNotices(object):
|
|||
logger.warning("Failed to get server notices room")
|
||||
return
|
||||
|
||||
yield self._check_and_set_tags(user_id, room_id)
|
||||
await self._check_and_set_tags(user_id, room_id)
|
||||
|
||||
# Determine current state of room
|
||||
currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id)
|
||||
currently_blocked, ref_events = await self._is_room_currently_blocked(room_id)
|
||||
|
||||
limit_msg = None
|
||||
limit_type = None
|
||||
|
@ -99,7 +98,7 @@ class ResourceLimitsServerNotices(object):
|
|||
# Normally should always pass in user_id to check_auth_blocking
|
||||
# if you have it, but in this case are checking what would happen
|
||||
# to other users if they were to arrive.
|
||||
yield self._auth.check_auth_blocking()
|
||||
await self._auth.check_auth_blocking()
|
||||
except ResourceLimitError as e:
|
||||
limit_msg = e.msg
|
||||
limit_type = e.limit_type
|
||||
|
@ -112,22 +111,21 @@ class ResourceLimitsServerNotices(object):
|
|||
# We have hit the MAU limit, but MAU alerting is disabled:
|
||||
# reset room if necessary and return
|
||||
if currently_blocked:
|
||||
self._remove_limit_block_notification(user_id, ref_events)
|
||||
await self._remove_limit_block_notification(user_id, ref_events)
|
||||
return
|
||||
|
||||
if currently_blocked and not limit_msg:
|
||||
# Room is notifying of a block, when it ought not to be.
|
||||
yield self._remove_limit_block_notification(user_id, ref_events)
|
||||
await self._remove_limit_block_notification(user_id, ref_events)
|
||||
elif not currently_blocked and limit_msg:
|
||||
# Room is not notifying of a block, when it ought to be.
|
||||
yield self._apply_limit_block_notification(
|
||||
await self._apply_limit_block_notification(
|
||||
user_id, limit_msg, limit_type
|
||||
)
|
||||
except SynapseError as e:
|
||||
logger.error("Error sending resource limits server notice: %s", e)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _remove_limit_block_notification(self, user_id, ref_events):
|
||||
async def _remove_limit_block_notification(self, user_id, ref_events):
|
||||
"""Utility method to remove limit block notifications from the server
|
||||
notices room.
|
||||
|
||||
|
@ -137,12 +135,13 @@ class ResourceLimitsServerNotices(object):
|
|||
limit blocking and need to be preserved.
|
||||
"""
|
||||
content = {"pinned": ref_events}
|
||||
yield self._server_notices_manager.send_notice(
|
||||
await self._server_notices_manager.send_notice(
|
||||
user_id, content, EventTypes.Pinned, ""
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _apply_limit_block_notification(self, user_id, event_body, event_limit_type):
|
||||
async def _apply_limit_block_notification(
|
||||
self, user_id, event_body, event_limit_type
|
||||
):
|
||||
"""Utility method to apply limit block notifications in the server
|
||||
notices room.
|
||||
|
||||
|
@ -159,12 +158,12 @@ class ResourceLimitsServerNotices(object):
|
|||
"admin_contact": self._config.admin_contact,
|
||||
"limit_type": event_limit_type,
|
||||
}
|
||||
event = yield self._server_notices_manager.send_notice(
|
||||
event = await self._server_notices_manager.send_notice(
|
||||
user_id, content, EventTypes.Message
|
||||
)
|
||||
|
||||
content = {"pinned": [event.event_id]}
|
||||
yield self._server_notices_manager.send_notice(
|
||||
await self._server_notices_manager.send_notice(
|
||||
user_id, content, EventTypes.Pinned, ""
|
||||
)
|
||||
|
||||
|
@ -198,7 +197,7 @@ class ResourceLimitsServerNotices(object):
|
|||
room_id(str): The room id of the server notices room
|
||||
|
||||
Returns:
|
||||
|
||||
Deferred[Tuple[bool, List]]:
|
||||
bool: Is the room currently blocked
|
||||
list: The list of pinned events that are unrelated to limit blocking
|
||||
This list can be used as a convenience in the case where the block
|
||||
|
|
|
@ -14,11 +14,9 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
|
||||
from synapse.types import UserID, create_requester
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -51,8 +49,7 @@ class ServerNoticesManager(object):
|
|||
"""
|
||||
return self._config.server_notices_mxid is not None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_notice(
|
||||
async def send_notice(
|
||||
self, user_id, event_content, type=EventTypes.Message, state_key=None
|
||||
):
|
||||
"""Send a notice to the given user
|
||||
|
@ -68,8 +65,8 @@ class ServerNoticesManager(object):
|
|||
Returns:
|
||||
Deferred[FrozenEvent]
|
||||
"""
|
||||
room_id = yield self.get_or_create_notice_room_for_user(user_id)
|
||||
yield self.maybe_invite_user_to_room(user_id, room_id)
|
||||
room_id = await self.get_or_create_notice_room_for_user(user_id)
|
||||
await self.maybe_invite_user_to_room(user_id, room_id)
|
||||
|
||||
system_mxid = self._config.server_notices_mxid
|
||||
requester = create_requester(system_mxid)
|
||||
|
@ -86,13 +83,13 @@ class ServerNoticesManager(object):
|
|||
if state_key is not None:
|
||||
event_dict["state_key"] = state_key
|
||||
|
||||
res = yield self._event_creation_handler.create_and_send_nonmember_event(
|
||||
res = await self._event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, ratelimit=False
|
||||
)
|
||||
return res
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_or_create_notice_room_for_user(self, user_id):
|
||||
@cached()
|
||||
async def get_or_create_notice_room_for_user(self, user_id):
|
||||
"""Get the room for notices for a given user
|
||||
|
||||
If we have not yet created a notice room for this user, create it, but don't
|
||||
|
@ -109,7 +106,7 @@ class ServerNoticesManager(object):
|
|||
|
||||
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
|
||||
|
||||
rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
|
||||
rooms = await self._store.get_rooms_for_local_user_where_membership_is(
|
||||
user_id, [Membership.INVITE, Membership.JOIN]
|
||||
)
|
||||
for room in rooms:
|
||||
|
@ -118,7 +115,7 @@ class ServerNoticesManager(object):
|
|||
# be joined. This is kinda deliberate, in that if somebody somehow
|
||||
# manages to invite the system user to a room, that doesn't make it
|
||||
# the server notices room.
|
||||
user_ids = yield self._store.get_users_in_room(room.room_id)
|
||||
user_ids = await self._store.get_users_in_room(room.room_id)
|
||||
if self.server_notices_mxid in user_ids:
|
||||
# we found a room which our user shares with the system notice
|
||||
# user
|
||||
|
@ -146,7 +143,7 @@ class ServerNoticesManager(object):
|
|||
}
|
||||
|
||||
requester = create_requester(self.server_notices_mxid)
|
||||
info = yield self._room_creation_handler.create_room(
|
||||
info = await self._room_creation_handler.create_room(
|
||||
requester,
|
||||
config={
|
||||
"preset": RoomCreationPreset.PRIVATE_CHAT,
|
||||
|
@ -158,7 +155,7 @@ class ServerNoticesManager(object):
|
|||
)
|
||||
room_id = info["room_id"]
|
||||
|
||||
max_id = yield self._store.add_tag_to_room(
|
||||
max_id = await self._store.add_tag_to_room(
|
||||
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
|
||||
)
|
||||
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
||||
|
@ -166,8 +163,7 @@ class ServerNoticesManager(object):
|
|||
logger.info("Created server notices room %s for %s", room_id, user_id)
|
||||
return room_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_invite_user_to_room(self, user_id: str, room_id: str):
|
||||
async def maybe_invite_user_to_room(self, user_id: str, room_id: str):
|
||||
"""Invite the given user to the given server room, unless the user has already
|
||||
joined or been invited to it.
|
||||
|
||||
|
@ -179,14 +175,14 @@ class ServerNoticesManager(object):
|
|||
|
||||
# Check whether the user has already joined or been invited to this room. If
|
||||
# that's the case, there is no need to re-invite them.
|
||||
joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
|
||||
joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is(
|
||||
user_id, [Membership.INVITE, Membership.JOIN]
|
||||
)
|
||||
for room in joined_rooms:
|
||||
if room.room_id == room_id:
|
||||
return
|
||||
|
||||
yield self._room_member_handler.update_membership(
|
||||
await self._room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
target=UserID.from_string(user_id),
|
||||
room_id=room_id,
|
||||
|
|
|
@ -12,8 +12,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.server_notices.consent_server_notices import ConsentServerNotices
|
||||
from synapse.server_notices.resource_limits_server_notices import (
|
||||
ResourceLimitsServerNotices,
|
||||
|
@ -36,18 +34,16 @@ class ServerNoticesSender(object):
|
|||
ResourceLimitsServerNotices(hs),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_user_syncing(self, user_id):
|
||||
async def on_user_syncing(self, user_id):
|
||||
"""Called when the user performs a sync operation.
|
||||
|
||||
Args:
|
||||
user_id (str): mxid of user who synced
|
||||
"""
|
||||
for sn in self._server_notices:
|
||||
yield sn.maybe_send_server_notice_to_user(user_id)
|
||||
await sn.maybe_send_server_notice_to_user(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_user_ip(self, user_id):
|
||||
async def on_user_ip(self, user_id):
|
||||
"""Called on the master when a worker process saw a client request.
|
||||
|
||||
Args:
|
||||
|
@ -57,4 +53,4 @@ class ServerNoticesSender(object):
|
|||
# we check for notices to send to the user in on_user_ip as well as
|
||||
# in on_user_syncing
|
||||
for sn in self._server_notices:
|
||||
yield sn.maybe_send_server_notice_to_user(user_id)
|
||||
await sn.maybe_send_server_notice_to_user(user_id)
|
||||
|
|
|
@ -66,6 +66,7 @@ from .stats import StatsStore
|
|||
from .stream import StreamStore
|
||||
from .tags import TagsStore
|
||||
from .transactions import TransactionStore
|
||||
from .ui_auth import UIAuthStore
|
||||
from .user_directory import UserDirectoryStore
|
||||
from .user_erasure_store import UserErasureStore
|
||||
|
||||
|
@ -112,6 +113,7 @@ class DataStore(
|
|||
StatsStore,
|
||||
RelationsStore,
|
||||
CacheInvalidationStore,
|
||||
UIAuthStore,
|
||||
):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
self.hs = hs
|
||||
|
|
|
@ -178,7 +178,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
|
||||
async def get_updated_global_account_data(
|
||||
self, last_id: int, current_id: int, limit: int
|
||||
) -> List[Tuple]:
|
||||
) -> List[Tuple[int, str, str]]:
|
||||
"""Get the global account_data that has changed, for the account_data stream
|
||||
|
||||
Args:
|
||||
|
@ -208,7 +208,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
|
||||
async def get_updated_room_account_data(
|
||||
self, last_id: int, current_id: int, limit: int
|
||||
) -> List[Tuple]:
|
||||
) -> List[Tuple[int, str, str, str]]:
|
||||
"""Get the global account_data that has changed, for the account_data stream
|
||||
|
||||
Args:
|
||||
|
|
|
@ -273,8 +273,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
desc="delete_account_validity_for_user",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_server_admin(self, user):
|
||||
async def is_server_admin(self, user):
|
||||
"""Determines if a user is an admin of this homeserver.
|
||||
|
||||
Args:
|
||||
|
@ -283,7 +282,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
Returns (bool):
|
||||
true iff the user is a server admin, false otherwise.
|
||||
"""
|
||||
res = yield self.db.simple_select_one_onecol(
|
||||
res = await self.db.simple_select_one_onecol(
|
||||
table="users",
|
||||
keyvalues={"name": user.to_string()},
|
||||
retcol="admin",
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ui_auth_sessions(
|
||||
session_id TEXT NOT NULL, -- The session ID passed to the client.
|
||||
creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds).
|
||||
serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse.
|
||||
clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client.
|
||||
uri TEXT NOT NULL, -- The URI the UI authentication session is using.
|
||||
method TEXT NOT NULL, -- The HTTP method the UI authentication session is using.
|
||||
-- The clientdict, uri, and method make up an tuple that must be immutable
|
||||
-- throughout the lifetime of the UI Auth session.
|
||||
description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur.
|
||||
UNIQUE (session_id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials(
|
||||
session_id TEXT NOT NULL, -- The corresponding UI Auth session.
|
||||
stage_type TEXT NOT NULL, -- The stage type.
|
||||
result TEXT NOT NULL, -- The result of the stage verification, stored as JSON.
|
||||
UNIQUE (session_id, stage_type),
|
||||
FOREIGN KEY (session_id)
|
||||
REFERENCES ui_auth_sessions (session_id)
|
||||
);
|
|
@ -0,0 +1,279 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import attr
|
||||
|
||||
import synapse.util.stringutils as stringutils
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.types import JsonDict
|
||||
|
||||
|
||||
@attr.s
|
||||
class UIAuthSessionData:
|
||||
session_id = attr.ib(type=str)
|
||||
# The dictionary from the client root level, not the 'auth' key.
|
||||
clientdict = attr.ib(type=JsonDict)
|
||||
# The URI and method the session was intiatied with. These are checked at
|
||||
# each stage of the authentication to ensure that the asked for operation
|
||||
# has not changed.
|
||||
uri = attr.ib(type=str)
|
||||
method = attr.ib(type=str)
|
||||
# A string description of the operation that the current authentication is
|
||||
# authorising.
|
||||
description = attr.ib(type=str)
|
||||
|
||||
|
||||
class UIAuthWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
Manage user interactive authentication sessions.
|
||||
"""
|
||||
|
||||
async def create_ui_auth_session(
|
||||
self, clientdict: JsonDict, uri: str, method: str, description: str,
|
||||
) -> UIAuthSessionData:
|
||||
"""
|
||||
Creates a new user interactive authentication session.
|
||||
|
||||
The session can be used to track the stages necessary to authenticate a
|
||||
user across multiple HTTP requests.
|
||||
|
||||
Args:
|
||||
clientdict:
|
||||
The dictionary from the client root level, not the 'auth' key.
|
||||
uri:
|
||||
The URI this session was initiated with, this is checked at each
|
||||
stage of the authentication to ensure that the asked for
|
||||
operation has not changed.
|
||||
method:
|
||||
The method this session was initiated with, this is checked at each
|
||||
stage of the authentication to ensure that the asked for
|
||||
operation has not changed.
|
||||
description:
|
||||
A string description of the operation that the current
|
||||
authentication is authorising.
|
||||
Returns:
|
||||
The newly created session.
|
||||
Raises:
|
||||
StoreError if a unique session ID cannot be generated.
|
||||
"""
|
||||
# The clientdict gets stored as JSON.
|
||||
clientdict_json = json.dumps(clientdict)
|
||||
|
||||
# autogen a session ID and try to create it. We may clash, so just
|
||||
# try a few times till one goes through, giving up eventually.
|
||||
attempts = 0
|
||||
while attempts < 5:
|
||||
session_id = stringutils.random_string(24)
|
||||
|
||||
try:
|
||||
await self.db.simple_insert(
|
||||
table="ui_auth_sessions",
|
||||
values={
|
||||
"session_id": session_id,
|
||||
"clientdict": clientdict_json,
|
||||
"uri": uri,
|
||||
"method": method,
|
||||
"description": description,
|
||||
"serverdict": "{}",
|
||||
"creation_time": self.hs.get_clock().time_msec(),
|
||||
},
|
||||
desc="create_ui_auth_session",
|
||||
)
|
||||
return UIAuthSessionData(
|
||||
session_id, clientdict, uri, method, description
|
||||
)
|
||||
except self.db.engine.module.IntegrityError:
|
||||
attempts += 1
|
||||
raise StoreError(500, "Couldn't generate a session ID.")
|
||||
|
||||
async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
|
||||
"""Retrieve a UI auth session.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session.
|
||||
Returns:
|
||||
A dict containing the device information.
|
||||
Raises:
|
||||
StoreError if the session is not found.
|
||||
"""
|
||||
result = await self.db.simple_select_one(
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("clientdict", "uri", "method", "description"),
|
||||
desc="get_ui_auth_session",
|
||||
)
|
||||
|
||||
result["clientdict"] = json.loads(result["clientdict"])
|
||||
|
||||
return UIAuthSessionData(session_id, **result)
|
||||
|
||||
async def mark_ui_auth_stage_complete(
|
||||
self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
|
||||
):
|
||||
"""
|
||||
Mark a session stage as completed.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the corresponding session.
|
||||
stage_type: The completed stage type.
|
||||
result: The result of the stage verification.
|
||||
Raises:
|
||||
StoreError if the session cannot be found.
|
||||
"""
|
||||
# Add (or update) the results of the current stage to the database.
|
||||
#
|
||||
# Note that we need to allow for the same stage to complete multiple
|
||||
# times here so that registration is idempotent.
|
||||
try:
|
||||
await self.db.simple_upsert(
|
||||
table="ui_auth_sessions_credentials",
|
||||
keyvalues={"session_id": session_id, "stage_type": stage_type},
|
||||
values={"result": json.dumps(result)},
|
||||
desc="mark_ui_auth_stage_complete",
|
||||
)
|
||||
except self.db.engine.module.IntegrityError:
|
||||
raise StoreError(400, "Unknown session ID: %s" % (session_id,))
|
||||
|
||||
async def get_completed_ui_auth_stages(
|
||||
self, session_id: str
|
||||
) -> Dict[str, Union[str, bool, JsonDict]]:
|
||||
"""
|
||||
Retrieve the completed stages of a UI authentication session.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session.
|
||||
Returns:
|
||||
The completed stages mapped to the result of the verification of
|
||||
that auth-type.
|
||||
"""
|
||||
results = {}
|
||||
for row in await self.db.simple_select_list(
|
||||
table="ui_auth_sessions_credentials",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("stage_type", "result"),
|
||||
desc="get_completed_ui_auth_stages",
|
||||
):
|
||||
results[row["stage_type"]] = json.loads(row["result"])
|
||||
|
||||
return results
|
||||
|
||||
async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
|
||||
"""
|
||||
Store a key-value pair into the sessions data associated with this
|
||||
request. This data is stored server-side and cannot be modified by
|
||||
the client.
|
||||
|
||||
Args:
|
||||
session_id: The ID of this session as returned from check_auth
|
||||
key: The key to store the data under
|
||||
value: The data to store
|
||||
Raises:
|
||||
StoreError if the session cannot be found.
|
||||
"""
|
||||
await self.db.runInteraction(
|
||||
"set_ui_auth_session_data",
|
||||
self._set_ui_auth_session_data_txn,
|
||||
session_id,
|
||||
key,
|
||||
value,
|
||||
)
|
||||
|
||||
def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
|
||||
# Get the current value.
|
||||
result = self.db.simple_select_one_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("serverdict",),
|
||||
)
|
||||
|
||||
# Update it and add it back to the database.
|
||||
serverdict = json.loads(result["serverdict"])
|
||||
serverdict[key] = value
|
||||
|
||||
self.db.simple_update_one_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
updatevalues={"serverdict": json.dumps(serverdict)},
|
||||
)
|
||||
|
||||
async def get_ui_auth_session_data(
|
||||
self, session_id: str, key: str, default: Optional[Any] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve data stored with set_session_data
|
||||
|
||||
Args:
|
||||
session_id: The ID of this session as returned from check_auth
|
||||
key: The key to store the data under
|
||||
default: Value to return if the key has not been set
|
||||
Raises:
|
||||
StoreError if the session cannot be found.
|
||||
"""
|
||||
result = await self.db.simple_select_one(
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("serverdict",),
|
||||
desc="get_ui_auth_session_data",
|
||||
)
|
||||
|
||||
serverdict = json.loads(result["serverdict"])
|
||||
|
||||
return serverdict.get(key, default)
|
||||
|
||||
|
||||
class UIAuthStore(UIAuthWorkerStore):
|
||||
def delete_old_ui_auth_sessions(self, expiration_time: int):
|
||||
"""
|
||||
Remove sessions which were last used earlier than the expiration time.
|
||||
|
||||
Args:
|
||||
expiration_time: The latest time that is still considered valid.
|
||||
This is an epoch time in milliseconds.
|
||||
|
||||
"""
|
||||
return self.db.runInteraction(
|
||||
"delete_old_ui_auth_sessions",
|
||||
self._delete_old_ui_auth_sessions_txn,
|
||||
expiration_time,
|
||||
)
|
||||
|
||||
def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
|
||||
# Get the expired sessions.
|
||||
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
|
||||
txn.execute(sql, [expiration_time])
|
||||
session_ids = [r[0] for r in txn.fetchall()]
|
||||
|
||||
# Delete the corresponding completed credentials.
|
||||
self.db.simple_delete_many_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions_credentials",
|
||||
column="session_id",
|
||||
iterable=session_ids,
|
||||
keyvalues={},
|
||||
)
|
||||
|
||||
# Finally, delete the sessions.
|
||||
self.db.simple_delete_many_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions",
|
||||
column="session_id",
|
||||
iterable=session_ids,
|
||||
keyvalues={},
|
||||
)
|
|
@ -85,6 +85,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
|
|||
prepare_database(db_conn, self, config=None)
|
||||
|
||||
db_conn.create_function("rank", 1, _rank)
|
||||
db_conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
def is_deadlock(self, error):
|
||||
return False
|
||||
|
|
|
@ -82,18 +82,26 @@ class ProfileTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_set_my_name(self):
|
||||
yield self.handler.set_displayname(
|
||||
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.set_displayname(
|
||||
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_displayname(self.frank.localpart)),
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
"Frank Jr.",
|
||||
)
|
||||
|
||||
# Set displayname again
|
||||
yield self.handler.set_displayname(
|
||||
self.frank, synapse.types.create_requester(self.frank), "Frank"
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.set_displayname(
|
||||
self.frank, synapse.types.create_requester(self.frank), "Frank"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
|
@ -112,16 +120,20 @@ class ProfileTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
# Setting displayname a second time is forbidden
|
||||
d = self.handler.set_displayname(
|
||||
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
||||
d = defer.ensureDeferred(
|
||||
self.handler.set_displayname(
|
||||
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
|
||||
)
|
||||
)
|
||||
|
||||
yield self.assertFailure(d, SynapseError)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_set_my_name_noauth(self):
|
||||
d = self.handler.set_displayname(
|
||||
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
|
||||
d = defer.ensureDeferred(
|
||||
self.handler.set_displayname(
|
||||
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
|
||||
)
|
||||
)
|
||||
|
||||
yield self.assertFailure(d, AuthError)
|
||||
|
@ -165,10 +177,12 @@ class ProfileTestCase(unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_set_my_avatar(self):
|
||||
yield self.handler.set_avatar_url(
|
||||
self.frank,
|
||||
synapse.types.create_requester(self.frank),
|
||||
"http://my.server/pic.gif",
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.set_avatar_url(
|
||||
self.frank,
|
||||
synapse.types.create_requester(self.frank),
|
||||
"http://my.server/pic.gif",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
|
@ -177,10 +191,12 @@ class ProfileTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
# Set avatar again
|
||||
yield self.handler.set_avatar_url(
|
||||
self.frank,
|
||||
synapse.types.create_requester(self.frank),
|
||||
"http://my.server/me.png",
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.set_avatar_url(
|
||||
self.frank,
|
||||
synapse.types.create_requester(self.frank),
|
||||
"http://my.server/me.png",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
|
@ -203,10 +219,12 @@ class ProfileTestCase(unittest.TestCase):
|
|||
)
|
||||
|
||||
# Set avatar a second time is forbidden
|
||||
d = self.handler.set_avatar_url(
|
||||
self.frank,
|
||||
synapse.types.create_requester(self.frank),
|
||||
"http://my.server/pic.gif",
|
||||
d = defer.ensureDeferred(
|
||||
self.handler.set_avatar_url(
|
||||
self.frank,
|
||||
synapse.types.create_requester(self.frank),
|
||||
"http://my.server/pic.gif",
|
||||
)
|
||||
)
|
||||
|
||||
yield self.assertFailure(d, SynapseError)
|
||||
|
|
|
@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
room_alias_str = "#room:test"
|
||||
self.hs.config.auto_join_rooms = [room_alias_str]
|
||||
|
||||
self.store.is_real_user = Mock(return_value=False)
|
||||
self.store.is_real_user = Mock(return_value=defer.succeed(False))
|
||||
user_id = self.get_success(self.handler.register_user(localpart="support"))
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||
self.assertEqual(len(rooms), 0)
|
||||
|
@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
room_alias_str = "#room:test"
|
||||
self.hs.config.auto_join_rooms = [room_alias_str]
|
||||
|
||||
self.store.count_real_users = Mock(return_value=1)
|
||||
self.store.is_real_user = Mock(return_value=True)
|
||||
self.store.count_real_users = Mock(return_value=defer.succeed(1))
|
||||
self.store.is_real_user = Mock(return_value=defer.succeed(True))
|
||||
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||
directory_handler = self.hs.get_handlers().directory_handler
|
||||
|
@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
room_alias_str = "#room:test"
|
||||
self.hs.config.auto_join_rooms = [room_alias_str]
|
||||
|
||||
self.store.count_real_users = Mock(return_value=2)
|
||||
self.store.is_real_user = Mock(return_value=True)
|
||||
self.store.count_real_users = Mock(return_value=defer.succeed(2))
|
||||
self.store.is_real_user = Mock(return_value=defer.succeed(True))
|
||||
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||
self.assertEqual(len(rooms), 0)
|
||||
|
@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
self.handler.register_user(localpart=invalid_user_id), SynapseError
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
|
||||
async def get_or_create_user(
|
||||
self, requester, localpart, displayname, password_hash=None
|
||||
):
|
||||
"""Creates a new user if the user does not exist,
|
||||
else revokes all previous access tokens and generates a new one.
|
||||
|
||||
|
@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
"""
|
||||
if localpart is None:
|
||||
raise SynapseError(400, "Request must include user id")
|
||||
yield self.hs.get_auth().check_auth_blocking()
|
||||
await self.hs.get_auth().check_auth_blocking()
|
||||
need_register = True
|
||||
|
||||
try:
|
||||
yield self.handler.check_username(localpart)
|
||||
await self.handler.check_username(localpart)
|
||||
except SynapseError as e:
|
||||
if e.errcode == Codes.USER_IN_USE:
|
||||
need_register = False
|
||||
|
@ -288,23 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
|||
token = self.macaroon_generator.generate_access_token(user_id)
|
||||
|
||||
if need_register:
|
||||
yield self.handler.register_with_store(
|
||||
await self.handler.register_with_store(
|
||||
user_id=user_id,
|
||||
password_hash=password_hash,
|
||||
create_profile_with_displayname=user.localpart,
|
||||
)
|
||||
else:
|
||||
yield defer.ensureDeferred(
|
||||
self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
|
||||
)
|
||||
await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
|
||||
|
||||
yield self.store.add_access_token_to_user(
|
||||
await self.store.add_access_token_to_user(
|
||||
user_id=user_id, token=token, device_id=None, valid_until_ms=None
|
||||
)
|
||||
|
||||
if displayname is not None:
|
||||
# logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||
yield self.hs.get_profile_handler().set_displayname(
|
||||
await self.hs.get_profile_handler().set_displayname(
|
||||
user, requester, displayname, by_admin=True
|
||||
)
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
|
|||
from twisted.internet.task import LoopingCall
|
||||
from twisted.web.http import HTTPChannel
|
||||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.app.generic_worker import (
|
||||
GenericWorkerReplicationHandler,
|
||||
GenericWorkerServer,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.tcp.client import ReplicationDataHandler
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
self._server_transport = None
|
||||
|
||||
def _build_replication_data_handler(self):
|
||||
return TestReplicationDataHandler(self.worker_hs.get_datastore())
|
||||
return TestReplicationDataHandler(self.worker_hs)
|
||||
|
||||
def reconnect(self):
|
||||
if self._client_transport:
|
||||
|
@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(request.method, b"GET")
|
||||
|
||||
|
||||
class TestReplicationDataHandler(ReplicationDataHandler):
|
||||
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
|
||||
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
|
||||
|
||||
def __init__(self, store: BaseSlavedStore):
|
||||
super().__init__(store)
|
||||
|
||||
# streams to subscribe to: map from stream id to position
|
||||
self.stream_positions = {} # type: Dict[str, int]
|
||||
def __init__(self, hs: HomeServer):
|
||||
super().__init__(hs)
|
||||
|
||||
# list of received (stream_name, token, row) tuples
|
||||
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
|
||||
|
||||
def get_streams_to_replicate(self):
|
||||
return self.stream_positions
|
||||
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
await super().on_rdata(stream_name, token, rows)
|
||||
async def on_rdata(self, stream_name, instance_name, token, rows):
|
||||
await super().on_rdata(stream_name, instance_name, token, rows)
|
||||
for r in rows:
|
||||
self.received_rdata_rows.append((stream_name, token, r))
|
||||
|
||||
if (
|
||||
stream_name in self.stream_positions
|
||||
and token > self.stream_positions[stream_name]
|
||||
):
|
||||
self.stream_positions[stream_name] = token
|
||||
|
||||
|
||||
@attr.s()
|
||||
class OneShotRequestFactory:
|
||||
|
|
|
@ -12,9 +12,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
_STREAM_UPDATE_TARGET_ROW_COUNT,
|
||||
|
@ -25,10 +22,6 @@ from tests.replication.tcp.streams._base import BaseStreamTestCase
|
|||
|
||||
|
||||
class AccountDataStreamTestCase(BaseStreamTestCase):
|
||||
def prepare(self, reactor, clock, hs):
|
||||
super().prepare(reactor, clock, hs)
|
||||
self.test_handler.stream_positions["account_data"] = 0
|
||||
|
||||
def test_update_function_room_account_data_limit(self):
|
||||
"""Test replication with many room account data updates
|
||||
"""
|
||||
|
|
|
@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
self.user_tok = self.login("u1", "pass")
|
||||
|
||||
self.reconnect()
|
||||
self.test_handler.stream_positions["events"] = 0
|
||||
|
||||
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
||||
self.test_handler.received_rdata_rows.clear()
|
||||
|
@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# we should have received all the expected rows in the right order
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
# we should have received all the expected rows in the right order (as
|
||||
# well as various cache invalidation updates which we ignore)
|
||||
received_rows = [
|
||||
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
|
||||
]
|
||||
|
||||
for event in events:
|
||||
stream_name, token, row = received_rows.pop(0)
|
||||
self.assertEqual("events", stream_name)
|
||||
|
@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# now we should have received all the expected rows in the right order.
|
||||
# we should have received all the expected rows in the right order (as
|
||||
# well as various cache invalidation updates which we ignore)
|
||||
#
|
||||
# we expect:
|
||||
#
|
||||
|
@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
# of the states that got reverted.
|
||||
# - two rows for state2
|
||||
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
received_rows = [
|
||||
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
|
||||
]
|
||||
|
||||
# first check the first two rows, which should be state1
|
||||
|
||||
|
@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
|||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# we should have received all the expected rows in the right order
|
||||
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
# we should have received all the expected rows in the right order (as
|
||||
# well as various cache invalidation updates which we ignore)
|
||||
received_rows = [
|
||||
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
|
||||
]
|
||||
self.assertGreaterEqual(len(received_rows), len(events))
|
||||
for i in range(NUM_USERS):
|
||||
# for each user, we expect the PL event row, followed by state rows for
|
||||
|
|
|
@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
|||
def test_receipt(self):
|
||||
self.reconnect()
|
||||
|
||||
# make the client subscribe to the receipts stream
|
||||
self.test_handler.stream_positions.update({"receipts": 0})
|
||||
|
||||
# tell the master to send a new receipt
|
||||
self.get_success(
|
||||
self.hs.get_datastore().insert_receipt(
|
||||
|
@ -44,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
|||
|
||||
# there should be one RDATA command
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "receipts")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
|
||||
|
@ -74,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
|||
|
||||
# We should now have caught up and get the missing data
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "receipts")
|
||||
self.assertEqual(token, 3)
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
|
|
|
@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||
|
||||
self.reconnect()
|
||||
|
||||
# make the client subscribe to the typing stream
|
||||
self.test_handler.stream_positions.update({"typing": 0})
|
||||
|
||||
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
@ -50,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||
self.assert_request_is_get_repl_stream_updates(request, "typing")
|
||||
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "typing")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
|
||||
|
@ -77,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
|||
self.assertEqual(int(request.args[b"from_token"][0]), token)
|
||||
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "typing")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
row = rdata_rows[0]
|
||||
|
|
|
@ -181,3 +181,43 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
|||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 403)
|
||||
|
||||
def test_complete_operation_unknown_session(self):
|
||||
"""
|
||||
Attempting to mark an invalid session as complete should error.
|
||||
"""
|
||||
|
||||
# Make the initial request to register. (Later on a different password
|
||||
# will be used.)
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"register",
|
||||
{"username": "user", "type": "m.login.password", "password": "bar"},
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
# Returns a 401 as per the spec
|
||||
self.assertEqual(request.code, 401)
|
||||
# Grab the session
|
||||
session = channel.json_body["session"]
|
||||
# Assert our configured public key is being given
|
||||
self.assertEqual(
|
||||
channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake"
|
||||
)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "auth/m.login.recaptcha/fallback/web?session=" + session
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(request.code, 200)
|
||||
|
||||
# Attempt to complete an unknown session, which should return an error.
|
||||
unknown_session = session + "unknown"
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"auth/m.login.recaptcha/fallback/web?session="
|
||||
+ unknown_session
|
||||
+ "&g-recaptcha-response=a",
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(request.code, 400)
|
||||
|
|
|
@ -55,25 +55,18 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
||||
return_value=defer.succeed(1000)
|
||||
)
|
||||
self._send_notice = self._rlsn._server_notices_manager.send_notice
|
||||
self._rlsn._server_notices_manager.send_notice = Mock()
|
||||
self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None))
|
||||
self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
|
||||
|
||||
self._rlsn._server_notices_manager.send_notice = Mock(
|
||||
return_value=defer.succeed(Mock())
|
||||
)
|
||||
self._send_notice = self._rlsn._server_notices_manager.send_notice
|
||||
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.user_id = "@user_id:test"
|
||||
|
||||
# self.server_notices_mxid = "@server:test"
|
||||
# self.server_notices_mxid_display_name = None
|
||||
# self.server_notices_mxid_avatar_url = None
|
||||
# self.server_notices_room_name = "Server Notices"
|
||||
|
||||
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
|
||||
returnValue=""
|
||||
return_value=defer.succeed("!something:localhost")
|
||||
)
|
||||
self._rlsn._store.add_tag_to_room = Mock()
|
||||
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
|
||||
self._rlsn._store.get_tags_for_room = Mock(return_value={})
|
||||
self.hs.config.admin_contact = "mailto:user@test.com"
|
||||
|
||||
|
@ -95,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
|
||||
"""Test when user has blocked notice, but should have it removed"""
|
||||
|
||||
self._rlsn._auth.check_auth_blocking = Mock()
|
||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
|
||||
mock_event = Mock(
|
||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||
)
|
||||
self._rlsn._store.get_events = Mock(
|
||||
return_value=defer.succeed({"123": mock_event})
|
||||
)
|
||||
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
# Would be better to check the content, but once == remove blocking event
|
||||
self._send_notice.assert_called_once()
|
||||
|
@ -112,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||
Test when user has blocked notice, but notice ought to be there (NOOP)
|
||||
"""
|
||||
self._rlsn._auth.check_auth_blocking = Mock(
|
||||
side_effect=ResourceLimitError(403, "foo")
|
||||
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
|
||||
)
|
||||
|
||||
mock_event = Mock(
|
||||
|
@ -121,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||
self._rlsn._store.get_events = Mock(
|
||||
return_value=defer.succeed({"123": mock_event})
|
||||
)
|
||||
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
|
||||
self._send_notice.assert_not_called()
|
||||
|
@ -129,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||
"""
|
||||
Test when user does not have blocked notice, but should have one
|
||||
"""
|
||||
|
||||
self._rlsn._auth.check_auth_blocking = Mock(
|
||||
side_effect=ResourceLimitError(403, "foo")
|
||||
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
|
||||
)
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
|
||||
|
@ -142,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||
"""
|
||||
Test when user does not have blocked notice, nor should they (NOOP)
|
||||
"""
|
||||
self._rlsn._auth.check_auth_blocking = Mock()
|
||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
|
||||
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
|
||||
|
@ -153,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||
Test when user is not part of the MAU cohort - this should not ever
|
||||
happen - but ...
|
||||
"""
|
||||
self._rlsn._auth.check_auth_blocking = Mock()
|
||||
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
|
||||
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
||||
return_value=defer.succeed(None)
|
||||
)
|
||||
|
@ -167,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||
an alert message is not sent into the room
|
||||
"""
|
||||
self.hs.config.mau_limit_alerting = False
|
||||
|
||||
self._rlsn._auth.check_auth_blocking = Mock(
|
||||
return_value=defer.succeed(None),
|
||||
side_effect=ResourceLimitError(
|
||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
||||
)
|
||||
),
|
||||
)
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
|
||||
self.assertTrue(self._send_notice.call_count == 0)
|
||||
self.assertEqual(self._send_notice.call_count, 0)
|
||||
|
||||
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
|
||||
"""
|
||||
Test that when a server is disabled, that MAU limit alerting is ignored.
|
||||
"""
|
||||
self.hs.config.mau_limit_alerting = False
|
||||
|
||||
self._rlsn._auth.check_auth_blocking = Mock(
|
||||
return_value=defer.succeed(None),
|
||||
side_effect=ResourceLimitError(
|
||||
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
|
||||
)
|
||||
),
|
||||
)
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
|
||||
|
@ -198,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||
"""
|
||||
self.hs.config.mau_limit_alerting = False
|
||||
self._rlsn._auth.check_auth_blocking = Mock(
|
||||
return_value=defer.succeed(None),
|
||||
side_effect=ResourceLimitError(
|
||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
|
||||
return_value=defer.succeed((True, []))
|
||||
)
|
||||
|
@ -256,7 +254,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
|
|||
def test_server_notice_only_sent_once(self):
|
||||
self.store.get_monthly_active_count = Mock(return_value=1000)
|
||||
|
||||
self.store.user_last_seen_monthly_active = Mock(return_value=1000)
|
||||
self.store.user_last_seen_monthly_active = Mock(
|
||||
return_value=defer.succeed(1000)
|
||||
)
|
||||
|
||||
# Call the function multiple times to ensure we only send the notice once
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
|
|
|
@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase):
|
|||
user_id = UserID("us", "test")
|
||||
our_user = Requester(user_id, None, False, None, None)
|
||||
room_creator = self.homeserver.get_room_creation_handler()
|
||||
room = room_creator.create_room(
|
||||
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
|
||||
room = ensureDeferred(
|
||||
room_creator.create_room(
|
||||
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
|
||||
)
|
||||
)
|
||||
self.reactor.advance(0.1)
|
||||
self.room_id = self.successResultOf(room)["room_id"]
|
||||
|
|
|
@ -512,8 +512,8 @@ class MockClock(object):
|
|||
|
||||
return t
|
||||
|
||||
def looping_call(self, function, interval):
|
||||
self.loopers.append([function, interval / 1000.0, self.now])
|
||||
def looping_call(self, function, interval, *args, **kwargs):
|
||||
self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
|
||||
|
||||
def cancel_call_later(self, timer, ignore_errs=False):
|
||||
if timer[2]:
|
||||
|
@ -543,9 +543,9 @@ class MockClock(object):
|
|||
self.timers.append(t)
|
||||
|
||||
for looped in self.loopers:
|
||||
func, interval, last = looped
|
||||
func, interval, last, args, kwargs = looped
|
||||
if last + interval < self.now:
|
||||
func()
|
||||
func(*args, **kwargs)
|
||||
looped[2] = self.now
|
||||
|
||||
def advance_time_msec(self, ms):
|
||||
|
|
3
tox.ini
3
tox.ini
|
@ -200,8 +200,9 @@ commands = mypy \
|
|||
synapse/replication \
|
||||
synapse/rest \
|
||||
synapse/spam_checker_api \
|
||||
synapse/storage/engines \
|
||||
synapse/storage/data_stores/main/ui_auth.py \
|
||||
synapse/storage/database.py \
|
||||
synapse/storage/engines \
|
||||
synapse/streams \
|
||||
synapse/util/caches/stream_change_cache.py \
|
||||
tests/replication/tcp/streams \
|
||||
|
|
Loading…
Reference in New Issue