Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes
commit
1048e2ca6a
|
@ -0,0 +1 @@
|
||||||
|
Synapse is now permissive about trailing slashes on some of its federation endpoints, allowing zero or more to be present.
|
|
@ -0,0 +1 @@
|
||||||
|
Add ability for password providers to login/register a user via 3PID (email, phone).
|
|
@ -0,0 +1 @@
|
||||||
|
The user directory has been rewritten to make it faster, with less chance of falling behind on a large server.
|
|
@ -75,6 +75,20 @@ Password auth provider classes may optionally provide the following methods.
|
||||||
result from the ``/login`` call (including ``access_token``, ``device_id``,
|
result from the ``/login`` call (including ``access_token``, ``device_id``,
|
||||||
etc.)
|
etc.)
|
||||||
|
|
||||||
|
``someprovider.check_3pid_auth``\(*medium*, *address*, *password*)
|
||||||
|
|
||||||
|
This method, if implemented, is called when a user attempts to register or
|
||||||
|
log in with a third party identifier, such as email. It is passed the
|
||||||
|
medium (ex. "email"), an address (ex. "jdoe@example.com") and the user's
|
||||||
|
password.
|
||||||
|
|
||||||
|
The method should return a Twisted ``Deferred`` object, which resolves to
|
||||||
|
a ``str`` containing the user's (canonical) User ID if authentication was
|
||||||
|
successful, and ``None`` if not.
|
||||||
|
|
||||||
|
As with ``check_auth``, the ``Deferred`` may alternatively resolve to a
|
||||||
|
``(user_id, callback)`` tuple.
|
||||||
|
|
||||||
``someprovider.check_password``\(*user_id*, *password*)
|
``someprovider.check_password``\(*user_id*, *password*)
|
||||||
|
|
||||||
This method provides a simpler interface than ``get_supported_login_types``
|
This method provides a simpler interface than ``get_supported_login_types``
|
||||||
|
|
|
@ -621,13 +621,13 @@ class Auth(object):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the the sender is allowed to redact the target event if the
|
True if the the sender is allowed to redact the target event if the
|
||||||
target event was created by them.
|
target event was created by them.
|
||||||
False if the sender is allowed to redact the target event with no
|
False if the sender is allowed to redact the target event with no
|
||||||
further checks.
|
further checks.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if the event sender is definitely not allowed to redact
|
AuthError if the event sender is definitely not allowed to redact
|
||||||
the target event.
|
the target event.
|
||||||
"""
|
"""
|
||||||
return event_auth.check_redaction(room_version, event, auth_events)
|
return event_auth.check_redaction(room_version, event, auth_events)
|
||||||
|
|
||||||
|
@ -743,9 +743,9 @@ class Auth(object):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[tuple[str, str|None]]: Resolves to the current membership of
|
Deferred[tuple[str, str|None]]: Resolves to the current membership of
|
||||||
the user in the room and the membership event ID of the user. If
|
the user in the room and the membership event ID of the user. If
|
||||||
the user is not in the room and never has been, then
|
the user is not in the room and never has been, then
|
||||||
`(Membership.JOIN, None)` is returned.
|
`(Membership.JOIN, None)` is returned.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -777,13 +777,13 @@ class Auth(object):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str|None): If present, checks for presence against existing
|
user_id(str|None): If present, checks for presence against existing
|
||||||
MAU cohort
|
MAU cohort
|
||||||
|
|
||||||
threepid(dict|None): If present, checks for presence against configured
|
threepid(dict|None): If present, checks for presence against configured
|
||||||
reserved threepid. Used in cases where the user is trying register
|
reserved threepid. Used in cases where the user is trying register
|
||||||
with a MAU blocked server, normally they would be rejected but their
|
with a MAU blocked server, normally they would be rejected but their
|
||||||
threepid is on the reserved list. user_id and
|
threepid is on the reserved list. user_id and
|
||||||
threepid should never be set at the same time.
|
threepid should never be set at the same time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Never fail an auth check for the server notices users or support user
|
# Never fail an auth check for the server notices users or support user
|
||||||
|
|
|
@ -173,7 +173,7 @@ class TransportLayerClient(object):
|
||||||
# generated by the json_data_callback.
|
# generated by the json_data_callback.
|
||||||
json_data = transaction.get_dict()
|
json_data = transaction.get_dict()
|
||||||
|
|
||||||
path = _create_v1_path("/send/%s/", transaction.transaction_id)
|
path = _create_v1_path("/send/%s", transaction.transaction_id)
|
||||||
|
|
||||||
response = yield self.client.put_json(
|
response = yield self.client.put_json(
|
||||||
transaction.destination,
|
transaction.destination,
|
||||||
|
|
|
@ -312,7 +312,7 @@ class BaseFederationServlet(object):
|
||||||
|
|
||||||
|
|
||||||
class FederationSendServlet(BaseFederationServlet):
|
class FederationSendServlet(BaseFederationServlet):
|
||||||
PATH = "/send/(?P<transaction_id>[^/]*)/"
|
PATH = "/send/(?P<transaction_id>[^/]*)/?"
|
||||||
|
|
||||||
def __init__(self, handler, server_name, **kwargs):
|
def __init__(self, handler, server_name, **kwargs):
|
||||||
super(FederationSendServlet, self).__init__(
|
super(FederationSendServlet, self).__init__(
|
||||||
|
@ -378,7 +378,7 @@ class FederationSendServlet(BaseFederationServlet):
|
||||||
|
|
||||||
|
|
||||||
class FederationEventServlet(BaseFederationServlet):
|
class FederationEventServlet(BaseFederationServlet):
|
||||||
PATH = "/event/(?P<event_id>[^/]*)/"
|
PATH = "/event/(?P<event_id>[^/]*)/?"
|
||||||
|
|
||||||
# This is when someone asks for a data item for a given server data_id pair.
|
# This is when someone asks for a data item for a given server data_id pair.
|
||||||
def on_GET(self, origin, content, query, event_id):
|
def on_GET(self, origin, content, query, event_id):
|
||||||
|
@ -386,7 +386,7 @@ class FederationEventServlet(BaseFederationServlet):
|
||||||
|
|
||||||
|
|
||||||
class FederationStateServlet(BaseFederationServlet):
|
class FederationStateServlet(BaseFederationServlet):
|
||||||
PATH = "/state/(?P<context>[^/]*)/"
|
PATH = "/state/(?P<context>[^/]*)/?"
|
||||||
|
|
||||||
# This is when someone asks for all data for a given context.
|
# This is when someone asks for all data for a given context.
|
||||||
def on_GET(self, origin, content, query, context):
|
def on_GET(self, origin, content, query, context):
|
||||||
|
@ -398,7 +398,7 @@ class FederationStateServlet(BaseFederationServlet):
|
||||||
|
|
||||||
|
|
||||||
class FederationStateIdsServlet(BaseFederationServlet):
|
class FederationStateIdsServlet(BaseFederationServlet):
|
||||||
PATH = "/state_ids/(?P<room_id>[^/]*)/"
|
PATH = "/state_ids/(?P<room_id>[^/]*)/?"
|
||||||
|
|
||||||
def on_GET(self, origin, content, query, room_id):
|
def on_GET(self, origin, content, query, room_id):
|
||||||
return self.handler.on_state_ids_request(
|
return self.handler.on_state_ids_request(
|
||||||
|
@ -409,7 +409,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
|
||||||
|
|
||||||
|
|
||||||
class FederationBackfillServlet(BaseFederationServlet):
|
class FederationBackfillServlet(BaseFederationServlet):
|
||||||
PATH = "/backfill/(?P<context>[^/]*)/"
|
PATH = "/backfill/(?P<context>[^/]*)/?"
|
||||||
|
|
||||||
def on_GET(self, origin, content, query, context):
|
def on_GET(self, origin, content, query, context):
|
||||||
versions = [x.decode('ascii') for x in query[b"v"]]
|
versions = [x.decode('ascii') for x in query[b"v"]]
|
||||||
|
@ -1080,7 +1080,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
|
||||||
"""Get all categories for a group
|
"""Get all categories for a group
|
||||||
"""
|
"""
|
||||||
PATH = (
|
PATH = (
|
||||||
"/groups/(?P<group_id>[^/]*)/categories/"
|
"/groups/(?P<group_id>[^/]*)/categories/?"
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -1150,7 +1150,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
|
||||||
"""Get roles in a group
|
"""Get roles in a group
|
||||||
"""
|
"""
|
||||||
PATH = (
|
PATH = (
|
||||||
"/groups/(?P<group_id>[^/]*)/roles/"
|
"/groups/(?P<group_id>[^/]*)/roles/?"
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -745,6 +745,42 @@ class AuthHandler(BaseHandler):
|
||||||
errcode=Codes.FORBIDDEN
|
errcode=Codes.FORBIDDEN
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def check_password_provider_3pid(self, medium, address, password):
|
||||||
|
"""Check if a password provider is able to validate a thirdparty login
|
||||||
|
|
||||||
|
Args:
|
||||||
|
medium (str): The medium of the 3pid (ex. email).
|
||||||
|
address (str): The address of the 3pid (ex. jdoe@example.com).
|
||||||
|
password (str): The password of the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[(str|None, func|None)]: A tuple of `(user_id,
|
||||||
|
callback)`. If authentication is successful, `user_id` is a `str`
|
||||||
|
containing the authenticated, canonical user ID. `callback` is
|
||||||
|
then either a function to be later run after the server has
|
||||||
|
completed login/registration, or `None`. If authentication was
|
||||||
|
unsuccessful, `user_id` and `callback` are both `None`.
|
||||||
|
"""
|
||||||
|
for provider in self.password_providers:
|
||||||
|
if hasattr(provider, "check_3pid_auth"):
|
||||||
|
# This function is able to return a deferred that either
|
||||||
|
# resolves None, meaning authentication failure, or upon
|
||||||
|
# success, to a str (which is the user_id) or a tuple of
|
||||||
|
# (user_id, callback_func), where callback_func should be run
|
||||||
|
# after we've finished everything else
|
||||||
|
result = yield provider.check_3pid_auth(
|
||||||
|
medium, address, password,
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
# Check if the return value is a str or a tuple
|
||||||
|
if isinstance(result, str):
|
||||||
|
# If it's a str, set callback function to None
|
||||||
|
result = (result, None)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
defer.returnValue((None, None))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_local_password(self, user_id, password):
|
def _check_local_password(self, user_id, password):
|
||||||
"""Authenticate a user against the local password database.
|
"""Authenticate a user against the local password database.
|
||||||
|
@ -756,7 +792,8 @@ class AuthHandler(BaseHandler):
|
||||||
user_id (unicode): complete @user:id
|
user_id (unicode): complete @user:id
|
||||||
password (unicode): the provided password
|
password (unicode): the provided password
|
||||||
Returns:
|
Returns:
|
||||||
(unicode) the canonical_user_id, or None if unknown user / bad password
|
Deferred[unicode] the canonical_user_id, or Deferred[None] if
|
||||||
|
unknown user/bad password
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
LimitExceededError if the ratelimiter's login requests count for this
|
LimitExceededError if the ratelimiter's login requests count for this
|
||||||
|
|
|
@ -147,8 +147,14 @@ class BaseProfileHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
|
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
|
||||||
"""target_user is the user whose displayname is to be changed;
|
"""Set the displayname of a user
|
||||||
auth_user is the user attempting to make this change."""
|
|
||||||
|
Args:
|
||||||
|
target_user (UserID): the user whose displayname is to be changed.
|
||||||
|
requester (Requester): The user attempting to make this change.
|
||||||
|
new_displayname (str): The displayname to give this user.
|
||||||
|
by_admin (bool): Whether this change was made by an administrator.
|
||||||
|
"""
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||||
|
|
||||||
|
|
|
@ -171,7 +171,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
api.constants.UserTypes, or None for a normal user.
|
api.constants.UserTypes, or None for a normal user.
|
||||||
default_display_name (unicode|None): if set, the new user's displayname
|
default_display_name (unicode|None): if set, the new user's displayname
|
||||||
will be set to this. Defaults to 'localpart'.
|
will be set to this. Defaults to 'localpart'.
|
||||||
address (str|None): the IP address used to perform the regitration.
|
address (str|None): the IP address used to perform the registration.
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (user_id, access_token).
|
A tuple of (user_id, access_token).
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -623,7 +623,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
admin (boolean): is an admin user?
|
admin (boolean): is an admin user?
|
||||||
user_type (str|None): type of user. One of the values from
|
user_type (str|None): type of user. One of the values from
|
||||||
api.constants.UserTypes, or None for a normal user.
|
api.constants.UserTypes, or None for a normal user.
|
||||||
address (str|None): the IP address used to perform the regitration.
|
address (str|None): the IP address used to perform the registration.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred
|
Deferred
|
||||||
|
@ -721,9 +721,9 @@ class RegistrationHandler(BaseHandler):
|
||||||
access_token (str|None): The access token of the newly logged in
|
access_token (str|None): The access token of the newly logged in
|
||||||
device, or None if `inhibit_login` enabled.
|
device, or None if `inhibit_login` enabled.
|
||||||
bind_email (bool): Whether to bind the email with the identity
|
bind_email (bool): Whether to bind the email with the identity
|
||||||
server
|
server.
|
||||||
bind_msisdn (bool): Whether to bind the msisdn with the identity
|
bind_msisdn (bool): Whether to bind the msisdn with the identity
|
||||||
server
|
server.
|
||||||
"""
|
"""
|
||||||
if self.hs.config.worker_app:
|
if self.hs.config.worker_app:
|
||||||
yield self._post_registration_client(
|
yield self._post_registration_client(
|
||||||
|
@ -765,7 +765,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
"""A user consented to the terms on registration
|
"""A user consented to the terms on registration
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The user ID that consented
|
user_id (str): The user ID that consented.
|
||||||
consent_version (str): version of the policy the user has
|
consent_version (str): version of the policy the user has
|
||||||
consented to.
|
consented to.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -231,6 +231,7 @@ class MatrixFederationHttpClient(object):
|
||||||
# Retry with a trailing slash if we received a 400 with
|
# Retry with a trailing slash if we received a 400 with
|
||||||
# 'M_UNRECOGNIZED' which some endpoints can return when omitting a
|
# 'M_UNRECOGNIZED' which some endpoints can return when omitting a
|
||||||
# trailing slash on Synapse <= v0.99.3.
|
# trailing slash on Synapse <= v0.99.3.
|
||||||
|
logger.info("Retrying request with trailing slash")
|
||||||
request.path += "/"
|
request.path += "/"
|
||||||
|
|
||||||
response = yield self._send_request(
|
response = yield self._send_request(
|
||||||
|
|
|
@ -73,14 +73,26 @@ class ModuleApi(object):
|
||||||
"""
|
"""
|
||||||
return self._auth_handler.check_user_exists(user_id)
|
return self._auth_handler.check_user_exists(user_id)
|
||||||
|
|
||||||
def register(self, localpart):
|
@defer.inlineCallbacks
|
||||||
"""Registers a new user with given localpart
|
def register(self, localpart, displayname=None):
|
||||||
|
"""Registers a new user with given localpart and optional
|
||||||
|
displayname.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
localpart (str): The localpart of the new user.
|
||||||
|
displayname (str|None): The displayname of the new user. If None,
|
||||||
|
the user's displayname will default to `localpart`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: a 2-tuple of (user_id, access_token)
|
Deferred: a 2-tuple of (user_id, access_token)
|
||||||
"""
|
"""
|
||||||
|
# Register the user
|
||||||
reg = self.hs.get_registration_handler()
|
reg = self.hs.get_registration_handler()
|
||||||
return reg.register(localpart=localpart)
|
user_id, access_token = yield reg.register(
|
||||||
|
localpart=localpart, default_display_name=displayname,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((user_id, access_token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def invalidate_access_token(self, access_token):
|
def invalidate_access_token(self, access_token):
|
||||||
|
|
|
@ -201,6 +201,24 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
# We store all email addreses as lowercase in the DB.
|
# We store all email addreses as lowercase in the DB.
|
||||||
# (See add_threepid in synapse/handlers/auth.py)
|
# (See add_threepid in synapse/handlers/auth.py)
|
||||||
address = address.lower()
|
address = address.lower()
|
||||||
|
|
||||||
|
# Check for login providers that support 3pid login types
|
||||||
|
canonical_user_id, callback_3pid = (
|
||||||
|
yield self.auth_handler.check_password_provider_3pid(
|
||||||
|
medium,
|
||||||
|
address,
|
||||||
|
login_submission["password"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if canonical_user_id:
|
||||||
|
# Authentication through password provider and 3pid succeeded
|
||||||
|
result = yield self._register_device_with_callback(
|
||||||
|
canonical_user_id, login_submission, callback_3pid,
|
||||||
|
)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
# No password providers were able to handle this 3pid
|
||||||
|
# Check local store
|
||||||
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
medium, address,
|
medium, address,
|
||||||
)
|
)
|
||||||
|
@ -223,20 +241,43 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||||
if "user" not in identifier:
|
if "user" not in identifier:
|
||||||
raise SynapseError(400, "User identifier is missing 'user' key")
|
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||||
|
|
||||||
auth_handler = self.auth_handler
|
canonical_user_id, callback = yield self.auth_handler.validate_login(
|
||||||
canonical_user_id, callback = yield auth_handler.validate_login(
|
|
||||||
identifier["user"],
|
identifier["user"],
|
||||||
login_submission,
|
login_submission,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
result = yield self._register_device_with_callback(
|
||||||
|
canonical_user_id, login_submission, callback,
|
||||||
|
)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _register_device_with_callback(
|
||||||
|
self,
|
||||||
|
user_id,
|
||||||
|
login_submission,
|
||||||
|
callback=None,
|
||||||
|
):
|
||||||
|
""" Registers a device with a given user_id. Optionally run a callback
|
||||||
|
function after registration has completed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): ID of the user to register.
|
||||||
|
login_submission (dict): Dictionary of login information.
|
||||||
|
callback (func|None): Callback function to run after registration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result (Dict[str,str]): Dictionary of account information after
|
||||||
|
successful registration.
|
||||||
|
"""
|
||||||
device_id = login_submission.get("device_id")
|
device_id = login_submission.get("device_id")
|
||||||
initial_display_name = login_submission.get("initial_device_display_name")
|
initial_display_name = login_submission.get("initial_device_display_name")
|
||||||
device_id, access_token = yield self.registration_handler.register_device(
|
device_id, access_token = yield self.registration_handler.register_device(
|
||||||
canonical_user_id, device_id, initial_display_name,
|
user_id, device_id, initial_display_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"user_id": canonical_user_id,
|
"user_id": user_id,
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
|
|
|
@ -135,7 +135,12 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _populate_user_directory_process_rooms(self, progress, batch_size):
|
def _populate_user_directory_process_rooms(self, progress, batch_size):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
progress (dict)
|
||||||
|
batch_size (int): Maximum number of state events to process
|
||||||
|
per cycle.
|
||||||
|
"""
|
||||||
state = self.hs.get_state_handler()
|
state = self.hs.get_state_handler()
|
||||||
|
|
||||||
# If we don't have progress filed, delete everything.
|
# If we don't have progress filed, delete everything.
|
||||||
|
@ -143,13 +148,14 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
||||||
yield self.delete_all_from_user_dir()
|
yield self.delete_all_from_user_dir()
|
||||||
|
|
||||||
def _get_next_batch(txn):
|
def _get_next_batch(txn):
|
||||||
|
# Only fetch 250 rooms, so we don't fetch too many at once, even
|
||||||
|
# if those 250 rooms have less than batch_size state events.
|
||||||
sql = """
|
sql = """
|
||||||
SELECT room_id FROM %s
|
SELECT room_id, events FROM %s
|
||||||
ORDER BY events DESC
|
ORDER BY events DESC
|
||||||
LIMIT %s
|
LIMIT 250
|
||||||
""" % (
|
""" % (
|
||||||
TEMP_TABLE + "_rooms",
|
TEMP_TABLE + "_rooms",
|
||||||
str(batch_size),
|
|
||||||
)
|
)
|
||||||
txn.execute(sql)
|
txn.execute(sql)
|
||||||
rooms_to_work_on = txn.fetchall()
|
rooms_to_work_on = txn.fetchall()
|
||||||
|
@ -157,8 +163,6 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
||||||
if not rooms_to_work_on:
|
if not rooms_to_work_on:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
rooms_to_work_on = [x[0] for x in rooms_to_work_on]
|
|
||||||
|
|
||||||
# Get how many are left to process, so we can give status on how
|
# Get how many are left to process, so we can give status on how
|
||||||
# far we are in processing
|
# far we are in processing
|
||||||
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
|
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
|
||||||
|
@ -180,7 +184,9 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
||||||
% (len(rooms_to_work_on), progress["remaining"])
|
% (len(rooms_to_work_on), progress["remaining"])
|
||||||
)
|
)
|
||||||
|
|
||||||
for room_id in rooms_to_work_on:
|
processed_event_count = 0
|
||||||
|
|
||||||
|
for room_id, event_count in rooms_to_work_on:
|
||||||
is_in_room = yield self.is_host_joined(room_id, self.server_name)
|
is_in_room = yield self.is_host_joined(room_id, self.server_name)
|
||||||
|
|
||||||
if is_in_room:
|
if is_in_room:
|
||||||
|
@ -247,7 +253,13 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
|
||||||
progress,
|
progress,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(len(rooms_to_work_on))
|
processed_event_count += event_count
|
||||||
|
|
||||||
|
if processed_event_count > batch_size:
|
||||||
|
# Don't process any more rooms, we've hit our batch size.
|
||||||
|
defer.returnValue(processed_event_count)
|
||||||
|
|
||||||
|
defer.returnValue(processed_event_count)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _populate_user_directory_process_users(self, progress, batch_size):
|
def _populate_user_directory_process_users(self, progress, batch_size):
|
||||||
|
|
|
@ -180,7 +180,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
put_json = self.hs.get_http_client().put_json
|
put_json = self.hs.get_http_client().put_json
|
||||||
put_json.assert_called_once_with(
|
put_json.assert_called_once_with(
|
||||||
"farm",
|
"farm",
|
||||||
path="/_matrix/federation/v1/send/1000000/",
|
path="/_matrix/federation/v1/send/1000000",
|
||||||
data=_expect_edu_transaction(
|
data=_expect_edu_transaction(
|
||||||
"m.typing",
|
"m.typing",
|
||||||
content={
|
content={
|
||||||
|
@ -202,7 +202,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
(request, channel) = self.make_request(
|
(request, channel) = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
"/_matrix/federation/v1/send/1000000/",
|
"/_matrix/federation/v1/send/1000000",
|
||||||
_make_edu_transaction_json(
|
_make_edu_transaction_json(
|
||||||
"m.typing",
|
"m.typing",
|
||||||
content={
|
content={
|
||||||
|
@ -258,7 +258,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
put_json = self.hs.get_http_client().put_json
|
put_json = self.hs.get_http_client().put_json
|
||||||
put_json.assert_called_once_with(
|
put_json.assert_called_once_with(
|
||||||
"farm",
|
"farm",
|
||||||
path="/_matrix/federation/v1/send/1000000/",
|
path="/_matrix/federation/v1/send/1000000",
|
||||||
data=_expect_edu_transaction(
|
data=_expect_edu_transaction(
|
||||||
"m.typing",
|
"m.typing",
|
||||||
content={
|
content={
|
||||||
|
|
Loading…
Reference in New Issue