From 16090a077f9f387dcd42edabda58063e3df6b771 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 15 May 2020 17:17:42 +0100 Subject: [PATCH 01/14] Prevent 0-member/null room_version rooms from appearing in group room queries (#7465) --- changelog.d/7465.bugfix | 1 + .../storage/data_stores/main/group_server.py | 92 ++++++++++++++++--- 2 files changed, 79 insertions(+), 14 deletions(-) create mode 100644 changelog.d/7465.bugfix diff --git a/changelog.d/7465.bugfix b/changelog.d/7465.bugfix new file mode 100644 index 0000000000..1cbe50caa5 --- /dev/null +++ b/changelog.d/7465.bugfix @@ -0,0 +1 @@ +Prevent rooms with 0 members or with invalid version strings from breaking group queries. \ No newline at end of file diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py index 0963e6c250..fb1361f1c1 100644 --- a/synapse/storage/data_stores/main/group_server.py +++ b/synapse/storage/data_stores/main/group_server.py @@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore): desc="get_invited_users_in_group", ) - def get_rooms_in_group(self, group_id, include_private=False): + def get_rooms_in_group(self, group_id: str, include_private: bool = False): + """Retrieve the rooms that belong to a given group. Does not return rooms that + lack members. + + Args: + group_id: The ID of the group to query for rooms + include_private: Whether to return private rooms in results + + Returns: + Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the + form of: + + { + "room_id": "!a_room_id:example.com", # The ID of the room + "is_public": False # Whether this is a public room or not + } + """ # TODO: Pagination - keyvalues = {"group_id": group_id} - if not include_private: - keyvalues["is_public"] = True + def _get_rooms_in_group_txn(txn): + sql = """ + SELECT room_id, is_public FROM group_rooms + WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) + """ + args = [group_id, False] - return self.db.simple_select_list( - table="group_rooms", - keyvalues=keyvalues, - retcols=("room_id", "is_public"), - desc="get_rooms_in_group", - ) + if not include_private: + sql += " AND is_public = ?" + args += [True] - def get_rooms_for_summary_by_category(self, group_id, include_private=False): + txn.execute(sql, args) + + return [ + {"room_id": room_id, "is_public": is_public} + for room_id, is_public in txn + ] + + return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn) + + def get_rooms_for_summary_by_category( + self, group_id: str, include_private: bool = False, + ): """Get the rooms and categories that should be included in a summary request - Returns ([rooms], [categories]) + Args: + group_id: The ID of the group to query the summary for + include_private: Whether to return private rooms in results + + Returns: + Deferred[Tuple[List, Dict]]: A tuple containing: + + * A list of dictionaries with the keys: + * "room_id": str, the room ID + * "is_public": bool, whether the room is public + * "category_id": str|None, the category ID if set, else None + * "order": int, the sort order of rooms + + * A dictionary with the key: + * category_id (str): a dictionary with the keys: + * "is_public": bool, whether the category is public + * "profile": str, the category profile + * "order": int, the sort order of rooms in this category """ def _get_rooms_for_summary_txn(txn): @@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore): SELECT room_id, is_public, category_id, room_order FROM group_summary_rooms WHERE group_id = ? + AND room_id IN ( + SELECT group_rooms.room_id FROM group_rooms + LEFT JOIN room_stats_current ON + group_rooms.room_id = room_stats_current.room_id + AND joined_members > 0 + AND local_users_in_room > 0 + LEFT JOIN rooms ON + group_rooms.room_id = rooms.room_id + AND (room_version <> '') = ? + ) """ if not include_private: sql += " AND is_public = ?" - txn.execute(sql, (group_id, True)) + txn.execute(sql, (group_id, False, True)) else: - txn.execute(sql, (group_id,)) + txn.execute(sql, (group_id, False)) rooms = [ { From 03aff4c75ed3b0b106ed1395b3d03b1ab9b013a6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 15 May 2020 17:22:47 +0100 Subject: [PATCH 02/14] Add a worker store for search insertion. (#7516) This is required as both event persistence and the background update needs access to this function. It should be perfectly safe for two workers to write to that table at the same time. --- changelog.d/7516.misc | 1 + synapse/app/generic_worker.py | 2 + synapse/storage/data_stores/main/search.py | 96 +++++++++++----------- 3 files changed, 52 insertions(+), 47 deletions(-) create mode 100644 changelog.d/7516.misc diff --git a/changelog.d/7516.misc b/changelog.d/7516.misc new file mode 100644 index 0000000000..94b0fd49b2 --- /dev/null +++ b/changelog.d/7516.misc @@ -0,0 +1 @@ +Add a worker store for search insertion, required for moving event persistence off master. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 2e3add7ac5..ab801108ca 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -122,6 +122,7 @@ from synapse.storage.data_stores.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) from synapse.storage.data_stores.main.presence import UserPresenceState +from synapse.storage.data_stores.main.search import SearchWorkerStore from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt @@ -451,6 +452,7 @@ class GenericWorkerSlavedStore( SlavedFilteringStore, MonthlyActiveUsersWorkerStore, MediaRepositoryStore, + SearchWorkerStore, BaseSlavedStore, ): def __init__(self, database, db_conn, hs): diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index ee75b92344..13f49d8060 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -37,7 +37,55 @@ SearchEntry = namedtuple( ) -class SearchBackgroundUpdateStore(SQLBaseStore): +class SearchWorkerStore(SQLBaseStore): + def store_search_entries_txn(self, txn, entries): + """Add entries to the search table + + Args: + txn (cursor): + entries (iterable[SearchEntry]): + entries to be added to the table + """ + if not self.hs.config.enable_search: + return + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search" + " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" + " VALUES (?,?,?,to_tsvector('english', ?),?,?)" + ) + + args = ( + ( + entry.event_id, + entry.room_id, + entry.key, + entry.value, + entry.stream_ordering, + entry.origin_server_ts, + ) + for entry in entries + ) + + txn.executemany(sql, args) + + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) + args = ( + (entry.event_id, entry.room_id, entry.key, entry.value) + for entry in entries + ) + + txn.executemany(sql, args) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + +class SearchBackgroundUpdateStore(SearchWorkerStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" @@ -296,52 +344,6 @@ class SearchBackgroundUpdateStore(SQLBaseStore): return num_rows - def store_search_entries_txn(self, txn, entries): - """Add entries to the search table - - Args: - txn (cursor): - entries (iterable[SearchEntry]): - entries to be added to the table - """ - if not self.hs.config.enable_search: - return - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) - - args = ( - ( - entry.event_id, - entry.room_id, - entry.key, - entry.value, - entry.stream_ordering, - entry.origin_server_ts, - ) - for entry in entries - ) - - txn.executemany(sql, args) - - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - args = ( - (entry.event_id, entry.room_id, entry.key, entry.value) - for entry in entries - ) - - txn.executemany(sql, args) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - class SearchStore(SearchBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): From a3cf36f76ed41222241393adf608d0e640bb51b8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 May 2020 12:26:02 -0400 Subject: [PATCH 03/14] Support UI Authentication for OpenID Connect accounts (#7457) --- changelog.d/7457.feature | 1 + synapse/handlers/auth.py | 4 +- synapse/handlers/oidc_handler.py | 76 +++++++++++++++++++++------- synapse/rest/client/v1/login.py | 31 +++++++----- synapse/rest/client/v2_alpha/auth.py | 19 +++++-- tests/handlers/test_oidc.py | 15 ++++-- 6 files changed, 105 insertions(+), 41 deletions(-) create mode 100644 changelog.d/7457.feature diff --git a/changelog.d/7457.feature b/changelog.d/7457.feature new file mode 100644 index 0000000000..7ad767bf71 --- /dev/null +++ b/changelog.d/7457.feature @@ -0,0 +1 @@ +Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs). diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 524281d2f1..75b39e878c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -80,7 +80,9 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled - self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled + self._sso_enabled = ( + hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled + ) # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 178f263439..4ba8c7fda5 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -311,7 +311,7 @@ class OidcHandler: ``ClientAuth`` to authenticate with the client with its ID and secret. Args: - code: The autorization code we got from the callback. + code: The authorization code we got from the callback. Returns: A dict containing various tokens. @@ -497,11 +497,14 @@ class OidcHandler: return UserInfo(claims) async def handle_redirect_request( - self, request: SynapseRequest, client_redirect_url: bytes - ) -> None: + self, + request: SynapseRequest, + client_redirect_url: bytes, + ui_auth_session_id: Optional[str] = None, + ) -> str: """Handle an incoming request to /login/sso/redirect - It redirects the browser to the authorization endpoint with a few + It returns a redirect to the authorization endpoint with a few parameters: - ``client_id``: the client ID set in ``oidc_config.client_id`` @@ -511,24 +514,32 @@ class OidcHandler: - ``state``: a random string - ``nonce``: a random string - In addition to redirecting the client, we are setting a cookie with + In addition generating a redirect URL, we are setting a cookie with a signed macaroon token containing the state, the nonce and the client_redirect_url params. Those are then checked when the client comes back from the provider. - Args: request: the incoming request from the browser. We'll respond to it with a redirect and a cookie. client_redirect_url: the URL that we should redirect the client to when everything is done + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). + + Returns: + The redirect URL to the authorization endpoint. + """ state = generate_token() nonce = generate_token() cookie = self._generate_oidc_session_token( - state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(), + state=state, + nonce=nonce, + client_redirect_url=client_redirect_url.decode(), + ui_auth_session_id=ui_auth_session_id, ) request.addCookie( SESSION_COOKIE_NAME, @@ -541,7 +552,7 @@ class OidcHandler: metadata = await self.load_metadata() authorization_endpoint = metadata.get("authorization_endpoint") - uri = prepare_grant_uri( + return prepare_grant_uri( authorization_endpoint, client_id=self._client_auth.client_id, response_type="code", @@ -550,8 +561,6 @@ class OidcHandler: state=state, nonce=nonce, ) - request.redirect(uri) - finish_request(request) async def handle_oidc_callback(self, request: SynapseRequest) -> None: """Handle an incoming request to /_synapse/oidc/callback @@ -625,7 +634,11 @@ class OidcHandler: # Deserialize the session token and verify it. try: - nonce, client_redirect_url = self._verify_oidc_session_token(session, state) + ( + nonce, + client_redirect_url, + ui_auth_session_id, + ) = self._verify_oidc_session_token(session, state) except MacaroonDeserializationException as e: logger.exception("Invalid session") self._render_error(request, "invalid_session", str(e)) @@ -678,15 +691,21 @@ class OidcHandler: return # and finally complete the login - await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url - ) + if ui_auth_session_id: + await self._auth_handler.complete_sso_ui_auth( + user_id, ui_auth_session_id, request + ) + else: + await self._auth_handler.complete_sso_login( + user_id, request, client_redirect_url + ) def _generate_oidc_session_token( self, state: str, nonce: str, client_redirect_url: str, + ui_auth_session_id: Optional[str], duration_in_ms: int = (60 * 60 * 1000), ) -> str: """Generates a signed token storing data about an OIDC session. @@ -702,6 +721,8 @@ class OidcHandler: nonce: The ``nonce`` parameter passed to the OIDC provider. client_redirect_url: The URL the client gave when it initiated the flow. + ui_auth_session_id: The session ID of the ongoing UI Auth (or + None if this is a login). duration_in_ms: An optional duration for the token in milliseconds. Defaults to an hour. @@ -718,12 +739,19 @@ class OidcHandler: macaroon.add_first_party_caveat( "client_redirect_url = %s" % (client_redirect_url,) ) + if ui_auth_session_id: + macaroon.add_first_party_caveat( + "ui_auth_session_id = %s" % (ui_auth_session_id,) + ) now = self._clock.time_msec() expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) + return macaroon.serialize() - def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]: + def _verify_oidc_session_token( + self, session: str, state: str + ) -> Tuple[str, str, Optional[str]]: """Verifies and extract an OIDC session token. This verifies that a given session token was issued by this homeserver @@ -734,7 +762,7 @@ class OidcHandler: state: The state the OIDC provider gave back Returns: - The nonce and the client_redirect_url for this session + The nonce, client_redirect_url, and ui_auth_session_id for this session """ macaroon = pymacaroons.Macaroon.deserialize(session) @@ -744,17 +772,27 @@ class OidcHandler: v.satisfy_exact("state = %s" % (state,)) v.satisfy_general(lambda c: c.startswith("nonce = ")) v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) + # Sometimes there's a UI auth session ID, it seems to be OK to attempt + # to always satisfy this. + v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) v.satisfy_general(self._verify_expiry) v.verify(macaroon, self._macaroon_secret_key) - # Extract the `nonce` and `client_redirect_url` from the token + # Extract the `nonce`, `client_redirect_url`, and maybe the + # `ui_auth_session_id` from the token. nonce = self._get_value_from_macaroon(macaroon, "nonce") client_redirect_url = self._get_value_from_macaroon( macaroon, "client_redirect_url" ) + try: + ui_auth_session_id = self._get_value_from_macaroon( + macaroon, "ui_auth_session_id" + ) # type: Optional[str] + except ValueError: + ui_auth_session_id = None - return nonce, client_redirect_url + return nonce, client_redirect_url, ui_auth_session_id def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: """Extracts a caveat value from a macaroon token. @@ -773,7 +811,7 @@ class OidcHandler: for caveat in macaroon.caveats: if caveat.caveat_id.startswith(prefix): return caveat.caveat_id[len(prefix) :] - raise Exception("No %s caveat in macaroon" % (key,)) + raise ValueError("No %s caveat in macaroon" % (key,)) def _verify_expiry(self, caveat: str) -> bool: prefix = "time < " diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index de7eca21f8..d89b2e5532 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -401,19 +401,22 @@ class BaseSSORedirectServlet(RestServlet): PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) - def on_GET(self, request: SynapseRequest): + async def on_GET(self, request: SynapseRequest): args = request.args if b"redirectUrl" not in args: return 400, "Redirect URL not specified for SSO auth" client_redirect_url = args[b"redirectUrl"][0] - sso_url = self.get_sso_url(client_redirect_url) + sso_url = await self.get_sso_url(request, client_redirect_url) request.redirect(sso_url) finish_request(request) - def get_sso_url(self, client_redirect_url: bytes) -> bytes: + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: """Get the URL to redirect to, to perform SSO auth Args: + request: The client request to redirect. client_redirect_url: the URL that we should redirect the client to when everything is done @@ -428,7 +431,9 @@ class CasRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): self._cas_handler = hs.get_cas_handler() - def get_sso_url(self, client_redirect_url: bytes) -> bytes: + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: return self._cas_handler.get_redirect_url( {"redirectUrl": client_redirect_url} ).encode("ascii") @@ -465,11 +470,13 @@ class SAMLRedirectServlet(BaseSSORedirectServlet): def __init__(self, hs): self._saml_handler = hs.get_saml_handler() - def get_sso_url(self, client_redirect_url: bytes) -> bytes: + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: return self._saml_handler.handle_redirect_request(client_redirect_url) -class OIDCRedirectServlet(RestServlet): +class OIDCRedirectServlet(BaseSSORedirectServlet): """Implementation for /login/sso/redirect for the OIDC login flow.""" PATTERNS = client_patterns("/login/sso/redirect", v1=True) @@ -477,12 +484,12 @@ class OIDCRedirectServlet(RestServlet): def __init__(self, hs): self._oidc_handler = hs.get_oidc_handler() - async def on_GET(self, request): - args = request.args - if b"redirectUrl" not in args: - return 400, "Redirect URL not specified for SSO auth" - client_redirect_url = args[b"redirectUrl"][0] - await self._oidc_handler.handle_redirect_request(request, client_redirect_url) + async def get_sso_url( + self, request: SynapseRequest, client_redirect_url: bytes + ) -> bytes: + return await self._oidc_handler.handle_redirect_request( + request, client_redirect_url + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 24dd3d3e96..7bca1326d5 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -131,14 +131,19 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() # SSO configuration. - self._saml_enabled = hs.config.saml2_enabled - if self._saml_enabled: - self._saml_handler = hs.get_saml_handler() self._cas_enabled = hs.config.cas_enabled if self._cas_enabled: self._cas_handler = hs.get_cas_handler() self._cas_server_url = hs.config.cas_server_url self._cas_service_url = hs.config.cas_service_url + self._saml_enabled = hs.config.saml2_enabled + if self._saml_enabled: + self._saml_handler = hs.get_saml_handler() + self._oidc_enabled = hs.config.oidc_enabled + if self._oidc_enabled: + self._oidc_handler = hs.get_oidc_handler() + self._cas_server_url = hs.config.cas_server_url + self._cas_service_url = hs.config.cas_service_url async def on_GET(self, request, stagetype): session = parse_string(request, "session") @@ -172,11 +177,17 @@ class AuthRestServlet(RestServlet): ) elif self._saml_enabled: - client_redirect_url = "" + client_redirect_url = b"" sso_redirect_url = self._saml_handler.handle_redirect_request( client_redirect_url, session ) + elif self._oidc_enabled: + client_redirect_url = b"" + sso_redirect_url = await self._oidc_handler.handle_redirect_request( + request, client_redirect_url, session + ) + else: raise SynapseError(400, "Homeserver not configured for SSO.") diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 61963aa90d..1bb25ab684 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -292,11 +292,10 @@ class OidcHandlerTestCase(HomeserverTestCase): @defer.inlineCallbacks def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" - req = Mock(spec=["addCookie", "redirect", "finish"]) - yield defer.ensureDeferred( + req = Mock(spec=["addCookie"]) + url = yield defer.ensureDeferred( self.handler.handle_redirect_request(req, b"http://client/redirect") ) - url = req.redirect.call_args[0][0] url = urlparse(url) auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT) @@ -382,7 +381,10 @@ class OidcHandlerTestCase(HomeserverTestCase): nonce = "nonce" client_redirect_url = "http://client/redirect" session = self.handler._generate_oidc_session_token( - state=state, nonce=nonce, client_redirect_url=client_redirect_url, + state=state, + nonce=nonce, + client_redirect_url=client_redirect_url, + ui_auth_session_id=None, ) request.getCookie.return_value = session @@ -472,7 +474,10 @@ class OidcHandlerTestCase(HomeserverTestCase): # Mismatching session session = self.handler._generate_oidc_session_token( - state="state", nonce="nonce", client_redirect_url="http://client/redirect", + state="state", + nonce="nonce", + client_redirect_url="http://client/redirect", + ui_auth_session_id=None, ) request.args = {} request.args[b"state"] = [b"mismatching state"] From 34a43f0084ae2e16eb6e8f6eb2ecf4bb37bd4cf0 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 15 May 2020 18:53:31 +0100 Subject: [PATCH 04/14] Fix a couple of small typos --- synapse/appservice/__init__.py | 2 +- synapse/storage/data_stores/main/appservice.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index aea3985a5f..1b13e84425 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -270,7 +270,7 @@ class ApplicationService(object): def is_exclusive_room(self, room_id): return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) - def get_exlusive_user_regexes(self): + def get_exclusive_user_regexes(self): """Get the list of regexes used to determine if a user is exclusively registered by the AS """ diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py index efbc06c796..7a1fe8cdd2 100644 --- a/synapse/storage/data_stores/main/appservice.py +++ b/synapse/storage/data_stores/main/appservice.py @@ -30,12 +30,12 @@ logger = logging.getLogger(__name__) def _make_exclusive_regex(services_cache): - # We precompie a regex constructed from all the regexes that the AS's + # We precompile a regex constructed from all the regexes that the AS's # have registered for exclusive users. exclusive_user_regexes = [ regex.pattern for service in services_cache - for regex in service.get_exlusive_user_regexes() + for regex in service.get_exclusive_user_regexes() ] if exclusive_user_regexes: exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) From 6c1f7c722f0baade9aecf41f600fcced670c4fcb Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 15 May 2020 19:03:25 +0100 Subject: [PATCH 05/14] Fix limit logic for AccountDataStream (#7384) Make sure that the AccountDataStream presents complete updates, in the right order. This is much the same fix as #7337 and #7358, but applied to a different stream. --- changelog.d/7384.bugfix | 1 + synapse/replication/tcp/streams/_base.py | 68 ++++++++-- .../storage/data_stores/main/account_data.py | 68 ++++++---- .../tcp/streams/test_account_data.py | 117 ++++++++++++++++++ 4 files changed, 220 insertions(+), 34 deletions(-) create mode 100644 changelog.d/7384.bugfix create mode 100644 tests/replication/tcp/streams/test_account_data.py diff --git a/changelog.d/7384.bugfix b/changelog.d/7384.bugfix new file mode 100644 index 0000000000..f49c600173 --- /dev/null +++ b/changelog.d/7384.bugfix @@ -0,0 +1 @@ +Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind. diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index b48a6a3e91..d42aaff055 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -14,14 +14,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import heapq import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + Optional, + Tuple, + TypeVar, +) import attr from synapse.replication.http.streams import ReplicationGetStreamUpdates +if TYPE_CHECKING: + import synapse.server + logger = logging.getLogger(__name__) # the number of rows to request from an update_function. @@ -37,7 +50,7 @@ Token = int # parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's # just a row from a database query, though this is dependent on the stream in question. # -StreamRow = Tuple +StreamRow = TypeVar("StreamRow", bound=Tuple) # The type returned by the update_function of a stream, as well as get_updates(), # get_updates_since, etc. @@ -533,32 +546,63 @@ class AccountDataStream(Stream): """ AccountDataStreamRow = namedtuple( - "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str + "AccountDataStream", + ("user_id", "room_id", "data_type"), # str # Optional[str] # str ) NAME = "account_data" ROW_TYPE = AccountDataStreamRow - def __init__(self, hs): + def __init__(self, hs: "synapse.server.HomeServer"): self.store = hs.get_datastore() super().__init__( hs.get_instance_name(), current_token_without_instance(self.store.get_max_account_data_stream_id), - db_query_to_update_function(self._update_function), + self._update_function, ) - async def _update_function(self, from_token, to_token, limit): - global_results, room_results = await self.store.get_all_updated_account_data( - from_token, from_token, to_token, limit + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + limited = False + global_results = await self.store.get_updated_global_account_data( + from_token, to_token, limit ) - results = list(room_results) - results.extend( - (stream_id, user_id, None, account_data_type) + # if the global results hit the limit, we'll need to limit the room results to + # the same stream token. + if len(global_results) >= limit: + to_token = global_results[-1][0] + limited = True + + room_results = await self.store.get_updated_room_account_data( + from_token, to_token, limit + ) + + # likewise, if the room results hit the limit, limit the global results to + # the same stream token. + if len(room_results) >= limit: + to_token = room_results[-1][0] + limited = True + + # convert the global results to the right format, and limit them to the to_token + # at the same time + global_rows = ( + (stream_id, (user_id, None, account_data_type)) for stream_id, user_id, account_data_type in global_results + if stream_id <= to_token ) - return results + # we know that the room_results are already limited to `to_token` so no need + # for a check on `stream_id` here. + room_rows = ( + (stream_id, (user_id, room_id, account_data_type)) + for stream_id, user_id, room_id, account_data_type in room_results + ) + + # we need to return a sorted list, so merge them together. + updates = list(heapq.merge(room_rows, global_rows)) + return updates, to_token, limited class GroupServerStream(Stream): diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py index 46b494b334..f9eef1b78e 100644 --- a/synapse/storage/data_stores/main/account_data.py +++ b/synapse/storage/data_stores/main/account_data.py @@ -16,6 +16,7 @@ import abc import logging +from typing import List, Tuple from canonicaljson import json @@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore): "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn ) - def get_all_updated_account_data( - self, last_global_id, last_room_id, current_id, limit - ): - """Get all the client account_data that has changed on the server - Args: - last_global_id(int): The position to fetch from for top level data - last_room_id(int): The position to fetch from for per room data - current_id(int): The position to fetch up to. - Returns: - A deferred pair of lists of tuples of stream_id int, user_id string, - room_id string, and type string. - """ - if last_room_id == current_id and last_global_id == current_id: - return defer.succeed(([], [])) + async def get_updated_global_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str]]: + """Get the global account_data that has changed, for the account_data stream - def get_updated_account_data_txn(txn): + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: + A list of tuples of stream_id int, user_id string, + and type string. + """ + if last_id == current_id: + return [] + + def get_updated_global_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_global_id, current_id, limit)) - global_results = txn.fetchall() + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() + return await self.db.runInteraction( + "get_updated_global_account_data", get_updated_global_account_data_txn + ) + + async def get_updated_room_account_data( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple[int, str, str, str]]: + """Get the global account_data that has changed, for the account_data stream + + Args: + last_id: the last stream_id from the previous batch. + current_id: the maximum stream_id to return up to + limit: the maximum number of rows to return + + Returns: + A list of tuples of stream_id int, user_id string, + room_id string and type string. + """ + if last_id == current_id: + return [] + + def get_updated_room_account_data_txn(txn): sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) - txn.execute(sql, (last_room_id, current_id, limit)) - room_results = txn.fetchall() - return global_results, room_results + txn.execute(sql, (last_id, current_id, limit)) + return txn.fetchall() - return self.db.runInteraction( - "get_all_updated_account_data_txn", get_updated_account_data_txn + return await self.db.runInteraction( + "get_updated_room_account_data", get_updated_room_account_data_txn ) def get_updated_account_data_for_user(self, user_id, stream_id): diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py new file mode 100644 index 0000000000..6a5116dd2a --- /dev/null +++ b/tests/replication/tcp/streams/test_account_data.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.replication.tcp.streams._base import ( + _STREAM_UPDATE_TARGET_ROW_COUNT, + AccountDataStream, +) + +from tests.replication._base import BaseStreamTestCase + + +class AccountDataStreamTestCase(BaseStreamTestCase): + def test_update_function_room_account_data_limit(self): + """Test replication with many room account data updates + """ + store = self.hs.get_datastore() + + # generate lots of account data updates + updates = [] + for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5): + update = "m.test_type.%i" % (i,) + self.get_success( + store.add_account_data_to_room("test_user", "test_room", update, {}) + ) + updates.append(update) + + # also one global update + self.get_success(store.add_account_data_for_user("test_user", "m.global", {})) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + + for t in updates: + (stream_name, token, row) = received_rows.pop(0) + self.assertEqual(stream_name, AccountDataStream.NAME) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, t) + self.assertEqual(row.room_id, "test_room") + + (stream_name, token, row) = received_rows.pop(0) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, "m.global") + self.assertIsNone(row.room_id) + + self.assertEqual([], received_rows) + + def test_update_function_global_account_data_limit(self): + """Test replication with many global account data updates + """ + store = self.hs.get_datastore() + + # generate lots of account data updates + updates = [] + for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5): + update = "m.test_type.%i" % (i,) + self.get_success(store.add_account_data_for_user("test_user", update, {})) + updates.append(update) + + # also one per-room update + self.get_success( + store.add_account_data_to_room("test_user", "test_room", "m.per_room", {}) + ) + + # tell the notifier to catch up to avoid duplicate rows. + # workaround for https://github.com/matrix-org/synapse/issues/7360 + # FIXME remove this when the above is fixed + self.replicate() + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # now reconnect to pull the updates + self.reconnect() + self.replicate() + + # we should have received all the expected rows in the right order + received_rows = self.test_handler.received_rdata_rows + + for t in updates: + (stream_name, token, row) = received_rows.pop(0) + self.assertEqual(stream_name, AccountDataStream.NAME) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, t) + self.assertIsNone(row.room_id) + + (stream_name, token, row) = received_rows.pop(0) + self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow) + self.assertEqual(row.data_type, "m.per_room") + self.assertEqual(row.room_id, "test_room") + + self.assertEqual([], received_rows) From 08fa96f03037178620f5f0dd609fac52fbf7f2d1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:07:24 +0100 Subject: [PATCH 06/14] Remove `exception_to_unicode` this is a no-op on python 3. --- synapse/storage/database.py | 15 +++------------ synapse/util/stringutils.py | 36 ------------------------------------ 2 files changed, 3 insertions(+), 48 deletions(-) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index c3d0863429..9947dbce77 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -50,7 +50,6 @@ from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor from synapse.types import Collection -from synapse.util.stringutils import exception_to_unicode logger = logging.getLogger(__name__) @@ -424,20 +423,14 @@ class Database(object): # This can happen if the database disappears mid # transaction. logger.warning( - "[TXN OPERROR] {%s} %s %d/%d", - name, - exception_to_unicode(e), - i, - N, + "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N, ) if i < N: i += 1 try: conn.rollback() except self.engine.module.Error as e1: - logger.warning( - "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1) - ) + logger.warning("[TXN EROLL] {%s} %s", name, e1) continue raise except self.engine.module.DatabaseError as e: @@ -449,9 +442,7 @@ class Database(object): conn.rollback() except self.engine.module.Error as e1: logger.warning( - "[TXN EROLL] {%s} %s", - name, - exception_to_unicode(e1), + "[TXN EROLL] {%s} %s", name, e1, ) continue raise diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 6899bcb788..2cfa5cf721 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -85,42 +85,6 @@ def to_ascii(s): return s -def exception_to_unicode(e): - """Helper function to extract the text of an exception as a unicode string - - Args: - e (Exception): exception to be stringified - - Returns: - unicode - """ - # urgh, this is a mess. The basic problem here is that psycopg2 constructs its - # exceptions with PyErr_SetString, with a (possibly non-ascii) argument. str() will - # then produce the raw byte sequence. Under Python 2, this will then cause another - # error if it gets mixed with a `unicode` object, as per - # https://github.com/matrix-org/synapse/issues/4252 - - # First of all, if we're under python3, everything is fine because it will sort this - # nonsense out for us. - if not PY2: - return str(e) - - # otherwise let's have a stab at decoding the exception message. We'll circumvent - # Exception.__str__(), which would explode if someone raised Exception(u'non-ascii') - # and instead look at what is in the args member. - - if len(e.args) == 0: - return "" - elif len(e.args) > 1: - return six.text_type(repr(e.args)) - - msg = e.args[0] - if isinstance(msg, bytes): - return msg.decode("utf-8", errors="replace") - else: - return msg - - def assert_valid_client_secret(client_secret): """Validate that a given string matches the client_secret regex defined by the spec""" if client_secret_regex.match(client_secret) is None: From 65902e08c3f4449de9baa4e6466f126585f688b3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:12:03 +0100 Subject: [PATCH 07/14] remove to_ascii this is a no-op on python 3. --- .../storage/data_stores/main/roommember.py | 25 ++++++++----------- synapse/storage/data_stores/main/state.py | 5 +--- synapse/util/stringutils.py | 20 +-------------- 3 files changed, 12 insertions(+), 38 deletions(-) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 48810a3e91..1e9c850152 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -45,7 +45,6 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.metrics import Measure -from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -179,7 +178,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): """ txn.execute(sql, (room_id, Membership.JOIN)) - return [to_ascii(r[0]) for r in txn] + return [r[0] for r in txn] @cached(max_entries=100000) def get_room_summary(self, room_id): @@ -223,7 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): txn.execute(sql, (room_id,)) res = {} for count, membership in txn: - summary = res.setdefault(to_ascii(membership), MemberSummary([], count)) + summary = res.setdefault(membership, MemberSummary([], count)) # we order by membership and then fairly arbitrarily by event_id so # heroes are consistent @@ -255,11 +254,11 @@ class RoomMemberWorkerStore(EventsWorkerStore): # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user. txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6)) for user_id, membership, event_id in txn: - summary = res[to_ascii(membership)] + summary = res[membership] # we will always have a summary for this membership type at this # point given the summary currently contains the counts. members = summary.members - members.append((to_ascii(user_id), to_ascii(event_id))) + members.append((user_id, event_id)) return res @@ -584,13 +583,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): ev_entry = event_map.get(event_id) if ev_entry: if ev_entry.event.membership == Membership.JOIN: - users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( - display_name=to_ascii( - ev_entry.event.content.get("displayname", None) - ), - avatar_url=to_ascii( - ev_entry.event.content.get("avatar_url", None) - ), + users_in_room[ev_entry.event.state_key] = ProfileInfo( + display_name=ev_entry.event.content.get("displayname", None), + avatar_url=ev_entry.event.content.get("avatar_url", None), ) else: missing_member_event_ids.append(event_id) @@ -604,9 +599,9 @@ class RoomMemberWorkerStore(EventsWorkerStore): if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: if event.event_id in member_event_ids: - users_in_room[to_ascii(event.state_key)] = ProfileInfo( - display_name=to_ascii(event.content.get("displayname", None)), - avatar_url=to_ascii(event.content.get("avatar_url", None)), + users_in_room[event.state_key] = ProfileInfo( + display_name=event.content.get("displayname", None), + avatar_url=event.content.get("avatar_url", None), ) return users_in_room diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 21052fcc7a..347cc50778 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -29,7 +29,6 @@ from synapse.storage.database import Database from synapse.storage.state import StateFilter from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -185,9 +184,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): (room_id,), ) - return { - (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn - } + return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} return self.db.runInteraction( "get_current_state_ids", _get_current_state_ids_txn diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 2cfa5cf721..81a44184ca 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -19,8 +19,7 @@ import re import string from collections import Iterable -import six -from six import PY2, PY3 +from six import PY3 from six.moves import range from synapse.api.errors import Codes, SynapseError @@ -68,23 +67,6 @@ def is_ascii(s): return True -def to_ascii(s): - """Converts a string to ascii if it is ascii, otherwise leave it alone. - - If given None then will return None. - """ - if PY3: - return s - - if s is None: - return None - - try: - return s.encode("ascii") - except UnicodeEncodeError: - return s - - def assert_valid_client_secret(client_secret): """Validate that a given string matches the client_secret regex defined by the spec""" if client_secret_regex.match(client_secret) is None: From 91f51c611c86ff7e85b20b6acb0b1025b65edcdf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:18:22 +0100 Subject: [PATCH 08/14] remove redundant `__func__` this is a no-op under python 3 --- synapse/replication/slave/storage/_base.py | 9 --------- synapse/replication/slave/storage/presence.py | 8 ++++---- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index 2904bd0235..f9e2533e96 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -16,8 +16,6 @@ import logging from typing import Optional -import six - from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine @@ -26,13 +24,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator logger = logging.getLogger(__name__) -def __func__(inp): - if six.PY3: - return inp - else: - return inp.__func__ - - class BaseSlavedStore(CacheInvalidationWorkerStore): def __init__(self, database: Database, db_conn, hs): super(BaseSlavedStore, self).__init__(database, db_conn, hs) diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py index bd79ba99be..4e0124842d 100644 --- a/synapse/replication/slave/storage/presence.py +++ b/synapse/replication/slave/storage/presence.py @@ -18,7 +18,7 @@ from synapse.storage.data_stores.main.presence import PresenceStore from synapse.storage.database import Database from synapse.util.caches.stream_change_cache import StreamChangeCache -from ._base import BaseSlavedStore, __func__ +from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -27,14 +27,14 @@ class SlavedPresenceStore(BaseSlavedStore): super(SlavedPresenceStore, self).__init__(database, db_conn, hs) self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id") - self._presence_on_startup = self._get_active_presence(db_conn) + self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore self.presence_stream_cache = StreamChangeCache( "PresenceStreamChangeCache", self._presence_id_gen.get_current_token() ) - _get_active_presence = __func__(DataStore._get_active_presence) - take_presence_startup_info = __func__(DataStore.take_presence_startup_info) + _get_active_presence = DataStore._get_active_presence + take_presence_startup_info = DataStore.take_presence_startup_info _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"] get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"] From e6027562e2a1964bcaa0163f1615ab72bfc6630b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:26:54 +0100 Subject: [PATCH 09/14] remove `builtins.buffer` code from storage code this is no longer needed on python 3 --- scripts-dev/convert_server_keys.py | 9 ++------- synapse/storage/_base.py | 8 -------- synapse/storage/data_stores/main/keys.py | 10 ++-------- synapse/storage/data_stores/main/transactions.py | 9 +-------- 4 files changed, 5 insertions(+), 31 deletions(-) diff --git a/scripts-dev/convert_server_keys.py b/scripts-dev/convert_server_keys.py index 06b4c1e2ff..961dc59f11 100644 --- a/scripts-dev/convert_server_keys.py +++ b/scripts-dev/convert_server_keys.py @@ -3,8 +3,6 @@ import json import sys import time -import six - import psycopg2 import yaml from canonicaljson import encode_canonical_json @@ -12,10 +10,7 @@ from signedjson.key import read_signing_keys from signedjson.sign import sign_json from unpaddedbase64 import encode_base64 -if six.PY2: - db_type = six.moves.builtins.buffer -else: - db_type = memoryview +db_binary_type = memoryview def select_v1_keys(connection): @@ -72,7 +67,7 @@ def rows_v2(server, json): valid_until = json["valid_until_ts"] key_json = encode_canonical_json(json) for key_id in json["verify_keys"]: - yield (server, key_id, "-", valid_until, valid_until, db_type(key_json)) + yield (server, key_id, "-", valid_until, valid_until, db_binary_type(key_json)) def main(): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 59073c0a42..bfce541ca7 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,9 +19,6 @@ import random from abc import ABCMeta from typing import Any, Optional -from six import PY2 -from six.moves import builtins - from canonicaljson import json from synapse.storage.database import LoggingTransaction # noqa: F401 @@ -103,11 +100,6 @@ def db_to_json(db_content): if isinstance(db_content, memoryview): db_content = db_content.tobytes() - # psycopg2 on Python 2 returns buffer objects, which we need to cast to - # bytes to decode - if PY2 and isinstance(db_content, builtins.buffer): - db_content = bytes(db_content) - # Decode it to a Unicode string before feeding it to json.loads, so we # consistenty get a Unicode-containing object out. if isinstance(db_content, (bytes, bytearray)): diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py index ba89c68c9f..4e1642a27a 100644 --- a/synapse/storage/data_stores/main/keys.py +++ b/synapse/storage/data_stores/main/keys.py @@ -17,8 +17,6 @@ import itertools import logging -import six - from signedjson.key import decode_verify_key_bytes from synapse.storage._base import SQLBaseStore @@ -28,12 +26,8 @@ from synapse.util.iterutils import batch_iter logger = logging.getLogger(__name__) -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview + +db_binary_type = memoryview class KeyStore(SQLBaseStore): diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py index 5b07c2fbc0..a9bf457939 100644 --- a/synapse/storage/data_stores/main/transactions.py +++ b/synapse/storage/data_stores/main/transactions.py @@ -16,8 +16,6 @@ import logging from collections import namedtuple -import six - from canonicaljson import encode_canonical_json from twisted.internet import defer @@ -27,12 +25,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import Database from synapse.util.caches.expiringcache import ExpiringCache -# py2 sqlite has buffer hardcoded as only binary type, so we must use it, -# despite being deprecated and removed in favor of memoryview -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview +db_binary_type = memoryview logger = logging.getLogger(__name__) From d4676910c91dd492ca5cc7c207969fa7bfe1bbee Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:17:06 +0100 Subject: [PATCH 10/14] remove miscellaneous PY2 code --- synapse/http/matrixfederationclient.py | 8 ++------ synapse/logging/utils.py | 10 ++------- synapse/push/httppusher.py | 11 +++------- synapse/rest/media/v1/_base.py | 27 +++++++++---------------- synapse/util/caches/__init__.py | 7 +------ synapse/util/stringutils.py | 28 +++++++------------------- 6 files changed, 24 insertions(+), 67 deletions(-) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 225a47e3c3..44077f5349 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -19,7 +19,7 @@ import random import sys from io import BytesIO -from six import PY3, raise_from, string_types +from six import raise_from, string_types from six.moves import urllib import attr @@ -70,11 +70,7 @@ incoming_responses_counter = Counter( MAX_LONG_RETRIES = 10 MAX_SHORT_RETRIES = 3 - -if PY3: - MAXINT = sys.maxsize -else: - MAXINT = sys.maxint +MAXINT = sys.maxsize _next_id = 1 diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py index 0c2527bd86..99049bb5d8 100644 --- a/synapse/logging/utils.py +++ b/synapse/logging/utils.py @@ -20,8 +20,6 @@ import time from functools import wraps from inspect import getcallargs -from six import PY3 - _TIME_FUNC_ID = 0 @@ -30,12 +28,8 @@ def _log_debug_as_f(f, msg, msg_args): logger = logging.getLogger(name) if logger.isEnabledFor(logging.DEBUG): - if PY3: - lineno = f.__code__.co_firstlineno - pathname = f.__code__.co_filename - else: - lineno = f.func_code.co_firstlineno - pathname = f.func_code.co_filename + lineno = f.__code__.co_firstlineno + pathname = f.__code__.co_filename record = logging.LogRecord( name=name, diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 5bb17d1228..eaaa7afc91 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -15,8 +15,6 @@ # limitations under the License. import logging -import six - from prometheus_client import Counter from twisted.internet import defer @@ -28,9 +26,6 @@ from synapse.push import PusherConfigException from . import push_rule_evaluator, push_tools -if six.PY3: - long = int - logger = logging.getLogger(__name__) http_push_processed_counter = Counter( @@ -318,7 +313,7 @@ class HttpPusher(object): { "app_id": self.app_id, "pushkey": self.pushkey, - "pushkey_ts": long(self.pushkey_ts / 1000), + "pushkey_ts": int(self.pushkey_ts / 1000), "data": self.data_minus_url, } ], @@ -347,7 +342,7 @@ class HttpPusher(object): { "app_id": self.app_id, "pushkey": self.pushkey, - "pushkey_ts": long(self.pushkey_ts / 1000), + "pushkey_ts": int(self.pushkey_ts / 1000), "data": self.data_minus_url, "tweaks": tweaks, } @@ -409,7 +404,7 @@ class HttpPusher(object): { "app_id": self.app_id, "pushkey": self.pushkey, - "pushkey_ts": long(self.pushkey_ts / 1000), + "pushkey_ts": int(self.pushkey_ts / 1000), "data": self.data_minus_url, } ], diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 503f2bed98..3689777266 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -17,7 +17,6 @@ import logging import os -from six import PY3 from six.moves import urllib from twisted.internet import defer @@ -324,23 +323,15 @@ def get_filename_from_headers(headers): upload_name_utf8 = upload_name_utf8[7:] # We have a filename*= section. This MUST be ASCII, and any UTF-8 # bytes are %-quoted. - if PY3: - try: - # Once it is decoded, we can then unquote the %-encoded - # parts strictly into a unicode string. - upload_name = urllib.parse.unquote( - upload_name_utf8.decode("ascii"), errors="strict" - ) - except UnicodeDecodeError: - # Incorrect UTF-8. - pass - else: - # On Python 2, we first unquote the %-encoded parts and then - # decode it strictly using UTF-8. - try: - upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8") - except UnicodeDecodeError: - pass + try: + # Once it is decoded, we can then unquote the %-encoded + # parts strictly into a unicode string. + upload_name = urllib.parse.unquote( + upload_name_utf8.decode("ascii"), errors="strict" + ) + except UnicodeDecodeError: + # Incorrect UTF-8. + pass # If there isn't check for an ascii name. if not upload_name: diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 4b8a0c7a8f..dd356bf156 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -15,11 +15,9 @@ # limitations under the License. import logging +from sys import intern from typing import Callable, Dict, Optional -import six -from six.moves import intern - import attr from prometheus_client.core import Gauge @@ -154,9 +152,6 @@ def intern_string(string): return None try: - if six.PY2: - string = string.encode("ascii") - return intern(string) except UnicodeEncodeError: return string diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 81a44184ca..08c86e92b8 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -19,9 +19,6 @@ import re import string from collections import Iterable -from six import PY3 -from six.moves import range - from synapse.api.errors import Codes, SynapseError _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" @@ -46,24 +43,13 @@ def random_string_with_symbols(length): def is_ascii(s): - - if PY3: - if isinstance(s, bytes): - try: - s.decode("ascii").encode("ascii") - except UnicodeDecodeError: - return False - except UnicodeEncodeError: - return False - return True - - try: - s.encode("ascii") - except UnicodeEncodeError: - return False - except UnicodeDecodeError: - return False - else: + if isinstance(s, bytes): + try: + s.decode("ascii").encode("ascii") + except UnicodeDecodeError: + return False + except UnicodeEncodeError: + return False return True From ab57353de329e2d05c377f8949deb56a0829be78 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 15 May 2020 19:29:28 +0100 Subject: [PATCH 11/14] changelog --- changelog.d/7519.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/7519.misc diff --git a/changelog.d/7519.misc b/changelog.d/7519.misc new file mode 100644 index 0000000000..c730b5e507 --- /dev/null +++ b/changelog.d/7519.misc @@ -0,0 +1 @@ +Remove some redundant Python 2 support code. From c29915bd05513a329e099d7e2970768113595830 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 15 May 2020 15:05:25 -0400 Subject: [PATCH 12/14] Add type hints to room member handlers (#7513) --- changelog.d/7513.misc | 1 + synapse/handlers/room_member.py | 284 +++++++++++++------------ synapse/handlers/room_member_worker.py | 28 ++- tox.ini | 2 + 4 files changed, 176 insertions(+), 139 deletions(-) create mode 100644 changelog.d/7513.misc diff --git a/changelog.d/7513.misc b/changelog.d/7513.misc new file mode 100644 index 0000000000..2ea7373e29 --- /dev/null +++ b/changelog.d/7513.misc @@ -0,0 +1 @@ +Add type hints to room member handler. diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 4ddeba4c97..e51e1c32fe 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -17,13 +17,16 @@ import abc import logging +from typing import Dict, Iterable, List, Optional, Tuple, Union from six.moves import http_client from synapse import types from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.types import Collection, RoomID, UserID +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -74,84 +77,84 @@ class RoomMemberHandler(object): self.base_handler = BaseHandler(hs) @abc.abstractmethod - async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> Optional[dict]: """Try and join a room that this server is not in Args: - requester (Requester) - remote_room_hosts (list[str]): List of servers that can be used - to join via. - room_id (str): Room that we are trying to join - user (UserID): User who is trying to join - content (dict): A dict that should be used as the content of the - join event. - - Returns: - Deferred + requester + remote_room_hosts: List of servers that can be used to join via. + room_id: Room that we are trying to join + user: User who is trying to join + content: A dict that should be used as the content of the join event. """ raise NotImplementedError() @abc.abstractmethod async def _remote_reject_invite( - self, requester, remote_room_hosts, room_id, target, content - ): + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + target: UserID, + content: dict, + ) -> dict: """Attempt to reject an invite for a room this server is not in. If we fail to do so we locally mark the invite as rejected. Args: - requester (Requester) - remote_room_hosts (list[str]): List of servers to use to try and - reject invite - room_id (str) - target (UserID): The user rejecting the invite - content (dict): The content for the rejection event + requester + remote_room_hosts: List of servers to use to try and reject invite + room_id + target: The user rejecting the invite + content: The content for the rejection event Returns: - Deferred[dict]: A dictionary to be returned to the client, may + A dictionary to be returned to the client, may include event_id etc, or nothing if we locally rejected """ raise NotImplementedError() @abc.abstractmethod - async def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has joined the room. Args: - target (UserID) - room_id (str) - - Returns: - None + target + room_id """ raise NotImplementedError() @abc.abstractmethod - async def _user_left_room(self, target, room_id): + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Notifies distributor on master process that the user has left the room. Args: - target (UserID) - room_id (str) - - Returns: - None + target + room_id """ raise NotImplementedError() async def _local_membership_update( self, - requester, - target, - room_id, - membership, + requester: Requester, + target: UserID, + room_id: str, + membership: str, prev_event_ids: Collection[str], - txn_id=None, - ratelimit=True, - content=None, - require_consent=True, - ): + txn_id: Optional[str] = None, + ratelimit: bool = True, + content: Optional[dict] = None, + require_consent: bool = True, + ) -> EventBase: user_id = target.to_string() if content is None: @@ -214,16 +217,13 @@ class RoomMemberHandler(object): async def copy_room_tags_and_direct_to_room( self, old_room_id, new_room_id, user_id - ): + ) -> None: """Copies the tags and direct room state from one room to another. Args: - old_room_id (str) - new_room_id (str) - user_id (str) - - Returns: - Deferred[None] + old_room_id: The room ID of the old room. + new_room_id: The room ID of the new room. + user_id: The user's ID. """ # Retrieve user account data for predecessor room user_account_data, _ = await self.store.get_account_data_for_user(user_id) @@ -253,17 +253,17 @@ class RoomMemberHandler(object): async def update_membership( self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, - ): + requester: Requester, + target: UserID, + room_id: str, + action: str, + txn_id: Optional[str] = None, + remote_room_hosts: Optional[List[str]] = None, + third_party_signed: Optional[dict] = None, + ratelimit: bool = True, + content: Optional[dict] = None, + require_consent: bool = True, + ) -> Union[EventBase, Optional[dict]]: key = (room_id,) with (await self.member_linearizer.queue(key)): @@ -284,17 +284,17 @@ class RoomMemberHandler(object): async def _update_membership( self, - requester, - target, - room_id, - action, - txn_id=None, - remote_room_hosts=None, - third_party_signed=None, - ratelimit=True, - content=None, - require_consent=True, - ): + requester: Requester, + target: UserID, + room_id: str, + action: str, + txn_id: Optional[str] = None, + remote_room_hosts: Optional[List[str]] = None, + third_party_signed: Optional[dict] = None, + ratelimit: bool = True, + content: Optional[dict] = None, + require_consent: bool = True, + ) -> Union[EventBase, Optional[dict]]: content_specified = bool(content) if content is None: content = {} @@ -468,12 +468,11 @@ class RoomMemberHandler(object): else: # send the rejection to the inviter's HS. remote_room_hosts = remote_room_hosts + [inviter.domain] - res = await self._remote_reject_invite( + return await self._remote_reject_invite( requester, remote_room_hosts, room_id, target, content, ) - return res - res = await self._local_membership_update( + return await self._local_membership_update( requester=requester, target=target, room_id=room_id, @@ -484,9 +483,10 @@ class RoomMemberHandler(object): content=content, require_consent=require_consent, ) - return res - async def transfer_room_state_on_room_upgrade(self, old_room_id, room_id): + async def transfer_room_state_on_room_upgrade( + self, old_room_id: str, room_id: str + ) -> None: """Upon our server becoming aware of an upgraded room, either by upgrading a room ourselves or joining one, we can transfer over information from the previous room. @@ -494,12 +494,8 @@ class RoomMemberHandler(object): well as migrating the room directory state. Args: - old_room_id (str): The ID of the old room - - room_id (str): The ID of the new room - - Returns: - Deferred + old_room_id: The ID of the old room + room_id: The ID of the new room """ logger.info("Transferring room state from %s to %s", old_room_id, room_id) @@ -526,17 +522,16 @@ class RoomMemberHandler(object): # Remove the old room from those groups await self.store.remove_room_from_group(group_id, old_room_id) - async def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids): + async def copy_user_state_on_room_upgrade( + self, old_room_id: str, new_room_id: str, user_ids: Iterable[str] + ) -> None: """Copy user-specific information when they join a new room when that new room is the result of a room upgrade Args: - old_room_id (str): The ID of upgraded room - new_room_id (str): The ID of the new room - user_ids (Iterable[str]): User IDs to copy state for - - Returns: - Deferred + old_room_id: The ID of upgraded room + new_room_id: The ID of the new room + user_ids: User IDs to copy state for """ logger.debug( @@ -566,17 +561,23 @@ class RoomMemberHandler(object): ) continue - async def send_membership_event(self, requester, event, context, ratelimit=True): + async def send_membership_event( + self, + requester: Requester, + event: EventBase, + context: EventContext, + ratelimit: bool = True, + ): """ Change the membership status of a user in a room. Args: - requester (Requester): The local user who requested the membership + requester: The local user who requested the membership event. If None, certain checks, like whether this homeserver can act as the sender, will be skipped. - event (SynapseEvent): The membership event. + event: The membership event. context: The context of the event. - ratelimit (bool): Whether to rate limit this request. + ratelimit: Whether to rate limit this request. Raises: SynapseError if there was a problem changing the membership. """ @@ -636,7 +637,9 @@ class RoomMemberHandler(object): if prev_member_event.membership == Membership.JOIN: await self._user_left_room(target_user, room_id) - async def _can_guest_join(self, current_state_ids): + async def _can_guest_join( + self, current_state_ids: Dict[Tuple[str, str], str] + ) -> bool: """ Returns whether a guest can join a room based on its current state. """ @@ -653,12 +656,14 @@ class RoomMemberHandler(object): and guest_access.content["guest_access"] == "can_join" ) - async def lookup_room_alias(self, room_alias): + async def lookup_room_alias( + self, room_alias: RoomAlias + ) -> Tuple[RoomID, List[str]]: """ Get the room ID associated with a room alias. Args: - room_alias (RoomAlias): The alias to look up. + room_alias: The alias to look up. Returns: A tuple of: The room ID as a RoomID object. @@ -682,24 +687,25 @@ class RoomMemberHandler(object): return RoomID.from_string(room_id), servers - async def _get_inviter(self, user_id, room_id): + async def _get_inviter(self, user_id: str, room_id: str) -> Optional[UserID]: invite = await self.store.get_invite_for_local_user_in_room( user_id=user_id, room_id=room_id ) if invite: return UserID.from_string(invite.sender) + return None async def do_3pid_invite( self, - room_id, - inviter, - medium, - address, - id_server, - requester, - txn_id, - id_access_token=None, - ): + room_id: str, + inviter: UserID, + medium: str, + address: str, + id_server: str, + requester: Requester, + txn_id: Optional[str], + id_access_token: Optional[str] = None, + ) -> None: if self.config.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: @@ -748,15 +754,15 @@ class RoomMemberHandler(object): async def _make_and_store_3pid_invite( self, - requester, - id_server, - medium, - address, - room_id, - user, - txn_id, - id_access_token=None, - ): + requester: Requester, + id_server: str, + medium: str, + address: str, + room_id: str, + user: UserID, + txn_id: Optional[str], + id_access_token: Optional[str] = None, + ) -> None: room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" @@ -830,7 +836,9 @@ class RoomMemberHandler(object): txn_id=txn_id, ) - async def _is_host_in_room(self, current_state_ids): + async def _is_host_in_room( + self, current_state_ids: Dict[Tuple[str, str], str] + ) -> bool: # Have we just created the room, and is this about to be the very # first member event? create_event_id = current_state_ids.get(("m.room.create", "")) @@ -852,7 +860,7 @@ class RoomMemberHandler(object): return False - async def _is_server_notice_room(self, room_id): + async def _is_server_notice_room(self, room_id: str) -> bool: if self._server_notices_mxid is None: return False user_ids = await self.store.get_users_in_room(room_id) @@ -867,13 +875,15 @@ class RoomMemberMasterHandler(RoomMemberHandler): self.distributor.declare("user_joined_room") self.distributor.declare("user_left_room") - async def _is_remote_room_too_complex(self, room_id, remote_room_hosts): + async def _is_remote_room_too_complex( + self, room_id: str, remote_room_hosts: List[str] + ) -> Optional[bool]: """ Check if complexity of a remote room is too great. Args: - room_id (str) - remote_room_hosts (list[str]) + room_id + remote_room_hosts Returns: bool of whether the complexity is too great, or None if unable to be fetched @@ -887,21 +897,26 @@ class RoomMemberMasterHandler(RoomMemberHandler): return complexity["v1"] > max_complexity return None - async def _is_local_room_too_complex(self, room_id): + async def _is_local_room_too_complex(self, room_id: str) -> bool: """ Check if the complexity of a local room is too great. Args: - room_id (str) - - Returns: bool + room_id: The room ID to check for complexity. """ max_complexity = self.hs.config.limit_remote_rooms.complexity complexity = await self.store.get_room_complexity(room_id) return complexity["v1"] > max_complexity - async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> None: """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -961,8 +976,13 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) async def _remote_reject_invite( - self, requester, remote_room_hosts, room_id, target, content - ): + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + target: UserID, + content: dict, + ) -> dict: """Implements RoomMemberHandler._remote_reject_invite """ fed_handler = self.federation_handler @@ -983,17 +1003,17 @@ class RoomMemberMasterHandler(RoomMemberHandler): await self.store.locally_reject_invite(target.to_string(), room_id) return {} - async def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room """ - return user_joined_room(self.distributor, target, room_id) + user_joined_room(self.distributor, target, room_id) - async def _user_left_room(self, target, room_id): + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ - return user_left_room(self.distributor, target, room_id) + user_left_room(self.distributor, target, room_id) - async def forget(self, user, room_id): + async def forget(self, user: UserID, room_id: str) -> None: user_id = user.to_string() member = await self.state_handler.get_current_state( diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index 0fc54349ab..5c776cc0be 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import List, Optional from synapse.api.errors import SynapseError from synapse.handlers.room_member import RoomMemberHandler @@ -22,6 +23,7 @@ from synapse.replication.http.membership import ( ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite, ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft, ) +from synapse.types import Requester, UserID logger = logging.getLogger(__name__) @@ -34,7 +36,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler): self._remote_reject_client = ReplRejectInvite.make_client(hs) self._notify_change_client = ReplJoinedLeft.make_client(hs) - async def _remote_join(self, requester, remote_room_hosts, room_id, user, content): + async def _remote_join( + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + user: UserID, + content: dict, + ) -> Optional[dict]: """Implements RoomMemberHandler._remote_join """ if len(remote_room_hosts) == 0: @@ -53,8 +62,13 @@ class RoomMemberWorkerHandler(RoomMemberHandler): return ret async def _remote_reject_invite( - self, requester, remote_room_hosts, room_id, target, content - ): + self, + requester: Requester, + remote_room_hosts: List[str], + room_id: str, + target: UserID, + content: dict, + ) -> dict: """Implements RoomMemberHandler._remote_reject_invite """ return await self._remote_reject_client( @@ -65,16 +79,16 @@ class RoomMemberWorkerHandler(RoomMemberHandler): content=content, ) - async def _user_joined_room(self, target, room_id): + async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room """ - return await self._notify_change_client( + await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="joined" ) - async def _user_left_room(self, target, room_id): + async def _user_left_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_left_room """ - return await self._notify_change_client( + await self._notify_change_client( user_id=target.to_string(), room_id=room_id, change="left" ) diff --git a/tox.ini b/tox.ini index a69bc04334..5a1fa610b6 100644 --- a/tox.ini +++ b/tox.ini @@ -188,6 +188,8 @@ commands = mypy \ synapse/handlers/directory.py \ synapse/handlers/oidc_handler.py \ synapse/handlers/presence.py \ + synapse/handlers/room_member.py \ + synapse/handlers/room_member_worker.py \ synapse/handlers/saml_handler.py \ synapse/handlers/sync.py \ synapse/handlers/ui_auth \ From 164f50f5f25a3204cce5fd2c8f196e9a9d4deb5d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 18 May 2020 10:43:05 +0100 Subject: [PATCH 13/14] fix mypy for tests/replication (#7518) --- changelog.d/7518.misc | 1 + tests/replication/slave/storage/test_events.py | 16 +++++----------- tests/replication/tcp/test_commands.py | 4 ++-- tox.ini | 2 +- 4 files changed, 9 insertions(+), 14 deletions(-) create mode 100644 changelog.d/7518.misc diff --git a/changelog.d/7518.misc b/changelog.d/7518.misc new file mode 100644 index 0000000000..f6e143fe1c --- /dev/null +++ b/changelog.d/7518.misc @@ -0,0 +1 @@ +Fix typing annotations in `tests.replication`. diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 0fee8a71c4..1a88c7fb80 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -17,11 +17,12 @@ from canonicaljson import encode_canonical_json from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict -from synapse.events.snapshot import EventContext from synapse.handlers.room import RoomEventSource from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser +from tests.server import FakeTransport + from ._base import BaseSlavedStoreTestCase USER_ID = "@feeling:test" @@ -240,6 +241,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): # limit the replication rate repl_transport = self._server_transport + assert isinstance(repl_transport, FakeTransport) repl_transport.autoflush = False # build the join and message events and persist them in the same batch. @@ -322,7 +324,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): type="m.room.message", key=None, internal={}, - state=None, depth=None, prev_events=[], auth_events=[], @@ -362,15 +363,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event = make_event_from_dict(event_dict, internal_metadata_dict=internal) self.event_id += 1 - - if state is not None: - state_ids = {key: e.event_id for key, e in state.items()} - context = EventContext.with_state( - state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids - ) - else: - state_handler = self.hs.get_state_handler() - context = self.get_success(state_handler.compute_event_context(event)) + state_handler = self.hs.get_state_handler() + context = self.get_success(state_handler.compute_event_context(event)) self.master_store.add_push_actions_to_staging( event.event_id, {user_id: actions for user_id, actions in push_actions} diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py index 7ddfd0a733..60c10a441a 100644 --- a/tests/replication/tcp/test_commands.py +++ b/tests/replication/tcp/test_commands.py @@ -30,7 +30,7 @@ class ParseCommandTestCase(TestCase): def test_parse_rdata(self): line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]' cmd = parse_command_from_line(line) - self.assertIsInstance(cmd, RdataCommand) + assert isinstance(cmd, RdataCommand) self.assertEqual(cmd.stream_name, "events") self.assertEqual(cmd.instance_name, "master") self.assertEqual(cmd.token, 6287863) @@ -38,7 +38,7 @@ class ParseCommandTestCase(TestCase): def test_parse_rdata_batch(self): line = 'RDATA presence master batch ["@foo:example.com", "online"]' cmd = parse_command_from_line(line) - self.assertIsInstance(cmd, RdataCommand) + assert isinstance(cmd, RdataCommand) self.assertEqual(cmd.stream_name, "presence") self.assertEqual(cmd.instance_name, "master") self.assertIsNone(cmd.token) diff --git a/tox.ini b/tox.ini index 5a1fa610b6..3bb4d45e2a 100644 --- a/tox.ini +++ b/tox.ini @@ -207,7 +207,7 @@ commands = mypy \ synapse/storage/util \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ - tests/replication/tcp/streams \ + tests/replication \ tests/test_utils \ tests/rest/client/v2_alpha/test_auth.py \ tests/util/test_stream_change_cache.py From 51055c8c4409e70e8f310fce420b2f2f7f7a257a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 18 May 2020 12:24:48 +0100 Subject: [PATCH 14/14] Allow ReplicationRestResource to be added to workers (#7515) This allows workers to talk to each other over HTTP replication. --- changelog.d/7515.misc | 1 + synapse/app/generic_worker.py | 4 ++++ synapse/replication/http/__init__.py | 13 ++++++++----- 3 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 changelog.d/7515.misc diff --git a/changelog.d/7515.misc b/changelog.d/7515.misc new file mode 100644 index 0000000000..48f3044f90 --- /dev/null +++ b/changelog.d/7515.misc @@ -0,0 +1 @@ +Allow `ReplicationRestResource` to be added to workers. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index ab801108ca..506b70443b 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -47,6 +47,7 @@ from synapse.http.site import SynapseSite from synapse.logging.context import LoggingContext from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore @@ -570,6 +571,9 @@ class GenericWorkerServer(HomeServer): if name in ["keys", "federation"]: resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) + if name == "replication": + resources[REPLICATION_PREFIX] = ReplicationRestResource(self) + root_resource = create_resource_tree(resources, NoResource()) _base.listen_tcp( diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index 4613b2538c..a909744e93 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -34,9 +34,12 @@ class ReplicationRestResource(JsonResource): def register_servlets(self, hs): send_event.register_servlets(hs, self) - membership.register_servlets(hs, self) federation.register_servlets(hs, self) - login.register_servlets(hs, self) - register.register_servlets(hs, self) - devices.register_servlets(hs, self) - streams.register_servlets(hs, self) + + # The following can't currently be instantiated on workers. + if hs.config.worker.worker_app is None: + membership.register_servlets(hs, self) + login.register_servlets(hs, self) + register.register_servlets(hs, self) + devices.register_servlets(hs, self) + streams.register_servlets(hs, self)