Compare commits

...

19 Commits

Author SHA1 Message Date
Erik Johnston d2012df31c Rename config var to stream_writers 2020-05-18 13:22:41 +01:00
Erik Johnston c42b180f4d
Fix typo
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
2020-05-18 12:46:04 +01:00
Erik Johnston 266755226b Fixup review comments 2020-05-18 12:31:27 +01:00
Erik Johnston 7ed308a281 Merge branch 'develop' of github.com:matrix-org/synapse into erikj/move_events 2020-05-18 12:24:43 +01:00
Erik Johnston 51055c8c44
Allow ReplicationRestResource to be added to workers (#7515)
This allows workers to talk to each other over HTTP replication.
2020-05-18 12:24:48 +01:00
Richard van der Hoff 4d1afb1dfe
Merge pull request #7519 from matrix-org/rav/kill_py2_code
Kill off some old python 2 code
2020-05-18 10:45:30 +01:00
Richard van der Hoff 164f50f5f2
fix mypy for tests/replication (#7518) 2020-05-18 10:43:05 +01:00
Patrick Cloke c29915bd05
Add type hints to room member handlers (#7513) 2020-05-15 15:05:25 -04:00
Richard van der Hoff ab57353de3 changelog 2020-05-15 19:37:41 +01:00
Richard van der Hoff d4676910c9 remove miscellaneous PY2 code 2020-05-15 19:37:41 +01:00
Richard van der Hoff e6027562e2 remove `builtins.buffer` code from storage code
this is no longer needed on python 3
2020-05-15 19:37:41 +01:00
Richard van der Hoff 91f51c611c remove redundant `__func__`
this is a no-op under python 3
2020-05-15 19:37:41 +01:00
Richard van der Hoff 65902e08c3 remove to_ascii
this is a no-op on python 3.
2020-05-15 19:12:03 +01:00
Richard van der Hoff 08fa96f030 Remove `exception_to_unicode`
this is a no-op on python 3.
2020-05-15 19:07:24 +01:00
Richard van der Hoff 6c1f7c722f
Fix limit logic for AccountDataStream (#7384)
Make sure that the AccountDataStream presents complete updates, in the right
order.

This is much the same fix as #7337 and #7358, but applied to a different stream.
2020-05-15 19:03:25 +01:00
Andrew Morgan 34a43f0084 Fix a couple of small typos 2020-05-15 18:54:32 +01:00
Patrick Cloke a3cf36f76e
Support UI Authentication for OpenID Connect accounts (#7457) 2020-05-15 12:26:02 -04:00
Erik Johnston 03aff4c75e
Add a worker store for search insertion. (#7516)
This is required as both event persistence and the background update needs access to this function. It should be perfectly safe for two workers to write to that table at the same time.
2020-05-15 17:22:47 +01:00
Andrew Morgan 16090a077f
Prevent 0-member/null room_version rooms from appearing in group room queries (#7465) 2020-05-15 17:17:42 +01:00
46 changed files with 711 additions and 497 deletions

1
changelog.d/7384.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.

1
changelog.d/7457.feature Normal file
View File

@ -0,0 +1 @@
Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs).

1
changelog.d/7465.bugfix Normal file
View File

@ -0,0 +1 @@
Prevent rooms with 0 members or with invalid version strings from breaking group queries.

1
changelog.d/7513.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to room member handler.

1
changelog.d/7515.misc Normal file
View File

@ -0,0 +1 @@
Allow `ReplicationRestResource` to be added to workers.

1
changelog.d/7516.misc Normal file
View File

@ -0,0 +1 @@
Add a worker store for search insertion, required for moving event persistence off master.

1
changelog.d/7518.misc Normal file
View File

@ -0,0 +1 @@
Fix typing annotations in `tests.replication`.

1
changelog.d/7519.misc Normal file
View File

@ -0,0 +1 @@
Remove some redundant Python 2 support code.

View File

@ -3,8 +3,6 @@ import json
import sys
import time
import six
import psycopg2
import yaml
from canonicaljson import encode_canonical_json
@ -12,10 +10,7 @@ from signedjson.key import read_signing_keys
from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64
if six.PY2:
db_type = six.moves.builtins.buffer
else:
db_type = memoryview
db_binary_type = memoryview
def select_v1_keys(connection):
@ -72,7 +67,7 @@ def rows_v2(server, json):
valid_until = json["valid_until_ts"]
key_json = encode_canonical_json(json)
for key_id in json["verify_keys"]:
yield (server, key_id, "-", valid_until, valid_until, db_type(key_json))
yield (server, key_id, "-", valid_until, valid_until, db_binary_type(key_json))
def main():

View File

@ -47,6 +47,7 @@ from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@ -122,6 +123,7 @@ from synapse.storage.data_stores.main.monthly_active_users import (
MonthlyActiveUsersWorkerStore,
)
from synapse.storage.data_stores.main.presence import UserPresenceState
from synapse.storage.data_stores.main.search import SearchWorkerStore
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.types import ReadReceipt
@ -451,6 +453,7 @@ class GenericWorkerSlavedStore(
SlavedFilteringStore,
MonthlyActiveUsersWorkerStore,
MediaRepositoryStore,
SearchWorkerStore,
BaseSlavedStore,
):
def __init__(self, database, db_conn, hs):
@ -568,6 +571,9 @@ class GenericWorkerServer(HomeServer):
if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(

View File

@ -270,7 +270,7 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def get_exlusive_user_regexes(self):
def get_exclusive_user_regexes(self):
"""Get the list of regexes used to determine if a user is exclusively
registered by the AS
"""

View File

@ -94,13 +94,13 @@ class WorkerConfig(Config):
bind_addresses.append("")
# A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map", {}) or {}
instance_map = config.get("instance_map") or {}
self.instance_map = {
name: InstanceLocationConfig(**c) for name, c in instance_map.items()
}
# Map from type of streams to source, c.f. WriterLocations.
writers = config.get("writers", {}) or {}
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
# Check that the configured writer for events also appears in

View File

@ -80,7 +80,9 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled
self._sso_enabled = (
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
)
# we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first

View File

@ -888,7 +888,7 @@ class EventCreationHandler(object):
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
This should only be run on master.
This should only be run on the instance in charge of persisting events.
"""
assert self.config.worker.writers.events == self._instance_name

View File

@ -311,7 +311,7 @@ class OidcHandler:
``ClientAuth`` to authenticate with the client with its ID and secret.
Args:
code: The autorization code we got from the callback.
code: The authorization code we got from the callback.
Returns:
A dict containing various tokens.
@ -497,11 +497,14 @@ class OidcHandler:
return UserInfo(claims)
async def handle_redirect_request(
self, request: SynapseRequest, client_redirect_url: bytes
) -> None:
self,
request: SynapseRequest,
client_redirect_url: bytes,
ui_auth_session_id: Optional[str] = None,
) -> str:
"""Handle an incoming request to /login/sso/redirect
It redirects the browser to the authorization endpoint with a few
It returns a redirect to the authorization endpoint with a few
parameters:
- ``client_id``: the client ID set in ``oidc_config.client_id``
@ -511,24 +514,32 @@ class OidcHandler:
- ``state``: a random string
- ``nonce``: a random string
In addition to redirecting the client, we are setting a cookie with
In addition generating a redirect URL, we are setting a cookie with
a signed macaroon token containing the state, the nonce and the
client_redirect_url params. Those are then checked when the client
comes back from the provider.
Args:
request: the incoming request from the browser.
We'll respond to it with a redirect and a cookie.
client_redirect_url: the URL that we should redirect the client to
when everything is done
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
Returns:
The redirect URL to the authorization endpoint.
"""
state = generate_token()
nonce = generate_token()
cookie = self._generate_oidc_session_token(
state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(),
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
)
request.addCookie(
SESSION_COOKIE_NAME,
@ -541,7 +552,7 @@ class OidcHandler:
metadata = await self.load_metadata()
authorization_endpoint = metadata.get("authorization_endpoint")
uri = prepare_grant_uri(
return prepare_grant_uri(
authorization_endpoint,
client_id=self._client_auth.client_id,
response_type="code",
@ -550,8 +561,6 @@ class OidcHandler:
state=state,
nonce=nonce,
)
request.redirect(uri)
finish_request(request)
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_synapse/oidc/callback
@ -625,7 +634,11 @@ class OidcHandler:
# Deserialize the session token and verify it.
try:
nonce, client_redirect_url = self._verify_oidc_session_token(session, state)
(
nonce,
client_redirect_url,
ui_auth_session_id,
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
self._render_error(request, "invalid_session", str(e))
@ -678,15 +691,21 @@ class OidcHandler:
return
# and finally complete the login
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url
)
if ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
user_id, ui_auth_session_id, request
)
else:
await self._auth_handler.complete_sso_login(
user_id, request, client_redirect_url
)
def _generate_oidc_session_token(
self,
state: str,
nonce: str,
client_redirect_url: str,
ui_auth_session_id: Optional[str],
duration_in_ms: int = (60 * 60 * 1000),
) -> str:
"""Generates a signed token storing data about an OIDC session.
@ -702,6 +721,8 @@ class OidcHandler:
nonce: The ``nonce`` parameter passed to the OIDC provider.
client_redirect_url: The URL the client gave when it initiated the
flow.
ui_auth_session_id: The session ID of the ongoing UI Auth (or
None if this is a login).
duration_in_ms: An optional duration for the token in milliseconds.
Defaults to an hour.
@ -718,12 +739,19 @@ class OidcHandler:
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (client_redirect_url,)
)
if ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]:
def _verify_oidc_session_token(
self, session: str, state: str
) -> Tuple[str, str, Optional[str]]:
"""Verifies and extract an OIDC session token.
This verifies that a given session token was issued by this homeserver
@ -734,7 +762,7 @@ class OidcHandler:
state: The state the OIDC provider gave back
Returns:
The nonce and the client_redirect_url for this session
The nonce, client_redirect_url, and ui_auth_session_id for this session
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
@ -744,17 +772,27 @@ class OidcHandler:
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the `nonce` and `client_redirect_url` from the token
# Extract the `nonce`, `client_redirect_url`, and maybe the
# `ui_auth_session_id` from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
return nonce, client_redirect_url
return nonce, client_redirect_url, ui_auth_session_id
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
@ -773,7 +811,7 @@ class OidcHandler:
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise Exception("No %s caveat in macaroon" % (key,))
raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "

View File

@ -17,13 +17,16 @@
import abc
import logging
from typing import Dict, Iterable, List, Optional, Tuple, Union
from six.moves import http_client
from synapse import types
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.types import Collection, RoomID, UserID
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@ -74,84 +77,84 @@ class RoomMemberHandler(object):
self.base_handler = BaseHandler(hs)
@abc.abstractmethod
async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
async def _remote_join(
self,
requester: Requester,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
content: dict,
) -> Optional[dict]:
"""Try and join a room that this server is not in
Args:
requester (Requester)
remote_room_hosts (list[str]): List of servers that can be used
to join via.
room_id (str): Room that we are trying to join
user (UserID): User who is trying to join
content (dict): A dict that should be used as the content of the
join event.
Returns:
Deferred
requester
remote_room_hosts: List of servers that can be used to join via.
room_id: Room that we are trying to join
user: User who is trying to join
content: A dict that should be used as the content of the join event.
"""
raise NotImplementedError()
@abc.abstractmethod
async def _remote_reject_invite(
self, requester, remote_room_hosts, room_id, target, content
):
self,
requester: Requester,
remote_room_hosts: List[str],
room_id: str,
target: UserID,
content: dict,
) -> dict:
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
Args:
requester (Requester)
remote_room_hosts (list[str]): List of servers to use to try and
reject invite
room_id (str)
target (UserID): The user rejecting the invite
content (dict): The content for the rejection event
requester
remote_room_hosts: List of servers to use to try and reject invite
room_id
target: The user rejecting the invite
content: The content for the rejection event
Returns:
Deferred[dict]: A dictionary to be returned to the client, may
A dictionary to be returned to the client, may
include event_id etc, or nothing if we locally rejected
"""
raise NotImplementedError()
@abc.abstractmethod
async def _user_joined_room(self, target, room_id):
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has joined the
room.
Args:
target (UserID)
room_id (str)
Returns:
None
target
room_id
"""
raise NotImplementedError()
@abc.abstractmethod
async def _user_left_room(self, target, room_id):
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has left the
room.
Args:
target (UserID)
room_id (str)
Returns:
None
target
room_id
"""
raise NotImplementedError()
async def _local_membership_update(
self,
requester,
target,
room_id,
membership,
requester: Requester,
target: UserID,
room_id: str,
membership: str,
prev_event_ids: Collection[str],
txn_id=None,
ratelimit=True,
content=None,
require_consent=True,
):
txn_id: Optional[str] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
) -> EventBase:
user_id = target.to_string()
if content is None:
@ -214,16 +217,13 @@ class RoomMemberHandler(object):
async def copy_room_tags_and_direct_to_room(
self, old_room_id, new_room_id, user_id
):
) -> None:
"""Copies the tags and direct room state from one room to another.
Args:
old_room_id (str)
new_room_id (str)
user_id (str)
Returns:
Deferred[None]
old_room_id: The room ID of the old room.
new_room_id: The room ID of the new room.
user_id: The user's ID.
"""
# Retrieve user account data for predecessor room
user_account_data, _ = await self.store.get_account_data_for_user(user_id)
@ -253,17 +253,17 @@ class RoomMemberHandler(object):
async def update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
content=None,
require_consent=True,
):
requester: Requester,
target: UserID,
room_id: str,
action: str,
txn_id: Optional[str] = None,
remote_room_hosts: Optional[List[str]] = None,
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
) -> Union[EventBase, Optional[dict]]:
key = (room_id,)
with (await self.member_linearizer.queue(key)):
@ -284,17 +284,17 @@ class RoomMemberHandler(object):
async def _update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
content=None,
require_consent=True,
):
requester: Requester,
target: UserID,
room_id: str,
action: str,
txn_id: Optional[str] = None,
remote_room_hosts: Optional[List[str]] = None,
third_party_signed: Optional[dict] = None,
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
) -> Union[EventBase, Optional[dict]]:
content_specified = bool(content)
if content is None:
content = {}
@ -468,12 +468,11 @@ class RoomMemberHandler(object):
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
res = await self._remote_reject_invite(
return await self._remote_reject_invite(
requester, remote_room_hosts, room_id, target, content,
)
return res
res = await self._local_membership_update(
return await self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
@ -484,9 +483,10 @@ class RoomMemberHandler(object):
content=content,
require_consent=require_consent,
)
return res
async def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
async def transfer_room_state_on_room_upgrade(
self, old_room_id: str, room_id: str
) -> None:
"""Upon our server becoming aware of an upgraded room, either by upgrading a room
ourselves or joining one, we can transfer over information from the previous room.
@ -494,12 +494,8 @@ class RoomMemberHandler(object):
well as migrating the room directory state.
Args:
old_room_id (str): The ID of the old room
room_id (str): The ID of the new room
Returns:
Deferred
old_room_id: The ID of the old room
room_id: The ID of the new room
"""
logger.info("Transferring room state from %s to %s", old_room_id, room_id)
@ -526,17 +522,16 @@ class RoomMemberHandler(object):
# Remove the old room from those groups
await self.store.remove_room_from_group(group_id, old_room_id)
async def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
async def copy_user_state_on_room_upgrade(
self, old_room_id: str, new_room_id: str, user_ids: Iterable[str]
) -> None:
"""Copy user-specific information when they join a new room when that new room is the
result of a room upgrade
Args:
old_room_id (str): The ID of upgraded room
new_room_id (str): The ID of the new room
user_ids (Iterable[str]): User IDs to copy state for
Returns:
Deferred
old_room_id: The ID of upgraded room
new_room_id: The ID of the new room
user_ids: User IDs to copy state for
"""
logger.debug(
@ -566,17 +561,23 @@ class RoomMemberHandler(object):
)
continue
async def send_membership_event(self, requester, event, context, ratelimit=True):
async def send_membership_event(
self,
requester: Requester,
event: EventBase,
context: EventContext,
ratelimit: bool = True,
):
"""
Change the membership status of a user in a room.
Args:
requester (Requester): The local user who requested the membership
requester: The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
event: The membership event.
context: The context of the event.
ratelimit (bool): Whether to rate limit this request.
ratelimit: Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
@ -636,7 +637,9 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)
async def _can_guest_join(self, current_state_ids):
async def _can_guest_join(
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
"""
Returns whether a guest can join a room based on its current state.
"""
@ -653,12 +656,14 @@ class RoomMemberHandler(object):
and guest_access.content["guest_access"] == "can_join"
)
async def lookup_room_alias(self, room_alias):
async def lookup_room_alias(
self, room_alias: RoomAlias
) -> Tuple[RoomID, List[str]]:
"""
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
room_alias: The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
@ -682,24 +687,25 @@ class RoomMemberHandler(object):
return RoomID.from_string(room_id), servers
async def _get_inviter(self, user_id, room_id):
async def _get_inviter(self, user_id: str, room_id: str) -> Optional[UserID]:
invite = await self.store.get_invite_for_local_user_in_room(
user_id=user_id, room_id=room_id
)
if invite:
return UserID.from_string(invite.sender)
return None
async def do_3pid_invite(
self,
room_id,
inviter,
medium,
address,
id_server,
requester,
txn_id,
id_access_token=None,
):
room_id: str,
inviter: UserID,
medium: str,
address: str,
id_server: str,
requester: Requester,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> None:
if self.config.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
@ -748,15 +754,15 @@ class RoomMemberHandler(object):
async def _make_and_store_3pid_invite(
self,
requester,
id_server,
medium,
address,
room_id,
user,
txn_id,
id_access_token=None,
):
requester: Requester,
id_server: str,
medium: str,
address: str,
room_id: str,
user: UserID,
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> None:
room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = ""
@ -830,7 +836,9 @@ class RoomMemberHandler(object):
txn_id=txn_id,
)
async def _is_host_in_room(self, current_state_ids):
async def _is_host_in_room(
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
@ -852,7 +860,7 @@ class RoomMemberHandler(object):
return False
async def _is_server_notice_room(self, room_id):
async def _is_server_notice_room(self, room_id: str) -> bool:
if self._server_notices_mxid is None:
return False
user_ids = await self.store.get_users_in_room(room_id)
@ -867,13 +875,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
async def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
async def _is_remote_room_too_complex(
self, room_id: str, remote_room_hosts: List[str]
) -> Optional[bool]:
"""
Check if complexity of a remote room is too great.
Args:
room_id (str)
remote_room_hosts (list[str])
room_id
remote_room_hosts
Returns: bool of whether the complexity is too great, or None
if unable to be fetched
@ -887,21 +897,26 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return complexity["v1"] > max_complexity
return None
async def _is_local_room_too_complex(self, room_id):
async def _is_local_room_too_complex(self, room_id: str) -> bool:
"""
Check if the complexity of a local room is too great.
Args:
room_id (str)
Returns: bool
room_id: The room ID to check for complexity.
"""
max_complexity = self.hs.config.limit_remote_rooms.complexity
complexity = await self.store.get_room_complexity(room_id)
return complexity["v1"] > max_complexity
async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
async def _remote_join(
self,
requester: Requester,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
content: dict,
) -> None:
"""Implements RoomMemberHandler._remote_join
"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
@ -961,8 +976,13 @@ class RoomMemberMasterHandler(RoomMemberHandler):
)
async def _remote_reject_invite(
self, requester, remote_room_hosts, room_id, target, content
):
self,
requester: Requester,
remote_room_hosts: List[str],
room_id: str,
target: UserID,
content: dict,
) -> dict:
"""Implements RoomMemberHandler._remote_reject_invite
"""
fed_handler = self.federation_handler
@ -983,17 +1003,17 @@ class RoomMemberMasterHandler(RoomMemberHandler):
await self.store.locally_reject_invite(target.to_string(), room_id)
return {}
async def _user_joined_room(self, target, room_id):
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room
"""
return user_joined_room(self.distributor, target, room_id)
user_joined_room(self.distributor, target, room_id)
async def _user_left_room(self, target, room_id):
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
return user_left_room(self.distributor, target, room_id)
user_left_room(self.distributor, target, room_id)
async def forget(self, user, room_id):
async def forget(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
member = await self.state_handler.get_current_state(

View File

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import List, Optional
from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
@ -22,6 +23,7 @@ from synapse.replication.http.membership import (
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
)
from synapse.types import Requester, UserID
logger = logging.getLogger(__name__)
@ -34,7 +36,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
self._remote_reject_client = ReplRejectInvite.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs)
async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
async def _remote_join(
self,
requester: Requester,
remote_room_hosts: List[str],
room_id: str,
user: UserID,
content: dict,
) -> Optional[dict]:
"""Implements RoomMemberHandler._remote_join
"""
if len(remote_room_hosts) == 0:
@ -53,8 +62,13 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
return ret
async def _remote_reject_invite(
self, requester, remote_room_hosts, room_id, target, content
):
self,
requester: Requester,
remote_room_hosts: List[str],
room_id: str,
target: UserID,
content: dict,
) -> dict:
"""Implements RoomMemberHandler._remote_reject_invite
"""
return await self._remote_reject_client(
@ -65,16 +79,16 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
content=content,
)
async def _user_joined_room(self, target, room_id):
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room
"""
return await self._notify_change_client(
await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="joined"
)
async def _user_left_room(self, target, room_id):
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
return await self._notify_change_client(
await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="left"
)

View File

@ -19,7 +19,7 @@ import random
import sys
from io import BytesIO
from six import PY3, raise_from, string_types
from six import raise_from, string_types
from six.moves import urllib
import attr
@ -70,11 +70,7 @@ incoming_responses_counter = Counter(
MAX_LONG_RETRIES = 10
MAX_SHORT_RETRIES = 3
if PY3:
MAXINT = sys.maxsize
else:
MAXINT = sys.maxint
MAXINT = sys.maxsize
_next_id = 1

View File

@ -20,8 +20,6 @@ import time
from functools import wraps
from inspect import getcallargs
from six import PY3
_TIME_FUNC_ID = 0
@ -30,12 +28,8 @@ def _log_debug_as_f(f, msg, msg_args):
logger = logging.getLogger(name)
if logger.isEnabledFor(logging.DEBUG):
if PY3:
lineno = f.__code__.co_firstlineno
pathname = f.__code__.co_filename
else:
lineno = f.func_code.co_firstlineno
pathname = f.func_code.co_filename
lineno = f.__code__.co_firstlineno
pathname = f.__code__.co_filename
record = logging.LogRecord(
name=name,

View File

@ -15,8 +15,6 @@
# limitations under the License.
import logging
import six
from prometheus_client import Counter
from twisted.internet import defer
@ -28,9 +26,6 @@ from synapse.push import PusherConfigException
from . import push_rule_evaluator, push_tools
if six.PY3:
long = int
logger = logging.getLogger(__name__)
http_push_processed_counter = Counter(
@ -318,7 +313,7 @@ class HttpPusher(object):
{
"app_id": self.app_id,
"pushkey": self.pushkey,
"pushkey_ts": long(self.pushkey_ts / 1000),
"pushkey_ts": int(self.pushkey_ts / 1000),
"data": self.data_minus_url,
}
],
@ -347,7 +342,7 @@ class HttpPusher(object):
{
"app_id": self.app_id,
"pushkey": self.pushkey,
"pushkey_ts": long(self.pushkey_ts / 1000),
"pushkey_ts": int(self.pushkey_ts / 1000),
"data": self.data_minus_url,
"tweaks": tweaks,
}
@ -409,7 +404,7 @@ class HttpPusher(object):
{
"app_id": self.app_id,
"pushkey": self.pushkey,
"pushkey_ts": long(self.pushkey_ts / 1000),
"pushkey_ts": int(self.pushkey_ts / 1000),
"data": self.data_minus_url,
}
],

View File

@ -34,9 +34,12 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs):
send_event.register_servlets(hs, self)
membership.register_servlets(hs, self)
federation.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
streams.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
membership.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
streams.register_servlets(hs, self)

View File

@ -16,8 +16,6 @@
import logging
from typing import Optional
import six
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
@ -26,13 +24,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
logger = logging.getLogger(__name__)
def __func__(inp):
if six.PY3:
return inp
else:
return inp.__func__
class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)

View File

@ -18,7 +18,7 @@ from synapse.storage.data_stores.main.presence import PresenceStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@ -27,14 +27,14 @@ class SlavedPresenceStore(BaseSlavedStore):
super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
self._presence_on_startup = self._get_active_presence(db_conn)
self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
_get_active_presence = __func__(DataStore._get_active_presence)
take_presence_startup_info = __func__(DataStore.take_presence_startup_info)
_get_active_presence = DataStore._get_active_presence
take_presence_startup_info = DataStore.take_presence_startup_info
_get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]

View File

@ -89,6 +89,8 @@ class ReplicationCommandHandler:
self._streams_to_replicate.append(stream)
continue
# Only add EventStream and BackfillStream as a source on the
# instance in charge of event persistence.
if (
isinstance(stream, (EventsStream, BackfillStream))
and hs.config.worker.writers.events == hs.get_instance_name()

View File

@ -14,14 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
import logging
from collections import namedtuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
List,
Optional,
Tuple,
TypeVar,
)
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__)
# the number of rows to request from an update_function.
@ -37,7 +50,7 @@ Token = int
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question.
#
StreamRow = Tuple
StreamRow = TypeVar("StreamRow", bound=Tuple)
# The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc.
@ -533,32 +546,63 @@ class AccountDataStream(Stream):
"""
AccountDataStreamRow = namedtuple(
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
"AccountDataStream",
("user_id", "room_id", "data_type"), # str # Optional[str] # str
)
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
def __init__(self, hs):
def __init__(self, hs: "synapse.server.HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
current_token_without_instance(self.store.get_max_account_data_stream_id),
db_query_to_update_function(self._update_function),
self._update_function,
)
async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
async def _update_function(
self, instance_name: str, from_token: int, to_token: int, limit: int
) -> StreamUpdateResult:
limited = False
global_results = await self.store.get_updated_global_account_data(
from_token, to_token, limit
)
results = list(room_results)
results.extend(
(stream_id, user_id, None, account_data_type)
# if the global results hit the limit, we'll need to limit the room results to
# the same stream token.
if len(global_results) >= limit:
to_token = global_results[-1][0]
limited = True
room_results = await self.store.get_updated_room_account_data(
from_token, to_token, limit
)
# likewise, if the room results hit the limit, limit the global results to
# the same stream token.
if len(room_results) >= limit:
to_token = room_results[-1][0]
limited = True
# convert the global results to the right format, and limit them to the to_token
# at the same time
global_rows = (
(stream_id, (user_id, None, account_data_type))
for stream_id, user_id, account_data_type in global_results
if stream_id <= to_token
)
return results
# we know that the room_results are already limited to `to_token` so no need
# for a check on `stream_id` here.
room_rows = (
(stream_id, (user_id, room_id, account_data_type))
for stream_id, user_id, room_id, account_data_type in room_results
)
# we need to return a sorted list, so merge them together.
updates = list(heapq.merge(room_rows, global_rows))
return updates, to_token, limited
class GroupServerStream(Stream):

View File

@ -401,19 +401,22 @@ class BaseSSORedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def on_GET(self, request: SynapseRequest):
async def on_GET(self, request: SynapseRequest):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0]
sso_url = self.get_sso_url(client_redirect_url)
sso_url = await self.get_sso_url(request, client_redirect_url)
request.redirect(sso_url)
finish_request(request)
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
"""Get the URL to redirect to, to perform SSO auth
Args:
request: The client request to redirect.
client_redirect_url: the URL that we should redirect the
client to when everything is done
@ -428,7 +431,9 @@ class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._cas_handler = hs.get_cas_handler()
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
return self._cas_handler.get_redirect_url(
{"redirectUrl": client_redirect_url}
).encode("ascii")
@ -465,11 +470,13 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url)
class OIDCRedirectServlet(RestServlet):
class OIDCRedirectServlet(BaseSSORedirectServlet):
"""Implementation for /login/sso/redirect for the OIDC login flow."""
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@ -477,12 +484,12 @@ class OIDCRedirectServlet(RestServlet):
def __init__(self, hs):
self._oidc_handler = hs.get_oidc_handler()
async def on_GET(self, request):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0]
await self._oidc_handler.handle_redirect_request(request, client_redirect_url)
async def get_sso_url(
self, request: SynapseRequest, client_redirect_url: bytes
) -> bytes:
return await self._oidc_handler.handle_redirect_request(
request, client_redirect_url
)
def register_servlets(hs, http_server):

View File

@ -131,14 +131,19 @@ class AuthRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
# SSO configuration.
self._saml_enabled = hs.config.saml2_enabled
if self._saml_enabled:
self._saml_handler = hs.get_saml_handler()
self._cas_enabled = hs.config.cas_enabled
if self._cas_enabled:
self._cas_handler = hs.get_cas_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
self._saml_enabled = hs.config.saml2_enabled
if self._saml_enabled:
self._saml_handler = hs.get_saml_handler()
self._oidc_enabled = hs.config.oidc_enabled
if self._oidc_enabled:
self._oidc_handler = hs.get_oidc_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
async def on_GET(self, request, stagetype):
session = parse_string(request, "session")
@ -172,11 +177,17 @@ class AuthRestServlet(RestServlet):
)
elif self._saml_enabled:
client_redirect_url = ""
client_redirect_url = b""
sso_redirect_url = self._saml_handler.handle_redirect_request(
client_redirect_url, session
)
elif self._oidc_enabled:
client_redirect_url = b""
sso_redirect_url = await self._oidc_handler.handle_redirect_request(
request, client_redirect_url, session
)
else:
raise SynapseError(400, "Homeserver not configured for SSO.")

View File

@ -17,7 +17,6 @@
import logging
import os
from six import PY3
from six.moves import urllib
from twisted.internet import defer
@ -324,23 +323,15 @@ def get_filename_from_headers(headers):
upload_name_utf8 = upload_name_utf8[7:]
# We have a filename*= section. This MUST be ASCII, and any UTF-8
# bytes are %-quoted.
if PY3:
try:
# Once it is decoded, we can then unquote the %-encoded
# parts strictly into a unicode string.
upload_name = urllib.parse.unquote(
upload_name_utf8.decode("ascii"), errors="strict"
)
except UnicodeDecodeError:
# Incorrect UTF-8.
pass
else:
# On Python 2, we first unquote the %-encoded parts and then
# decode it strictly using UTF-8.
try:
upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8")
except UnicodeDecodeError:
pass
try:
# Once it is decoded, we can then unquote the %-encoded
# parts strictly into a unicode string.
upload_name = urllib.parse.unquote(
upload_name_utf8.decode("ascii"), errors="strict"
)
except UnicodeDecodeError:
# Incorrect UTF-8.
pass
# If there isn't check for an ascii name.
if not upload_name:

View File

@ -19,9 +19,6 @@ import random
from abc import ABCMeta
from typing import Any, Optional
from six import PY2
from six.moves import builtins
from canonicaljson import json
from synapse.storage.database import LoggingTransaction # noqa: F401
@ -103,11 +100,6 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
# psycopg2 on Python 2 returns buffer objects, which we need to cast to
# bytes to decode
if PY2 and isinstance(db_content, builtins.buffer):
db_content = bytes(db_content)
# Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out.
if isinstance(db_content, (bytes, bytearray)):

View File

@ -67,7 +67,7 @@ class DataStores(object):
self.main = main_store_class(database, db_conn, hs)
# If we're on a process that can persist events also
# instansiate a `PersistEventsStore`
# instantiate a `PersistEventsStore`
if hs.config.worker.writers.events == hs.get_instance_name():
self.persist_events = PersistEventsStore(
hs, database, self.main

View File

@ -16,6 +16,7 @@
import abc
import logging
from typing import List, Tuple
from canonicaljson import json
@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
def get_all_updated_account_data(
self, last_global_id, last_room_id, current_id, limit
):
"""Get all the client account_data that has changed on the server
Args:
last_global_id(int): The position to fetch from for top level data
last_room_id(int): The position to fetch from for per room data
current_id(int): The position to fetch up to.
Returns:
A deferred pair of lists of tuples of stream_id int, user_id string,
room_id string, and type string.
"""
if last_room_id == current_id and last_global_id == current_id:
return defer.succeed(([], []))
async def get_updated_global_account_data(
self, last_id: int, current_id: int, limit: int
) -> List[Tuple[int, str, str]]:
"""Get the global account_data that has changed, for the account_data stream
def get_updated_account_data_txn(txn):
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
Returns:
A list of tuples of stream_id int, user_id string,
and type string.
"""
if last_id == current_id:
return []
def get_updated_global_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_global_id, current_id, limit))
global_results = txn.fetchall()
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return await self.db.runInteraction(
"get_updated_global_account_data", get_updated_global_account_data_txn
)
async def get_updated_room_account_data(
self, last_id: int, current_id: int, limit: int
) -> List[Tuple[int, str, str, str]]:
"""Get the global account_data that has changed, for the account_data stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
Returns:
A list of tuples of stream_id int, user_id string,
room_id string and type string.
"""
if last_id == current_id:
return []
def get_updated_room_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_room_id, current_id, limit))
room_results = txn.fetchall()
return global_results, room_results
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_account_data_txn", get_updated_account_data_txn
return await self.db.runInteraction(
"get_updated_room_account_data", get_updated_room_account_data_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):

View File

@ -30,12 +30,12 @@ logger = logging.getLogger(__name__)
def _make_exclusive_regex(services_cache):
# We precompie a regex constructed from all the regexes that the AS's
# We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users.
exclusive_user_regexes = [
regex.pattern
for service in services_cache
for regex in service.get_exlusive_user_regexes()
for regex in service.get_exclusive_user_regexes()
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)

View File

@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_invited_users_in_group",
)
def get_rooms_in_group(self, group_id, include_private=False):
def get_rooms_in_group(self, group_id: str, include_private: bool = False):
"""Retrieve the rooms that belong to a given group. Does not return rooms that
lack members.
Args:
group_id: The ID of the group to query for rooms
include_private: Whether to return private rooms in results
Returns:
Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
form of:
{
"room_id": "!a_room_id:example.com", # The ID of the room
"is_public": False # Whether this is a public room or not
}
"""
# TODO: Pagination
keyvalues = {"group_id": group_id}
if not include_private:
keyvalues["is_public"] = True
def _get_rooms_in_group_txn(txn):
sql = """
SELECT room_id, is_public FROM group_rooms
WHERE group_id = ?
AND room_id IN (
SELECT group_rooms.room_id FROM group_rooms
LEFT JOIN room_stats_current ON
group_rooms.room_id = room_stats_current.room_id
AND joined_members > 0
AND local_users_in_room > 0
LEFT JOIN rooms ON
group_rooms.room_id = rooms.room_id
AND (room_version <> '') = ?
)
"""
args = [group_id, False]
return self.db.simple_select_list(
table="group_rooms",
keyvalues=keyvalues,
retcols=("room_id", "is_public"),
desc="get_rooms_in_group",
)
if not include_private:
sql += " AND is_public = ?"
args += [True]
def get_rooms_for_summary_by_category(self, group_id, include_private=False):
txn.execute(sql, args)
return [
{"room_id": room_id, "is_public": is_public}
for room_id, is_public in txn
]
return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn)
def get_rooms_for_summary_by_category(
self, group_id: str, include_private: bool = False,
):
"""Get the rooms and categories that should be included in a summary request
Returns ([rooms], [categories])
Args:
group_id: The ID of the group to query the summary for
include_private: Whether to return private rooms in results
Returns:
Deferred[Tuple[List, Dict]]: A tuple containing:
* A list of dictionaries with the keys:
* "room_id": str, the room ID
* "is_public": bool, whether the room is public
* "category_id": str|None, the category ID if set, else None
* "order": int, the sort order of rooms
* A dictionary with the key:
* category_id (str): a dictionary with the keys:
* "is_public": bool, whether the category is public
* "profile": str, the category profile
* "order": int, the sort order of rooms in this category
"""
def _get_rooms_for_summary_txn(txn):
@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore):
SELECT room_id, is_public, category_id, room_order
FROM group_summary_rooms
WHERE group_id = ?
AND room_id IN (
SELECT group_rooms.room_id FROM group_rooms
LEFT JOIN room_stats_current ON
group_rooms.room_id = room_stats_current.room_id
AND joined_members > 0
AND local_users_in_room > 0
LEFT JOIN rooms ON
group_rooms.room_id = rooms.room_id
AND (room_version <> '') = ?
)
"""
if not include_private:
sql += " AND is_public = ?"
txn.execute(sql, (group_id, True))
txn.execute(sql, (group_id, False, True))
else:
txn.execute(sql, (group_id,))
txn.execute(sql, (group_id, False))
rooms = [
{

View File

@ -17,8 +17,6 @@
import itertools
import logging
import six
from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore
@ -28,12 +26,8 @@ from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview
db_binary_type = memoryview
class KeyStore(SQLBaseStore):

View File

@ -45,7 +45,6 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.metrics import Measure
from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__)
@ -179,7 +178,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (room_id, Membership.JOIN))
return [to_ascii(r[0]) for r in txn]
return [r[0] for r in txn]
@cached(max_entries=100000)
def get_room_summary(self, room_id):
@ -223,7 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (room_id,))
res = {}
for count, membership in txn:
summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
summary = res.setdefault(membership, MemberSummary([], count))
# we order by membership and then fairly arbitrarily by event_id so
# heroes are consistent
@ -255,11 +254,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
for user_id, membership, event_id in txn:
summary = res[to_ascii(membership)]
summary = res[membership]
# we will always have a summary for this membership type at this
# point given the summary currently contains the counts.
members = summary.members
members.append((to_ascii(user_id), to_ascii(event_id)))
members.append((user_id, event_id))
return res
@ -584,13 +583,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
ev_entry = event_map.get(event_id)
if ev_entry:
if ev_entry.event.membership == Membership.JOIN:
users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
display_name=to_ascii(
ev_entry.event.content.get("displayname", None)
),
avatar_url=to_ascii(
ev_entry.event.content.get("avatar_url", None)
),
users_in_room[ev_entry.event.state_key] = ProfileInfo(
display_name=ev_entry.event.content.get("displayname", None),
avatar_url=ev_entry.event.content.get("avatar_url", None),
)
else:
missing_member_event_ids.append(event_id)
@ -604,9 +599,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
if event.event_id in member_event_ids:
users_in_room[to_ascii(event.state_key)] = ProfileInfo(
display_name=to_ascii(event.content.get("displayname", None)),
avatar_url=to_ascii(event.content.get("avatar_url", None)),
users_in_room[event.state_key] = ProfileInfo(
display_name=event.content.get("displayname", None),
avatar_url=event.content.get("avatar_url", None),
)
return users_in_room

View File

@ -37,7 +37,55 @@ SearchEntry = namedtuple(
)
class SearchBackgroundUpdateStore(SQLBaseStore):
class SearchWorkerStore(SQLBaseStore):
def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table
Args:
txn (cursor):
entries (iterable[SearchEntry]):
entries to be added to the table
"""
if not self.hs.config.enable_search:
return
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search"
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
args = (
(
entry.event_id,
entry.room_id,
entry.key,
entry.value,
entry.stream_ordering,
entry.origin_server_ts,
)
for entry in entries
)
txn.executemany(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args = (
(entry.event_id, entry.room_id, entry.key, entry.value)
for entry in entries
)
txn.executemany(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@ -296,52 +344,6 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
return num_rows
def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table
Args:
txn (cursor):
entries (iterable[SearchEntry]):
entries to be added to the table
"""
if not self.hs.config.enable_search:
return
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search"
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
args = (
(
entry.event_id,
entry.room_id,
entry.key,
entry.value,
entry.stream_ordering,
entry.origin_server_ts,
)
for entry in entries
)
txn.executemany(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args = (
(entry.event_id, entry.room_id, entry.key, entry.value)
for entry in entries
)
txn.executemany(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):

View File

@ -29,7 +29,6 @@ from synapse.storage.database import Database
from synapse.storage.state import StateFilter
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__)
@ -185,9 +184,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
(room_id,),
)
return {
(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
}
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
return self.db.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn

View File

@ -16,8 +16,6 @@
import logging
from collections import namedtuple
import six
from canonicaljson import encode_canonical_json
from twisted.internet import defer
@ -27,12 +25,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview
db_binary_type = memoryview
logger = logging.getLogger(__name__)

View File

@ -50,7 +50,6 @@ from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.types import Collection
from synapse.util.stringutils import exception_to_unicode
logger = logging.getLogger(__name__)
@ -424,20 +423,14 @@ class Database(object):
# This can happen if the database disappears mid
# transaction.
logger.warning(
"[TXN OPERROR] {%s} %s %d/%d",
name,
exception_to_unicode(e),
i,
N,
"[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
)
if i < N:
i += 1
try:
conn.rollback()
except self.engine.module.Error as e1:
logger.warning(
"[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
)
logger.warning("[TXN EROLL] {%s} %s", name, e1)
continue
raise
except self.engine.module.DatabaseError as e:
@ -449,9 +442,7 @@ class Database(object):
conn.rollback()
except self.engine.module.Error as e1:
logger.warning(
"[TXN EROLL] {%s} %s",
name,
exception_to_unicode(e1),
"[TXN EROLL] {%s} %s", name, e1,
)
continue
raise

View File

@ -15,11 +15,9 @@
# limitations under the License.
import logging
from sys import intern
from typing import Callable, Dict, Optional
import six
from six.moves import intern
import attr
from prometheus_client.core import Gauge
@ -154,9 +152,6 @@ def intern_string(string):
return None
try:
if six.PY2:
string = string.encode("ascii")
return intern(string)
except UnicodeEncodeError:
return string

View File

@ -19,10 +19,6 @@ import re
import string
from collections import Iterable
import six
from six import PY2, PY3
from six.moves import range
from synapse.api.errors import Codes, SynapseError
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@ -47,80 +43,16 @@ def random_string_with_symbols(length):
def is_ascii(s):
if PY3:
if isinstance(s, bytes):
try:
s.decode("ascii").encode("ascii")
except UnicodeDecodeError:
return False
except UnicodeEncodeError:
return False
return True
try:
s.encode("ascii")
except UnicodeEncodeError:
return False
except UnicodeDecodeError:
return False
else:
if isinstance(s, bytes):
try:
s.decode("ascii").encode("ascii")
except UnicodeDecodeError:
return False
except UnicodeEncodeError:
return False
return True
def to_ascii(s):
"""Converts a string to ascii if it is ascii, otherwise leave it alone.
If given None then will return None.
"""
if PY3:
return s
if s is None:
return None
try:
return s.encode("ascii")
except UnicodeEncodeError:
return s
def exception_to_unicode(e):
"""Helper function to extract the text of an exception as a unicode string
Args:
e (Exception): exception to be stringified
Returns:
unicode
"""
# urgh, this is a mess. The basic problem here is that psycopg2 constructs its
# exceptions with PyErr_SetString, with a (possibly non-ascii) argument. str() will
# then produce the raw byte sequence. Under Python 2, this will then cause another
# error if it gets mixed with a `unicode` object, as per
# https://github.com/matrix-org/synapse/issues/4252
# First of all, if we're under python3, everything is fine because it will sort this
# nonsense out for us.
if not PY2:
return str(e)
# otherwise let's have a stab at decoding the exception message. We'll circumvent
# Exception.__str__(), which would explode if someone raised Exception(u'non-ascii')
# and instead look at what is in the args member.
if len(e.args) == 0:
return ""
elif len(e.args) > 1:
return six.text_type(repr(e.args))
msg = e.args[0]
if isinstance(msg, bytes):
return msg.decode("utf-8", errors="replace")
else:
return msg
def assert_valid_client_secret(client_secret):
"""Validate that a given string matches the client_secret regex defined by the spec"""
if client_secret_regex.match(client_secret) is None:

View File

@ -292,11 +292,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
@defer.inlineCallbacks
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["addCookie", "redirect", "finish"])
yield defer.ensureDeferred(
req = Mock(spec=["addCookie"])
url = yield defer.ensureDeferred(
self.handler.handle_redirect_request(req, b"http://client/redirect")
)
url = req.redirect.call_args[0][0]
url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@ -382,7 +381,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
nonce = "nonce"
client_redirect_url = "http://client/redirect"
session = self.handler._generate_oidc_session_token(
state=state, nonce=nonce, client_redirect_url=client_redirect_url,
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
request.getCookie.return_value = session
@ -472,7 +474,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Mismatching session
session = self.handler._generate_oidc_session_token(
state="state", nonce="nonce", client_redirect_url="http://client/redirect",
state="state",
nonce="nonce",
client_redirect_url="http://client/redirect",
ui_auth_session_id=None,
)
request.args = {}
request.args[b"state"] = [b"mismatching state"]

View File

@ -17,11 +17,12 @@ from canonicaljson import encode_canonical_json
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.handlers.room import RoomEventSource
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser
from tests.server import FakeTransport
from ._base import BaseSlavedStoreTestCase
USER_ID = "@feeling:test"
@ -240,6 +241,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# limit the replication rate
repl_transport = self._server_transport
assert isinstance(repl_transport, FakeTransport)
repl_transport.autoflush = False
# build the join and message events and persist them in the same batch.
@ -322,7 +324,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
type="m.room.message",
key=None,
internal={},
state=None,
depth=None,
prev_events=[],
auth_events=[],
@ -362,15 +363,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event = make_event_from_dict(event_dict, internal_metadata_dict=internal)
self.event_id += 1
if state is not None:
state_ids = {key: e.event_id for key, e in state.items()}
context = EventContext.with_state(
state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids
)
else:
state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event))
state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event))
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}

View File

@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams._base import (
_STREAM_UPDATE_TARGET_ROW_COUNT,
AccountDataStream,
)
from tests.replication._base import BaseStreamTestCase
class AccountDataStreamTestCase(BaseStreamTestCase):
def test_update_function_room_account_data_limit(self):
"""Test replication with many room account data updates
"""
store = self.hs.get_datastore()
# generate lots of account data updates
updates = []
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
update = "m.test_type.%i" % (i,)
self.get_success(
store.add_account_data_to_room("test_user", "test_room", update, {})
)
updates.append(update)
# also one global update
self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
for t in updates:
(stream_name, token, row) = received_rows.pop(0)
self.assertEqual(stream_name, AccountDataStream.NAME)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, t)
self.assertEqual(row.room_id, "test_room")
(stream_name, token, row) = received_rows.pop(0)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, "m.global")
self.assertIsNone(row.room_id)
self.assertEqual([], received_rows)
def test_update_function_global_account_data_limit(self):
"""Test replication with many global account data updates
"""
store = self.hs.get_datastore()
# generate lots of account data updates
updates = []
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
update = "m.test_type.%i" % (i,)
self.get_success(store.add_account_data_for_user("test_user", update, {}))
updates.append(update)
# also one per-room update
self.get_success(
store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
)
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
for t in updates:
(stream_name, token, row) = received_rows.pop(0)
self.assertEqual(stream_name, AccountDataStream.NAME)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, t)
self.assertIsNone(row.room_id)
(stream_name, token, row) = received_rows.pop(0)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, "m.per_room")
self.assertEqual(row.room_id, "test_room")
self.assertEqual([], received_rows)

View File

@ -30,7 +30,7 @@ class ParseCommandTestCase(TestCase):
def test_parse_rdata(self):
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
cmd = parse_command_from_line(line)
self.assertIsInstance(cmd, RdataCommand)
assert isinstance(cmd, RdataCommand)
self.assertEqual(cmd.stream_name, "events")
self.assertEqual(cmd.instance_name, "master")
self.assertEqual(cmd.token, 6287863)
@ -38,7 +38,7 @@ class ParseCommandTestCase(TestCase):
def test_parse_rdata_batch(self):
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
cmd = parse_command_from_line(line)
self.assertIsInstance(cmd, RdataCommand)
assert isinstance(cmd, RdataCommand)
self.assertEqual(cmd.stream_name, "presence")
self.assertEqual(cmd.instance_name, "master")
self.assertIsNone(cmd.token)

View File

@ -188,6 +188,8 @@ commands = mypy \
synapse/handlers/directory.py \
synapse/handlers/oidc_handler.py \
synapse/handlers/presence.py \
synapse/handlers/room_member.py \
synapse/handlers/room_member_worker.py \
synapse/handlers/saml_handler.py \
synapse/handlers/sync.py \
synapse/handlers/ui_auth \
@ -205,7 +207,7 @@ commands = mypy \
synapse/storage/util \
synapse/streams \
synapse/util/caches/stream_change_cache.py \
tests/replication/tcp/streams \
tests/replication \
tests/test_utils \
tests/rest/client/v2_alpha/test_auth.py \
tests/util/test_stream_change_cache.py