_update_membership

pull/7363/head
Andrew Morgan 2020-04-28 21:09:45 +01:00
parent 047cbcfab7
commit bf700782fa
17 changed files with 343 additions and 356 deletions

View File

@ -957,14 +957,13 @@ class FederationClient(FederationBase):
return signed_events
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
async def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
if destination == self.server_name:
continue
try:
yield self.transport_layer.exchange_third_party_invite(
await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
return None

View File

@ -126,30 +126,28 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now))
)
@defer.inlineCallbacks
def maybe_kick_guest_users(self, event, context=None):
async def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
current_state_ids = yield context.get_current_state_ids()
current_state = yield self.store.get_events(
current_state_ids = await context.get_current_state_ids()
current_state = await self.store.get_events(
list(current_state_ids.values())
)
else:
current_state = yield self.state_handler.get_current_state(
current_state = await self.state_handler.get_current_state(
event.room_id
)
current_state = list(current_state.values())
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state)
await self.kick_guest_users(current_state)
@defer.inlineCallbacks
def kick_guest_users(self, current_state):
async def kick_guest_users(self, current_state):
for member_event in current_state:
try:
if member_event.type != EventTypes.Member:
@ -180,7 +178,7 @@ class BaseHandler(object):
# homeserver.
requester = synapse.types.create_requester(target_user, is_guest=True)
handler = self.hs.get_room_member_handler()
yield handler.update_membership(
await handler.update_membership(
requester,
target_user,
member_event.room_id,

View File

@ -295,15 +295,14 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND,
)
@defer.inlineCallbacks
def _update_canonical_alias(
async def _update_canonical_alias(
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
):
"""
Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field.
"""
alias_event = yield self.state.get_current_state(
alias_event = await self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, ""
)
@ -334,7 +333,7 @@ class DirectoryHandler(BaseHandler):
del content["alt_aliases"]
if send_update:
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,

View File

@ -2563,9 +2563,8 @@ class FederationHandler(BaseHandler):
"missing": [e.event_id for e in missing_locals],
}
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(
async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed
):
third_party_invite = {"signed": signed}
@ -2581,16 +2580,16 @@ class FederationHandler(BaseHandler):
"state_key": target_user_id,
}
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
room_version = yield self.store.get_room_version_id(room_id)
if await self.auth.check_host_in_room(room_id, self.hs.hostname):
room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
event_allowed = yield self.third_party_event_rules.check_event_allowed(
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@ -2602,7 +2601,7 @@ class FederationHandler(BaseHandler):
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
event, context = yield self.add_display_name_to_third_party_invite(
event, context = await self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context
)
@ -2613,19 +2612,19 @@ class FederationHandler(BaseHandler):
event.internal_metadata.send_on_behalf_of = self.hs.hostname
try:
yield self.auth.check_from_context(room_version, event, context)
await self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warning("Denying new third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
await self._check_signature(event, context)
# We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
await member_handler.send_membership_event(None, event, context)
else:
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
yield self.federation_client.forward_third_party_invite(
await self.federation_client.forward_third_party_invite(
destinations, room_id, event_dict
)

View File

@ -626,8 +626,7 @@ class EventCreationHandler(object):
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
async def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
@ -647,7 +646,7 @@ class EventCreationHandler(object):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
prev_state = await self.deduplicate_state_event(event, context)
if prev_state is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
@ -656,7 +655,7 @@ class EventCreationHandler(object):
)
return prev_state
yield self.handle_new_client_event(
await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
)
@ -683,8 +682,7 @@ class EventCreationHandler(object):
return prev_event
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
async def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None
):
"""
@ -698,8 +696,8 @@ class EventCreationHandler(object):
# a situation where event persistence can't keep up, causing
# extremities to pile up, which in turn leads to state resolution
# taking longer.
with (yield self.limiter.queue(event_dict["room_id"])):
event, context = yield self.create_event(
with (await self.limiter.queue(event_dict["room_id"])):
event, context = await self.create_event(
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
)
@ -709,7 +707,7 @@ class EventCreationHandler(object):
spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
yield self.send_nonmember_event(
await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit
)
return event
@ -770,8 +768,7 @@ class EventCreationHandler(object):
return (event, context)
@measure_func("handle_new_client_event")
@defer.inlineCallbacks
def handle_new_client_event(
async def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Processes a new event. This includes checking auth, persisting it,
@ -794,9 +791,9 @@ class EventCreationHandler(object):
):
room_version = event.content.get("room_version", RoomVersions.V1.identifier)
else:
room_version = yield self.store.get_room_version_id(event.room_id)
room_version = await self.store.get_room_version_id(event.room_id)
event_allowed = yield self.third_party_event_rules.check_event_allowed(
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@ -805,7 +802,7 @@ class EventCreationHandler(object):
)
try:
yield self.auth.check_from_context(room_version, event, context)
await self.auth.check_from_context(room_version, event, context)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
raise err
@ -818,7 +815,7 @@ class EventCreationHandler(object):
logger.exception("Failed to encode content: %r", event.content)
raise
yield self.action_generator.handle_push_actions_for_event(event, context)
await self.action_generator.handle_push_actions_for_event(event, context)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead.
@ -826,7 +823,7 @@ class EventCreationHandler(object):
try:
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
yield self.send_event_to_master(
await self.send_event_to_master(
event_id=event.event_id,
store=self.store,
requester=requester,
@ -838,7 +835,7 @@ class EventCreationHandler(object):
success = True
return
yield self.persist_and_notify_client_event(
await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
@ -883,8 +880,7 @@ class EventCreationHandler(object):
Codes.BAD_ALIAS,
)
@defer.inlineCallbacks
def persist_and_notify_client_event(
async def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Called when we have fully built the event, have already
@ -901,7 +897,7 @@ class EventCreationHandler(object):
# user is actually admin or not).
is_admin_redaction = False
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
@ -913,11 +909,11 @@ class EventCreationHandler(object):
original_event and event.sender != original_event.sender
)
yield self.base_handler.ratelimit(
await self.base_handler.ratelimit(
requester, is_admin_redaction=is_admin_redaction
)
yield self.base_handler.maybe_kick_guest_users(event, context)
await self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Validate a newly added alias or newly added alt_aliases.
@ -927,7 +923,7 @@ class EventCreationHandler(object):
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
original_event = yield self.store.get_event(original_event_id)
original_event = await self.store.get_event(original_event_id)
if original_event:
original_alias = original_event.content.get("alias", None)
@ -937,7 +933,7 @@ class EventCreationHandler(object):
room_alias_str = event.content.get("alias", None)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias_str and room_alias_str != original_alias:
yield self._validate_canonical_alias(
await self._validate_canonical_alias(
directory_handler, room_alias_str, event.room_id
)
@ -957,7 +953,7 @@ class EventCreationHandler(object):
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
if new_alt_aliases:
for alias_str in new_alt_aliases:
yield self._validate_canonical_alias(
await self._validate_canonical_alias(
directory_handler, alias_str, event.room_id
)
@ -969,7 +965,7 @@ class EventCreationHandler(object):
def is_inviter_member_event(e):
return e.type == EventTypes.Member and e.sender == event.sender
current_state_ids = yield context.get_current_state_ids()
current_state_ids = await context.get_current_state_ids()
state_to_include_ids = [
e_id
@ -978,7 +974,7 @@ class EventCreationHandler(object):
or k == (EventTypes.Member, event.sender)
]
state_to_include = yield self.store.get_events(state_to_include_ids)
state_to_include = await self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [
{
@ -996,7 +992,7 @@ class EventCreationHandler(object):
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield defer.ensureDeferred(
returned_invite = await defer.ensureDeferred(
federation_handler.send_invite(invitee.domain, event)
)
event.unsigned.pop("room_state", None)
@ -1005,7 +1001,7 @@ class EventCreationHandler(object):
event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
@ -1021,14 +1017,14 @@ class EventCreationHandler(object):
if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room")
prev_state_ids = yield context.get_prev_state_ids()
auth_events_ids = yield self.auth.compute_auth_events(
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version = yield self.store.get_room_version_id(event.room_id)
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if event_auth.check_redaction(
@ -1047,11 +1043,11 @@ class EventCreationHandler(object):
event.internal_metadata.recheck_redaction = False
if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
event_stream_id, max_stream_id = yield self.storage.persistence.persist_event(
event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
event, context=context
)
@ -1059,7 +1055,7 @@ class EventCreationHandler(object):
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify():
try:
@ -1083,13 +1079,12 @@ class EventCreationHandler(object):
except Exception:
logger.exception("Error bumping presence active time")
@defer.inlineCallbacks
def _send_dummy_events_to_fill_extremities(self):
async def _send_dummy_events_to_fill_extremities(self):
"""Background task to send dummy events into rooms that have a large
number of extremities
"""
self._expire_rooms_to_exclude_from_dummy_event_insertion()
room_ids = yield self.store.get_rooms_with_many_extremities(
room_ids = await self.store.get_rooms_with_many_extremities(
min_count=10,
limit=5,
room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
@ -1099,9 +1094,9 @@ class EventCreationHandler(object):
# For each room we need to find a joined member we can use to send
# the dummy event with.
latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
members = yield self.state.get_current_users_in_room(
members = await self.state.get_current_users_in_room(
room_id, latest_event_ids=latest_event_ids
)
dummy_event_sent = False
@ -1110,7 +1105,7 @@ class EventCreationHandler(object):
continue
requester = create_requester(user_id)
try:
event, context = yield self.create_event(
event, context = await self.create_event(
requester,
{
"type": "org.matrix.dummy_event",
@ -1123,7 +1118,7 @@ class EventCreationHandler(object):
event.internal_metadata.proactively_send = False
yield self.send_nonmember_event(
await self.send_nonmember_event(
requester, event, context, ratelimit=False
)
dummy_event_sent = True

View File

@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler):
return result["displayname"]
@defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
async def set_displayname(
self, target_user, requester, new_displayname, by_admin=False
):
"""Set the displayname of a user
Args:
@ -158,7 +159,7 @@ class BaseProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's displayname")
if not by_admin and not self.hs.config.enable_set_displayname:
profile = yield self.store.get_profileinfo(target_user.localpart)
profile = await self.store.get_profileinfo(target_user.localpart)
if profile.display_name:
raise SynapseError(
400,
@ -180,15 +181,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
yield self.store.set_profile_displayname(target_user.localpart, new_displayname)
await self.store.set_profile_displayname(target_user.localpart, new_displayname)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change(
profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
yield self._update_join_states(requester, target_user)
await self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@ -217,8 +218,9 @@ class BaseProfileHandler(BaseHandler):
return result["avatar_url"]
@defer.inlineCallbacks
def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
async def set_avatar_url(
self, target_user, requester, new_avatar_url, by_admin=False
):
"""target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user):
@ -228,7 +230,7 @@ class BaseProfileHandler(BaseHandler):
raise AuthError(400, "Cannot set another user's avatar_url")
if not by_admin and not self.hs.config.enable_set_avatar_url:
profile = yield self.store.get_profileinfo(target_user.localpart)
profile = await self.store.get_profileinfo(target_user.localpart)
if profile.avatar_url:
raise SynapseError(
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
@ -243,15 +245,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
yield self.user_directory_handler.handle_local_profile_change(
profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
yield self._update_join_states(requester, target_user)
await self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def on_profile_query(self, args):
@ -279,21 +281,20 @@ class BaseProfileHandler(BaseHandler):
return response
@defer.inlineCallbacks
def _update_join_states(self, requester, target_user):
async def _update_join_states(self, requester, target_user):
if not self.hs.is_mine(target_user):
return
yield self.ratelimit(requester)
await self.ratelimit(requester)
room_ids = yield self.store.get_rooms_for_user(target_user.to_string())
room_ids = await self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids:
handler = self.hs.get_room_member_handler()
try:
# Assume the target_user isn't a guest,
# because we don't let guests set profile or avatar data.
yield handler.update_membership(
await handler.update_membership(
requester,
target_user,
room_id,

View File

@ -145,9 +145,9 @@ class RegistrationHandler(BaseHandler):
"""Registers a new client on the server.
Args:
localpart : The local part of the user ID to register. If None,
localpart: The local part of the user ID to register. If None,
one will be generated.
password (unicode) : The password to assign to this user so they can
password (unicode): The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
user_type (str|None): type of user. One of the values from
@ -242,7 +242,7 @@ class RegistrationHandler(BaseHandler):
fail_count += 1
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
yield defer.ensureDeferred(self._auto_join_rooms(user_id))
else:
logger.info(
"Skipping auto-join for %s because consent is required at registration",
@ -264,8 +264,7 @@ class RegistrationHandler(BaseHandler):
return user_id
@defer.inlineCallbacks
def _auto_join_rooms(self, user_id):
async def _auto_join_rooms(self, user_id):
"""Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created.
@ -279,9 +278,9 @@ class RegistrationHandler(BaseHandler):
# that an auto-generated support or bot user is not a real user and will never be
# the user to create the room
should_auto_create_rooms = False
is_real_user = yield self.store.is_real_user(user_id)
is_real_user = await self.store.is_real_user(user_id)
if self.hs.config.autocreate_auto_join_rooms and is_real_user:
count = yield self.store.count_real_users()
count = await self.store.count_real_users()
should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms:
logger.info("Auto-joining %s to %s", user_id, r)
@ -300,7 +299,7 @@ class RegistrationHandler(BaseHandler):
# getting the RoomCreationHandler during init gives a dependency
# loop
yield self.hs.get_room_creation_handler().create_room(
await self.hs.get_room_creation_handler().create_room(
fake_requester,
config={
"preset": "public_chat",
@ -309,7 +308,7 @@ class RegistrationHandler(BaseHandler):
ratelimit=False,
)
else:
yield self._join_user_to_room(fake_requester, r)
await self._join_user_to_room(fake_requester, r)
except ConsentNotGivenError as e:
# Technically not necessary to pull out this error though
# moving away from bare excepts is a good thing to do.
@ -317,15 +316,14 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
@defer.inlineCallbacks
def post_consent_actions(self, user_id):
async def post_consent_actions(self, user_id):
"""A series of registration actions that can only be carried out once consent
has been granted
Args:
user_id (str): The user to join
"""
yield self._auto_join_rooms(user_id)
await self._auto_join_rooms(user_id)
@defer.inlineCallbacks
def appservice_register(self, user_localpart, as_token):
@ -392,14 +390,13 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id += 1
return str(id)
@defer.inlineCallbacks
def _join_user_to_room(self, requester, room_identifier):
async def _join_user_to_room(self, requester, room_identifier):
room_member_handler = self.hs.get_room_member_handler()
if RoomID.is_valid(room_identifier):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
room_id, remote_room_hosts = await room_member_handler.lookup_room_alias(
room_alias
)
room_id = room_id.to_string()
@ -408,7 +405,7 @@ class RegistrationHandler(BaseHandler):
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
yield room_member_handler.update_membership(
await room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@ -546,8 +543,7 @@ class RegistrationHandler(BaseHandler):
return (device_id, access_token)
@defer.inlineCallbacks
def post_registration_actions(self, user_id, auth_result, access_token):
async def post_registration_actions(self, user_id, auth_result, access_token):
"""A user has completed registration
Args:
@ -558,7 +554,7 @@ class RegistrationHandler(BaseHandler):
device, or None if `inhibit_login` enabled.
"""
if self.hs.config.worker_app:
yield self._post_registration_client(
await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token
)
return
@ -570,19 +566,18 @@ class RegistrationHandler(BaseHandler):
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
yield self.store.upsert_monthly_active_user(user_id)
await self.store.upsert_monthly_active_user(user_id)
yield self._register_email_threepid(user_id, threepid, access_token)
await self._register_email_threepid(user_id, threepid, access_token)
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
yield self._register_msisdn_threepid(user_id, threepid)
await self._register_msisdn_threepid(user_id, threepid)
if auth_result and LoginType.TERMS in auth_result:
yield self._on_user_consented(user_id, self.hs.config.user_consent_version)
await self._on_user_consented(user_id, self.hs.config.user_consent_version)
@defer.inlineCallbacks
def _on_user_consented(self, user_id, consent_version):
async def _on_user_consented(self, user_id, consent_version):
"""A user consented to the terms on registration
Args:
@ -591,8 +586,8 @@ class RegistrationHandler(BaseHandler):
consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
yield self.store.user_set_consent_version(user_id, consent_version)
yield self.post_consent_actions(user_id)
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
@defer.inlineCallbacks
def _register_email_threepid(self, user_id, threepid, token):

View File

@ -148,17 +148,16 @@ class RoomCreationHandler(BaseHandler):
return ret
@defer.inlineCallbacks
def _upgrade_room(
async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
user_id = requester.user.to_string()
# start by allocating a new room id
r = yield self.store.get_room(old_room_id)
r = await self.store.get_room(old_room_id)
if r is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,))
new_room_id = yield self._generate_room_id(
new_room_id = await self._generate_room_id(
creator_id=user_id, is_public=r["is_public"], room_version=new_version,
)
@ -169,7 +168,7 @@ class RoomCreationHandler(BaseHandler):
(
tombstone_event,
tombstone_context,
) = yield self.event_creation_handler.create_event(
) = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Tombstone,
@ -183,12 +182,12 @@ class RoomCreationHandler(BaseHandler):
},
token_id=requester.access_token_id,
)
old_room_version = yield self.store.get_room_version_id(old_room_id)
yield self.auth.check_from_context(
old_room_version = await self.store.get_room_version_id(old_room_id)
await self.auth.check_from_context(
old_room_version, tombstone_event, tombstone_context
)
yield self.clone_existing_room(
await self.clone_existing_room(
requester,
old_room_id=old_room_id,
new_room_id=new_room_id,
@ -197,32 +196,31 @@ class RoomCreationHandler(BaseHandler):
)
# now send the tombstone
yield self.event_creation_handler.send_nonmember_event(
await self.event_creation_handler.send_nonmember_event(
requester, tombstone_event, tombstone_context
)
old_room_state = yield tombstone_context.get_current_state_ids()
old_room_state = await tombstone_context.get_current_state_ids()
# update any aliases
yield self._move_aliases_to_new_room(
await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state
)
# Copy over user push rules, tags and migrate room directory state
yield self.room_member_handler.transfer_room_state_on_room_upgrade(
await self.room_member_handler.transfer_room_state_on_room_upgrade(
old_room_id, new_room_id
)
# finally, shut down the PLs in the old room, and update them in the new
# room.
yield self._update_upgraded_room_pls(
await self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state,
)
return new_room_id
@defer.inlineCallbacks
def _update_upgraded_room_pls(
async def _update_upgraded_room_pls(
self,
requester: Requester,
old_room_id: str,
@ -249,7 +247,7 @@ class RoomCreationHandler(BaseHandler):
)
return
old_room_pl_state = yield self.store.get_event(old_room_pl_event_id)
old_room_pl_state = await self.store.get_event(old_room_pl_event_id)
# we try to stop regular users from speaking by setting the PL required
# to send regular events and invites to 'Moderator' level. That's normally
@ -278,7 +276,7 @@ class RoomCreationHandler(BaseHandler):
if updated:
try:
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.PowerLevels,
@ -292,7 +290,7 @@ class RoomCreationHandler(BaseHandler):
except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e)
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.PowerLevels,
@ -304,8 +302,7 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False,
)
@defer.inlineCallbacks
def clone_existing_room(
async def clone_existing_room(
self,
requester: Requester,
old_room_id: str,
@ -338,7 +335,7 @@ class RoomCreationHandler(BaseHandler):
# Check if old room was non-federatable
# Get old room's create event
old_room_create_event = yield self.store.get_create_event_for_room(old_room_id)
old_room_create_event = await self.store.get_create_event_for_room(old_room_id)
# Check if the create event specified a non-federatable room
if not old_room_create_event.content.get("m.federate", True):
@ -361,11 +358,11 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.PowerLevels, ""),
)
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
old_room_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types(types_to_copy)
)
# map from event_id to BaseEvent
old_room_state_events = yield self.store.get_events(old_room_state_ids.values())
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
for k, old_event_id in iteritems(old_room_state_ids):
old_event = old_room_state_events.get(old_event_id)
@ -400,7 +397,7 @@ class RoomCreationHandler(BaseHandler):
if current_power_level < needed_power_level:
power_levels["users"][user_id] = needed_power_level
yield self._send_events_for_new_room(
await self._send_events_for_new_room(
requester,
new_room_id,
# we expect to override all the presets with initial_state, so this is
@ -412,12 +409,12 @@ class RoomCreationHandler(BaseHandler):
)
# Transfer membership events
old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
)
# map from event_id to BaseEvent
old_room_member_state_events = yield self.store.get_events(
old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
for k, old_event in iteritems(old_room_member_state_events):
@ -426,7 +423,7 @@ class RoomCreationHandler(BaseHandler):
"membership" in old_event.content
and old_event.content["membership"] == "ban"
):
yield self.room_member_handler.update_membership(
await self.room_member_handler.update_membership(
requester,
UserID.from_string(old_event["state_key"]),
new_room_id,
@ -438,8 +435,7 @@ class RoomCreationHandler(BaseHandler):
# XXX invites/joins
# XXX 3pid invites
@defer.inlineCallbacks
def _move_aliases_to_new_room(
async def _move_aliases_to_new_room(
self,
requester: Requester,
old_room_id: str,
@ -448,13 +444,13 @@ class RoomCreationHandler(BaseHandler):
):
directory_handler = self.hs.get_handlers().directory_handler
aliases = yield self.store.get_aliases_for_room(old_room_id)
aliases = await self.store.get_aliases_for_room(old_room_id)
# check to see if we have a canonical alias.
canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event_id:
canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
canonical_alias_event = await self.store.get_event(canonical_alias_event_id)
# first we try to remove the aliases from the old room (we suppress sending
# the room_aliases event until the end).
@ -472,7 +468,7 @@ class RoomCreationHandler(BaseHandler):
for alias_str in aliases:
alias = RoomAlias.from_string(alias_str)
try:
yield directory_handler.delete_association(requester, alias)
await directory_handler.delete_association(requester, alias)
removed_aliases.append(alias_str)
except SynapseError as e:
logger.warning("Unable to remove alias %s from old room: %s", alias, e)
@ -485,7 +481,7 @@ class RoomCreationHandler(BaseHandler):
# we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases:
try:
yield directory_handler.create_association(
await directory_handler.create_association(
requester,
RoomAlias.from_string(alias),
new_room_id,
@ -502,7 +498,7 @@ class RoomCreationHandler(BaseHandler):
# alias event for the new room with a copy of the information.
try:
if canonical_alias_event:
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
@ -518,8 +514,9 @@ class RoomCreationHandler(BaseHandler):
# we returned the new room to the client at this point.
logger.error("Unable to send updated alias events in new room: %s", e)
@defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True, creator_join_profile=None):
async def create_room(
self, requester, config, ratelimit=True, creator_join_profile=None
):
""" Creates a new room.
Args:
@ -547,7 +544,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
yield self.auth.check_auth_blocking(user_id)
await self.auth.check_auth_blocking(user_id)
if (
self._server_notices_mxid is not None
@ -556,11 +553,11 @@ class RoomCreationHandler(BaseHandler):
# allow the server notices mxid to create rooms
is_requester_admin = True
else:
is_requester_admin = yield self.auth.is_server_admin(requester.user)
is_requester_admin = await self.auth.is_server_admin(requester.user)
# Check whether the third party rules allows/changes the room create
# request.
event_allowed = yield self.third_party_event_rules.on_create_room(
event_allowed = await self.third_party_event_rules.on_create_room(
requester, config, is_requester_admin=is_requester_admin
)
if not event_allowed:
@ -574,7 +571,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit:
yield self.ratelimit(requester)
await self.ratelimit(requester)
room_version_id = config.get(
"room_version", self.config.default_room_version.identifier
@ -597,7 +594,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(400, "Invalid characters in room alias")
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
mapping = yield self.store.get_association_from_room_alias(room_alias)
mapping = await self.store.get_association_from_room_alias(room_alias)
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
@ -612,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
yield self.event_creation_handler.assert_accepted_privacy_policy(requester)
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")
if (
@ -631,13 +628,13 @@ class RoomCreationHandler(BaseHandler):
visibility = config.get("visibility", None)
is_public = visibility == "public"
room_id = yield self._generate_room_id(
room_id = await self._generate_room_id(
creator_id=user_id, is_public=is_public, room_version=room_version,
)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias:
yield directory_handler.create_association(
await directory_handler.create_association(
requester=requester,
room_id=room_id,
room_alias=room_alias,
@ -663,7 +660,7 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier
yield self._send_events_for_new_room(
await self._send_events_for_new_room(
requester,
room_id,
preset_config=preset_config,
@ -677,7 +674,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config:
name = config["name"]
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@ -691,7 +688,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@ -709,7 +706,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
yield self.room_member_handler.update_membership(
await self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
@ -723,7 +720,7 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
yield self.hs.get_room_member_handler().do_3pid_invite(
await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@ -741,8 +738,7 @@ class RoomCreationHandler(BaseHandler):
return result
@defer.inlineCallbacks
def _send_events_for_new_room(
async def _send_events_for_new_room(
self,
creator, # A Requester object.
room_id,
@ -762,11 +758,10 @@ class RoomCreationHandler(BaseHandler):
return e
@defer.inlineCallbacks
def send(etype, content, **kwargs):
async def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False
)
@ -777,10 +772,10 @@ class RoomCreationHandler(BaseHandler):
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id})
yield send(etype=EventTypes.Create, content=creation_content)
await send(etype=EventTypes.Create, content=creation_content)
logger.debug("Sending %s in new room", EventTypes.Member)
yield self.room_member_handler.update_membership(
await self.room_member_handler.update_membership(
creator,
creator.user,
room_id,
@ -793,7 +788,7 @@ class RoomCreationHandler(BaseHandler):
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
yield send(etype=EventTypes.PowerLevels, content=pl_content)
await send(etype=EventTypes.PowerLevels, content=pl_content)
else:
power_level_content = {
"users": {creator_id: 100},
@ -825,33 +820,33 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override:
power_level_content.update(power_level_content_override)
yield send(etype=EventTypes.PowerLevels, content=power_level_content)
await send(etype=EventTypes.PowerLevels, content=power_level_content)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
yield send(
await send(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
if (EventTypes.JoinRules, "") not in initial_state:
yield send(
await send(
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
)
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
yield send(
await send(
etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]},
)
if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state:
yield send(
await send(
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
)
for (etype, state_key), content in initial_state.items():
yield send(etype=etype, state_key=state_key, content=content)
await send(etype=etype, state_key=state_key, content=content)
@defer.inlineCallbacks
def _generate_room_id(

View File

@ -142,8 +142,7 @@ class RoomMemberHandler(object):
"""
raise NotImplementedError()
@defer.inlineCallbacks
def _local_membership_update(
async def _local_membership_update(
self,
requester,
target,
@ -164,7 +163,7 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
event, context = yield self.event_creation_handler.create_event(
event, context = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
@ -182,18 +181,18 @@ class RoomMemberHandler(object):
)
# Check if this event matches the previous membership event for the user.
duplicate = yield self.event_creation_handler.deduplicate_state_event(
duplicate = await self.event_creation_handler.deduplicate_state_event(
event, context
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return duplicate
yield self.event_creation_handler.handle_new_client_event(
await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit
)
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@ -203,15 +202,19 @@ class RoomMemberHandler(object):
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield self._user_joined_room(target, room_id)
d = self._user_joined_room(target, room_id)
if d:
await d
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target, room_id)
d = self._user_left_room(target, room_id)
if d:
await d
return event
@ -253,8 +256,7 @@ class RoomMemberHandler(object):
for tag, tag_content in room_tags.items():
yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
@defer.inlineCallbacks
def update_membership(
async def update_membership(
self,
requester,
target,
@ -269,8 +271,8 @@ class RoomMemberHandler(object):
):
key = (room_id,)
with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership(
with (await self.member_linearizer.queue(key)):
result = await self._update_membership(
requester,
target,
room_id,
@ -285,8 +287,7 @@ class RoomMemberHandler(object):
return result
@defer.inlineCallbacks
def _update_membership(
async def _update_membership(
self,
requester,
target,
@ -321,7 +322,7 @@ class RoomMemberHandler(object):
# if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join.
if third_party_signed is not None:
yield self.federation_handler.exchange_third_party_invite(
await self.federation_handler.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
@ -332,7 +333,7 @@ class RoomMemberHandler(object):
remote_room_hosts = []
if effective_membership_state not in ("leave", "ban"):
is_blocked = yield self.store.is_room_blocked(room_id)
is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
@ -351,7 +352,7 @@ class RoomMemberHandler(object):
is_requester_admin = True
else:
is_requester_admin = yield self.auth.is_server_admin(requester.user)
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
if self.config.block_non_admin_invites:
@ -370,9 +371,9 @@ class RoomMemberHandler(object):
if block_invite:
raise SynapseError(403, "Invites have been disabled on this server")
latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids(
current_state_ids = await self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids
)
@ -381,7 +382,7 @@ class RoomMemberHandler(object):
# transitions and generic otherwise
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
if old_state_id:
old_state = yield self.store.get_event(old_state_id, allow_none=True)
old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
@ -413,7 +414,7 @@ class RoomMemberHandler(object):
old_membership == Membership.INVITE
and effective_membership_state == Membership.LEAVE
):
is_blocked = yield self._is_server_notice_room(room_id)
is_blocked = await self._is_server_notice_room(room_id)
if is_blocked:
raise SynapseError(
http_client.FORBIDDEN,
@ -424,18 +425,18 @@ class RoomMemberHandler(object):
if action == "kick":
raise AuthError(403, "The target user is not in the room")
is_host_in_room = yield self._is_host_in_room(current_state_ids)
is_host_in_room = await self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN:
if requester.is_guest:
guest_can_join = yield self._can_guest_join(current_state_ids)
guest_can_join = await self._can_guest_join(current_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
inviter = yield self._get_inviter(target.to_string(), room_id)
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@ -443,13 +444,13 @@ class RoomMemberHandler(object):
profile = self.profile_handler
if not content_specified:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
content["displayname"] = await profile.get_displayname(target)
content["avatar_url"] = await profile.get_avatar_url(target)
if requester.is_guest:
content["kind"] = "guest"
remote_join_response = yield self._remote_join(
remote_join_response = await self._remote_join(
requester, remote_room_hosts, room_id, target, content
)
@ -458,7 +459,7 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
inviter = yield self._get_inviter(target.to_string(), room_id)
inviter = await self._get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
@ -472,12 +473,12 @@ class RoomMemberHandler(object):
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
res = yield self._remote_reject_invite(
res = await self._remote_reject_invite(
requester, remote_room_hosts, room_id, target, content,
)
return res
res = yield self._local_membership_update(
res = await self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
@ -572,8 +573,7 @@ class RoomMemberHandler(object):
)
continue
@defer.inlineCallbacks
def send_membership_event(self, requester, event, context, ratelimit=True):
async def send_membership_event(self, requester, event, context, ratelimit=True):
"""
Change the membership status of a user in a room.
@ -599,27 +599,27 @@ class RoomMemberHandler(object):
else:
requester = types.create_requester(target_user)
prev_event = yield self.event_creation_handler.deduplicate_state_event(
prev_event = await self.event_creation_handler.deduplicate_state_event(
event, context
)
if prev_event is not None:
return
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids()
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = yield self._can_guest_join(prev_state_ids)
guest_can_join = await self._can_guest_join(prev_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if event.membership not in (Membership.LEAVE, Membership.BAN):
is_blocked = yield self.store.is_room_blocked(room_id)
is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
yield self.event_creation_handler.handle_new_client_event(
await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target_user], ratelimit=ratelimit
)
@ -633,15 +633,15 @@ class RoomMemberHandler(object):
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield self._user_joined_room(target_user, room_id)
await self._user_joined_room(target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target_user, room_id)
await self._user_left_room(target_user, room_id)
@defer.inlineCallbacks
def _can_guest_join(self, current_state_ids):
@ -699,8 +699,7 @@ class RoomMemberHandler(object):
if invite:
return UserID.from_string(invite.sender)
@defer.inlineCallbacks
def do_3pid_invite(
async def do_3pid_invite(
self,
room_id,
inviter,
@ -712,7 +711,7 @@ class RoomMemberHandler(object):
id_access_token=None,
):
if self.config.block_non_admin_invites:
is_requester_admin = yield self.auth.is_server_admin(requester.user)
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
raise SynapseError(
403, "Invites have been disabled on this server", Codes.FORBIDDEN
@ -720,9 +719,9 @@ class RoomMemberHandler(object):
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
yield self.base_handler.ratelimit(requester)
await self.base_handler.ratelimit(requester)
can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id
)
if not can_invite:
@ -737,16 +736,16 @@ class RoomMemberHandler(object):
403, "Looking up third-party identifiers is denied from this server"
)
invitee = yield self.identity_handler.lookup_3pid(
invitee = await self.identity_handler.lookup_3pid(
id_server, medium, address, id_access_token
)
if invitee:
yield self.update_membership(
await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
yield self._make_and_store_3pid_invite(
await self._make_and_store_3pid_invite(
requester,
id_server,
medium,
@ -757,8 +756,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
async def _make_and_store_3pid_invite(
self,
requester,
id_server,
@ -769,7 +767,7 @@ class RoomMemberHandler(object):
txn_id,
id_access_token=None,
):
room_state = yield self.state_handler.get_current_state(room_id)
room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
@ -807,7 +805,7 @@ class RoomMemberHandler(object):
public_keys,
fallback_public_key,
display_name,
) = yield self.identity_handler.ask_id_server_for_third_party_invite(
) = await self.identity_handler.ask_id_server_for_third_party_invite(
requester=requester,
id_server=id_server,
medium=medium,
@ -823,7 +821,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
yield self.event_creation_handler.create_and_send_nonmember_event(
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@ -917,8 +915,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return complexity["v1"] > max_complexity
@defer.inlineCallbacks
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join
"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
@ -933,7 +930,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if self.hs.config.limit_remote_rooms.enabled:
# Fetch the room complexity
too_complex = yield self._is_remote_room_too_complex(
too_complex = await self._is_remote_room_too_complex(
room_id, remote_room_hosts
)
if too_complex is True:
@ -947,12 +944,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield defer.ensureDeferred(
self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content
)
await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content
)
yield self._user_joined_room(user, room_id)
d = self._user_joined_room(user, room_id)
if d:
await d
# Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before.
@ -962,7 +959,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return
# Check again, but with the local state events
too_complex = yield self._is_local_room_too_complex(room_id)
too_complex = await self._is_local_room_too_complex(room_id)
if too_complex is False:
# We're under the limit.
@ -970,7 +967,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# The room is too large. Leave.
requester = types.create_requester(user, None, False, None)
yield self.update_membership(
await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
raise SynapseError(

View File

@ -16,8 +16,6 @@ import logging
from six import iteritems, string_types
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder
from synapse.config import ConfigError
@ -59,8 +57,7 @@ class ConsentServerNotices(object):
self._consent_uri_builder = ConsentURIBuilder(hs.config)
@defer.inlineCallbacks
def maybe_send_server_notice_to_user(self, user_id):
async def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, and does so if so
Args:
@ -78,7 +75,7 @@ class ConsentServerNotices(object):
return
self._users_in_progress.add(user_id)
try:
u = yield self._store.get_user_by_id(user_id)
u = await self._store.get_user_by_id(user_id)
if u["is_guest"] and not self._send_to_guests:
# don't send to guests
@ -100,8 +97,8 @@ class ConsentServerNotices(object):
content = copy_with_str_subst(
self._server_notice_content, {"consent_uri": consent_uri}
)
yield self._server_notices_manager.send_notice(user_id, content)
yield self._store.user_set_consent_server_notice_sent(
await self._server_notices_manager.send_notice(user_id, content)
await self._store.user_set_consent_server_notice_sent(
user_id, self._current_consent_version
)
except SynapseError as e:

View File

@ -50,8 +50,7 @@ class ResourceLimitsServerNotices(object):
self._notifier = hs.get_notifier()
@defer.inlineCallbacks
def maybe_send_server_notice_to_user(self, user_id):
async def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, this will be true in
two cases.
1. The server has reached its limit does not reflect this
@ -74,13 +73,13 @@ class ResourceLimitsServerNotices(object):
# Don't try and send server notices unless they've been enabled
return
timestamp = yield self._store.user_last_seen_monthly_active(user_id)
timestamp = await self._store.user_last_seen_monthly_active(user_id)
if timestamp is None:
# This user will be blocked from receiving the notice anyway.
# In practice, not sure we can ever get here
return
room_id = yield self._server_notices_manager.get_or_create_notice_room_for_user(
room_id = await self._server_notices_manager.get_or_create_notice_room_for_user(
user_id
)
@ -88,10 +87,10 @@ class ResourceLimitsServerNotices(object):
logger.warning("Failed to get server notices room")
return
yield self._check_and_set_tags(user_id, room_id)
await self._check_and_set_tags(user_id, room_id)
# Determine current state of room
currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id)
currently_blocked, ref_events = await self._is_room_currently_blocked(room_id)
limit_msg = None
limit_type = None
@ -99,7 +98,7 @@ class ResourceLimitsServerNotices(object):
# Normally should always pass in user_id to check_auth_blocking
# if you have it, but in this case are checking what would happen
# to other users if they were to arrive.
yield self._auth.check_auth_blocking()
await self._auth.check_auth_blocking()
except ResourceLimitError as e:
limit_msg = e.msg
limit_type = e.limit_type
@ -112,22 +111,21 @@ class ResourceLimitsServerNotices(object):
# We have hit the MAU limit, but MAU alerting is disabled:
# reset room if necessary and return
if currently_blocked:
self._remove_limit_block_notification(user_id, ref_events)
await self._remove_limit_block_notification(user_id, ref_events)
return
if currently_blocked and not limit_msg:
# Room is notifying of a block, when it ought not to be.
yield self._remove_limit_block_notification(user_id, ref_events)
await self._remove_limit_block_notification(user_id, ref_events)
elif not currently_blocked and limit_msg:
# Room is not notifying of a block, when it ought to be.
yield self._apply_limit_block_notification(
await self._apply_limit_block_notification(
user_id, limit_msg, limit_type
)
except SynapseError as e:
logger.error("Error sending resource limits server notice: %s", e)
@defer.inlineCallbacks
def _remove_limit_block_notification(self, user_id, ref_events):
async def _remove_limit_block_notification(self, user_id, ref_events):
"""Utility method to remove limit block notifications from the server
notices room.
@ -137,12 +135,13 @@ class ResourceLimitsServerNotices(object):
limit blocking and need to be preserved.
"""
content = {"pinned": ref_events}
yield self._server_notices_manager.send_notice(
await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Pinned, ""
)
@defer.inlineCallbacks
def _apply_limit_block_notification(self, user_id, event_body, event_limit_type):
async def _apply_limit_block_notification(
self, user_id, event_body, event_limit_type
):
"""Utility method to apply limit block notifications in the server
notices room.
@ -159,12 +158,12 @@ class ResourceLimitsServerNotices(object):
"admin_contact": self._config.admin_contact,
"limit_type": event_limit_type,
}
event = yield self._server_notices_manager.send_notice(
event = await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Message
)
content = {"pinned": [event.event_id]}
yield self._server_notices_manager.send_notice(
await self._server_notices_manager.send_notice(
user_id, content, EventTypes.Pinned, ""
)
@ -198,7 +197,7 @@ class ResourceLimitsServerNotices(object):
room_id(str): The room id of the server notices room
Returns:
Deferred[bool|List]:
bool: Is the room currently blocked
list: The list of pinned events that are unrelated to limit blocking
This list can be used as a convenience in the case where the block

View File

@ -14,11 +14,9 @@
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
from synapse.types import UserID, create_requester
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@ -51,8 +49,7 @@ class ServerNoticesManager(object):
"""
return self._config.server_notices_mxid is not None
@defer.inlineCallbacks
def send_notice(
async def send_notice(
self, user_id, event_content, type=EventTypes.Message, state_key=None
):
"""Send a notice to the given user
@ -68,8 +65,8 @@ class ServerNoticesManager(object):
Returns:
Deferred[FrozenEvent]
"""
room_id = yield self.get_or_create_notice_room_for_user(user_id)
yield self.maybe_invite_user_to_room(user_id, room_id)
room_id = await self.get_or_create_notice_room_for_user(user_id)
await self.maybe_invite_user_to_room(user_id, room_id)
system_mxid = self._config.server_notices_mxid
requester = create_requester(system_mxid)
@ -86,13 +83,13 @@ class ServerNoticesManager(object):
if state_key is not None:
event_dict["state_key"] = state_key
res = yield self._event_creation_handler.create_and_send_nonmember_event(
res = await self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False
)
return res
@cachedInlineCallbacks()
def get_or_create_notice_room_for_user(self, user_id):
@cached()
async def get_or_create_notice_room_for_user(self, user_id):
"""Get the room for notices for a given user
If we have not yet created a notice room for this user, create it, but don't
@ -109,7 +106,7 @@ class ServerNoticesManager(object):
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
for room in rooms:
@ -118,7 +115,7 @@ class ServerNoticesManager(object):
# be joined. This is kinda deliberate, in that if somebody somehow
# manages to invite the system user to a room, that doesn't make it
# the server notices room.
user_ids = yield self._store.get_users_in_room(room.room_id)
user_ids = await self._store.get_users_in_room(room.room_id)
if self.server_notices_mxid in user_ids:
# we found a room which our user shares with the system notice
# user
@ -146,7 +143,7 @@ class ServerNoticesManager(object):
}
requester = create_requester(self.server_notices_mxid)
info = yield self._room_creation_handler.create_room(
info = await self._room_creation_handler.create_room(
requester,
config={
"preset": RoomCreationPreset.PRIVATE_CHAT,
@ -158,7 +155,7 @@ class ServerNoticesManager(object):
)
room_id = info["room_id"]
max_id = yield self._store.add_tag_to_room(
max_id = await self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
@ -166,8 +163,7 @@ class ServerNoticesManager(object):
logger.info("Created server notices room %s for %s", room_id, user_id)
return room_id
@defer.inlineCallbacks
def maybe_invite_user_to_room(self, user_id: str, room_id: str):
async def maybe_invite_user_to_room(self, user_id: str, room_id: str):
"""Invite the given user to the given server room, unless the user has already
joined or been invited to it.
@ -179,14 +175,14 @@ class ServerNoticesManager(object):
# Check whether the user has already joined or been invited to this room. If
# that's the case, there is no need to re-invite them.
joined_rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
joined_rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
for room in joined_rooms:
if room.room_id == room_id:
return
yield self._room_member_handler.update_membership(
await self._room_member_handler.update_membership(
requester=requester,
target=UserID.from_string(user_id),
room_id=room_id,

View File

@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.server_notices.consent_server_notices import ConsentServerNotices
from synapse.server_notices.resource_limits_server_notices import (
ResourceLimitsServerNotices,
@ -36,18 +34,16 @@ class ServerNoticesSender(object):
ResourceLimitsServerNotices(hs),
)
@defer.inlineCallbacks
def on_user_syncing(self, user_id):
async def on_user_syncing(self, user_id):
"""Called when the user performs a sync operation.
Args:
user_id (str): mxid of user who synced
"""
for sn in self._server_notices:
yield sn.maybe_send_server_notice_to_user(user_id)
await sn.maybe_send_server_notice_to_user(user_id)
@defer.inlineCallbacks
def on_user_ip(self, user_id):
async def on_user_ip(self, user_id):
"""Called on the master when a worker process saw a client request.
Args:
@ -57,4 +53,4 @@ class ServerNoticesSender(object):
# we check for notices to send to the user in on_user_ip as well as
# in on_user_syncing
for sn in self._server_notices:
yield sn.maybe_send_server_notice_to_user(user_id)
await sn.maybe_send_server_notice_to_user(user_id)

View File

@ -82,18 +82,26 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_name(self):
yield self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
yield defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
)
self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank Jr.",
)
# Set displayname again
yield self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank"
yield defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank"
)
)
self.assertEquals(
@ -112,16 +120,20 @@ class ProfileTestCase(unittest.TestCase):
)
# Setting displayname a second time is forbidden
d = self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
d = defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
)
)
yield self.assertFailure(d, SynapseError)
@defer.inlineCallbacks
def test_set_my_name_noauth(self):
d = self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
d = defer.ensureDeferred(
self.handler.set_displayname(
self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
)
)
yield self.assertFailure(d, AuthError)
@ -165,10 +177,12 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_avatar(self):
yield self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
yield defer.ensureDeferred(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
)
)
self.assertEquals(
@ -177,10 +191,12 @@ class ProfileTestCase(unittest.TestCase):
)
# Set avatar again
yield self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/me.png",
yield defer.ensureDeferred(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/me.png",
)
)
self.assertEquals(
@ -203,10 +219,12 @@ class ProfileTestCase(unittest.TestCase):
)
# Set avatar a second time is forbidden
d = self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
d = defer.ensureDeferred(
self.handler.set_avatar_url(
self.frank,
synapse.types.create_requester(self.frank),
"http://my.server/pic.gif",
)
)
yield self.assertFailure(d, SynapseError)

View File

@ -175,7 +175,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.is_real_user = Mock(return_value=False)
self.store.is_real_user = Mock(return_value=defer.succeed(False))
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@ -187,8 +187,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.count_real_users = Mock(return_value=1)
self.store.is_real_user = Mock(return_value=True)
self.store.count_real_users = Mock(return_value=defer.succeed(1))
self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_handlers().directory_handler
@ -202,8 +202,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
self.hs.config.auto_join_rooms = [room_alias_str]
self.store.count_real_users = Mock(return_value=2)
self.store.is_real_user = Mock(return_value=True)
self.store.count_real_users = Mock(return_value=defer.succeed(2))
self.store.is_real_user = Mock(return_value=defer.succeed(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@ -256,8 +256,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart=invalid_user_id), SynapseError
)
@defer.inlineCallbacks
def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
async def get_or_create_user(
self, requester, localpart, displayname, password_hash=None
):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@ -272,11 +273,11 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
yield self.hs.get_auth().check_auth_blocking()
await self.hs.get_auth().check_auth_blocking()
need_register = True
try:
yield self.handler.check_username(localpart)
await self.handler.check_username(localpart)
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
need_register = False
@ -288,21 +289,21 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
token = self.macaroon_generator.generate_access_token(user_id)
if need_register:
yield self.handler.register_with_store(
await self.handler.register_with_store(
user_id=user_id,
password_hash=password_hash,
create_profile_with_displayname=user.localpart,
)
else:
yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
await self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
yield self.store.add_access_token_to_user(
await self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None
)
if displayname is not None:
# logger.info("setting user display name: %s -> %s", user_id, displayname)
yield self.hs.get_profile_handler().set_displayname(
await self.hs.get_profile_handler().set_displayname(
user, requester, displayname, by_admin=True
)

View File

@ -55,25 +55,18 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
self._rlsn._server_notices_manager.send_notice = Mock()
self._rlsn._state.get_current_state = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_events = Mock(return_value=defer.succeed({}))
self._rlsn._server_notices_manager.send_notice = Mock(
return_value=defer.succeed(Mock())
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
self.hs.config.limit_usage_by_mau = True
self.user_id = "@user_id:test"
# self.server_notices_mxid = "@server:test"
# self.server_notices_mxid_display_name = None
# self.server_notices_mxid_avatar_url = None
# self.server_notices_room_name = "Server Notices"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
returnValue=""
return_value=defer.succeed("!something:localhost")
)
self._rlsn._store.add_tag_to_room = Mock()
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_tags_for_room = Mock(return_value={})
self.hs.config.admin_contact = "mailto:user@test.com"
@ -95,14 +88,13 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed"""
self._rlsn._auth.check_auth_blocking = Mock()
self._rlsn._auth.check_auth_blocking = Mock(defer.succeed(None))
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
self._send_notice.assert_called_once()
@ -112,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, "foo")
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
)
mock_event = Mock(
@ -121,6 +113,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self._send_notice.assert_not_called()
@ -129,9 +122,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, but should have one
"""
self._rlsn._auth.check_auth_blocking = Mock(
side_effect=ResourceLimitError(403, "foo")
return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -142,7 +134,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock()
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -153,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
self._rlsn._auth.check_auth_blocking = Mock()
self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(None)
)
@ -167,24 +159,28 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
an alert message is not sent into the room
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
)
),
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
self.assertTrue(self._send_notice.call_count == 0)
self.assertEqual(self._send_notice.call_count, 0)
def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
)
),
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -198,10 +194,12 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
self.hs.config.mau_limit_alerting = False
self._rlsn._auth.check_auth_blocking = Mock(
return_value=defer.succeed(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
)
),
)
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
return_value=defer.succeed((True, []))
)
@ -256,7 +254,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
def test_server_notice_only_sent_once(self):
self.store.get_monthly_active_count = Mock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock(return_value=1000)
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
)
# Call the function multiple times to ensure we only send the notice once
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))

View File

@ -27,8 +27,10 @@ class MessageAcceptTests(unittest.TestCase):
user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, None, None)
room_creator = self.homeserver.get_room_creation_handler()
room = room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
room = ensureDeferred(
room_creator.create_room(
our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False
)
)
self.reactor.advance(0.1)
self.room_id = self.successResultOf(room)["room_id"]