Compare commits
19 Commits
64949e78b1
...
d2012df31c
Author | SHA1 | Date |
---|---|---|
![]() |
d2012df31c | |
![]() |
c42b180f4d | |
![]() |
266755226b | |
![]() |
7ed308a281 | |
![]() |
51055c8c44 | |
![]() |
4d1afb1dfe | |
![]() |
164f50f5f2 | |
![]() |
c29915bd05 | |
![]() |
ab57353de3 | |
![]() |
d4676910c9 | |
![]() |
e6027562e2 | |
![]() |
91f51c611c | |
![]() |
65902e08c3 | |
![]() |
08fa96f030 | |
![]() |
6c1f7c722f | |
![]() |
34a43f0084 | |
![]() |
a3cf36f76e | |
![]() |
03aff4c75e | |
![]() |
16090a077f |
|
@ -0,0 +1 @@
|
|||
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
|
|
@ -0,0 +1 @@
|
|||
Add OpenID Connect login/registration support. Contributed by Quentin Gliech, on behalf of [les Connecteurs](https://connecteu.rs).
|
|
@ -0,0 +1 @@
|
|||
Prevent rooms with 0 members or with invalid version strings from breaking group queries.
|
|
@ -0,0 +1 @@
|
|||
Add type hints to room member handler.
|
|
@ -0,0 +1 @@
|
|||
Allow `ReplicationRestResource` to be added to workers.
|
|
@ -0,0 +1 @@
|
|||
Add a worker store for search insertion, required for moving event persistence off master.
|
|
@ -0,0 +1 @@
|
|||
Fix typing annotations in `tests.replication`.
|
|
@ -0,0 +1 @@
|
|||
Remove some redundant Python 2 support code.
|
|
@ -3,8 +3,6 @@ import json
|
|||
import sys
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
import psycopg2
|
||||
import yaml
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
@ -12,10 +10,7 @@ from signedjson.key import read_signing_keys
|
|||
from signedjson.sign import sign_json
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
if six.PY2:
|
||||
db_type = six.moves.builtins.buffer
|
||||
else:
|
||||
db_type = memoryview
|
||||
db_binary_type = memoryview
|
||||
|
||||
|
||||
def select_v1_keys(connection):
|
||||
|
@ -72,7 +67,7 @@ def rows_v2(server, json):
|
|||
valid_until = json["valid_until_ts"]
|
||||
key_json = encode_canonical_json(json)
|
||||
for key_id in json["verify_keys"]:
|
||||
yield (server, key_id, "-", valid_until, valid_until, db_type(key_json))
|
||||
yield (server, key_id, "-", valid_until, valid_until, db_binary_type(key_json))
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
@ -47,6 +47,7 @@ from synapse.http.site import SynapseSite
|
|||
from synapse.logging.context import LoggingContext
|
||||
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
|
@ -122,6 +123,7 @@ from synapse.storage.data_stores.main.monthly_active_users import (
|
|||
MonthlyActiveUsersWorkerStore,
|
||||
)
|
||||
from synapse.storage.data_stores.main.presence import UserPresenceState
|
||||
from synapse.storage.data_stores.main.search import SearchWorkerStore
|
||||
from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore
|
||||
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
|
||||
from synapse.types import ReadReceipt
|
||||
|
@ -451,6 +453,7 @@ class GenericWorkerSlavedStore(
|
|||
SlavedFilteringStore,
|
||||
MonthlyActiveUsersWorkerStore,
|
||||
MediaRepositoryStore,
|
||||
SearchWorkerStore,
|
||||
BaseSlavedStore,
|
||||
):
|
||||
def __init__(self, database, db_conn, hs):
|
||||
|
@ -568,6 +571,9 @@ class GenericWorkerServer(HomeServer):
|
|||
if name in ["keys", "federation"]:
|
||||
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
|
||||
|
||||
if name == "replication":
|
||||
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
|
||||
|
||||
root_resource = create_resource_tree(resources, NoResource())
|
||||
|
||||
_base.listen_tcp(
|
||||
|
|
|
@ -270,7 +270,7 @@ class ApplicationService(object):
|
|||
def is_exclusive_room(self, room_id):
|
||||
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
|
||||
|
||||
def get_exlusive_user_regexes(self):
|
||||
def get_exclusive_user_regexes(self):
|
||||
"""Get the list of regexes used to determine if a user is exclusively
|
||||
registered by the AS
|
||||
"""
|
||||
|
|
|
@ -94,13 +94,13 @@ class WorkerConfig(Config):
|
|||
bind_addresses.append("")
|
||||
|
||||
# A map from instance name to host/port of their HTTP replication endpoint.
|
||||
instance_map = config.get("instance_map", {}) or {}
|
||||
instance_map = config.get("instance_map") or {}
|
||||
self.instance_map = {
|
||||
name: InstanceLocationConfig(**c) for name, c in instance_map.items()
|
||||
}
|
||||
|
||||
# Map from type of streams to source, c.f. WriterLocations.
|
||||
writers = config.get("writers", {}) or {}
|
||||
writers = config.get("stream_writers") or {}
|
||||
self.writers = WriterLocations(**writers)
|
||||
|
||||
# Check that the configured writer for events also appears in
|
||||
|
|
|
@ -80,7 +80,9 @@ class AuthHandler(BaseHandler):
|
|||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self._password_enabled = hs.config.password_enabled
|
||||
self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled
|
||||
self._sso_enabled = (
|
||||
hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
|
||||
)
|
||||
|
||||
# we keep this as a list despite the O(N^2) implication so that we can
|
||||
# keep PASSWORD first and avoid confusing clients which pick the first
|
||||
|
|
|
@ -888,7 +888,7 @@ class EventCreationHandler(object):
|
|||
"""Called when we have fully built the event, have already
|
||||
calculated the push actions for the event, and checked auth.
|
||||
|
||||
This should only be run on master.
|
||||
This should only be run on the instance in charge of persisting events.
|
||||
"""
|
||||
assert self.config.worker.writers.events == self._instance_name
|
||||
|
||||
|
|
|
@ -311,7 +311,7 @@ class OidcHandler:
|
|||
``ClientAuth`` to authenticate with the client with its ID and secret.
|
||||
|
||||
Args:
|
||||
code: The autorization code we got from the callback.
|
||||
code: The authorization code we got from the callback.
|
||||
|
||||
Returns:
|
||||
A dict containing various tokens.
|
||||
|
@ -497,11 +497,14 @@ class OidcHandler:
|
|||
return UserInfo(claims)
|
||||
|
||||
async def handle_redirect_request(
|
||||
self, request: SynapseRequest, client_redirect_url: bytes
|
||||
) -> None:
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
client_redirect_url: bytes,
|
||||
ui_auth_session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Handle an incoming request to /login/sso/redirect
|
||||
|
||||
It redirects the browser to the authorization endpoint with a few
|
||||
It returns a redirect to the authorization endpoint with a few
|
||||
parameters:
|
||||
|
||||
- ``client_id``: the client ID set in ``oidc_config.client_id``
|
||||
|
@ -511,24 +514,32 @@ class OidcHandler:
|
|||
- ``state``: a random string
|
||||
- ``nonce``: a random string
|
||||
|
||||
In addition to redirecting the client, we are setting a cookie with
|
||||
In addition generating a redirect URL, we are setting a cookie with
|
||||
a signed macaroon token containing the state, the nonce and the
|
||||
client_redirect_url params. Those are then checked when the client
|
||||
comes back from the provider.
|
||||
|
||||
|
||||
Args:
|
||||
request: the incoming request from the browser.
|
||||
We'll respond to it with a redirect and a cookie.
|
||||
client_redirect_url: the URL that we should redirect the client to
|
||||
when everything is done
|
||||
ui_auth_session_id: The session ID of the ongoing UI Auth (or
|
||||
None if this is a login).
|
||||
|
||||
Returns:
|
||||
The redirect URL to the authorization endpoint.
|
||||
|
||||
"""
|
||||
|
||||
state = generate_token()
|
||||
nonce = generate_token()
|
||||
|
||||
cookie = self._generate_oidc_session_token(
|
||||
state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(),
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
client_redirect_url=client_redirect_url.decode(),
|
||||
ui_auth_session_id=ui_auth_session_id,
|
||||
)
|
||||
request.addCookie(
|
||||
SESSION_COOKIE_NAME,
|
||||
|
@ -541,7 +552,7 @@ class OidcHandler:
|
|||
|
||||
metadata = await self.load_metadata()
|
||||
authorization_endpoint = metadata.get("authorization_endpoint")
|
||||
uri = prepare_grant_uri(
|
||||
return prepare_grant_uri(
|
||||
authorization_endpoint,
|
||||
client_id=self._client_auth.client_id,
|
||||
response_type="code",
|
||||
|
@ -550,8 +561,6 @@ class OidcHandler:
|
|||
state=state,
|
||||
nonce=nonce,
|
||||
)
|
||||
request.redirect(uri)
|
||||
finish_request(request)
|
||||
|
||||
async def handle_oidc_callback(self, request: SynapseRequest) -> None:
|
||||
"""Handle an incoming request to /_synapse/oidc/callback
|
||||
|
@ -625,7 +634,11 @@ class OidcHandler:
|
|||
|
||||
# Deserialize the session token and verify it.
|
||||
try:
|
||||
nonce, client_redirect_url = self._verify_oidc_session_token(session, state)
|
||||
(
|
||||
nonce,
|
||||
client_redirect_url,
|
||||
ui_auth_session_id,
|
||||
) = self._verify_oidc_session_token(session, state)
|
||||
except MacaroonDeserializationException as e:
|
||||
logger.exception("Invalid session")
|
||||
self._render_error(request, "invalid_session", str(e))
|
||||
|
@ -678,15 +691,21 @@ class OidcHandler:
|
|||
return
|
||||
|
||||
# and finally complete the login
|
||||
await self._auth_handler.complete_sso_login(
|
||||
user_id, request, client_redirect_url
|
||||
)
|
||||
if ui_auth_session_id:
|
||||
await self._auth_handler.complete_sso_ui_auth(
|
||||
user_id, ui_auth_session_id, request
|
||||
)
|
||||
else:
|
||||
await self._auth_handler.complete_sso_login(
|
||||
user_id, request, client_redirect_url
|
||||
)
|
||||
|
||||
def _generate_oidc_session_token(
|
||||
self,
|
||||
state: str,
|
||||
nonce: str,
|
||||
client_redirect_url: str,
|
||||
ui_auth_session_id: Optional[str],
|
||||
duration_in_ms: int = (60 * 60 * 1000),
|
||||
) -> str:
|
||||
"""Generates a signed token storing data about an OIDC session.
|
||||
|
@ -702,6 +721,8 @@ class OidcHandler:
|
|||
nonce: The ``nonce`` parameter passed to the OIDC provider.
|
||||
client_redirect_url: The URL the client gave when it initiated the
|
||||
flow.
|
||||
ui_auth_session_id: The session ID of the ongoing UI Auth (or
|
||||
None if this is a login).
|
||||
duration_in_ms: An optional duration for the token in milliseconds.
|
||||
Defaults to an hour.
|
||||
|
||||
|
@ -718,12 +739,19 @@ class OidcHandler:
|
|||
macaroon.add_first_party_caveat(
|
||||
"client_redirect_url = %s" % (client_redirect_url,)
|
||||
)
|
||||
if ui_auth_session_id:
|
||||
macaroon.add_first_party_caveat(
|
||||
"ui_auth_session_id = %s" % (ui_auth_session_id,)
|
||||
)
|
||||
now = self._clock.time_msec()
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
|
||||
return macaroon.serialize()
|
||||
|
||||
def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]:
|
||||
def _verify_oidc_session_token(
|
||||
self, session: str, state: str
|
||||
) -> Tuple[str, str, Optional[str]]:
|
||||
"""Verifies and extract an OIDC session token.
|
||||
|
||||
This verifies that a given session token was issued by this homeserver
|
||||
|
@ -734,7 +762,7 @@ class OidcHandler:
|
|||
state: The state the OIDC provider gave back
|
||||
|
||||
Returns:
|
||||
The nonce and the client_redirect_url for this session
|
||||
The nonce, client_redirect_url, and ui_auth_session_id for this session
|
||||
"""
|
||||
macaroon = pymacaroons.Macaroon.deserialize(session)
|
||||
|
||||
|
@ -744,17 +772,27 @@ class OidcHandler:
|
|||
v.satisfy_exact("state = %s" % (state,))
|
||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
||||
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
|
||||
# to always satisfy this.
|
||||
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
|
||||
v.satisfy_general(self._verify_expiry)
|
||||
|
||||
v.verify(macaroon, self._macaroon_secret_key)
|
||||
|
||||
# Extract the `nonce` and `client_redirect_url` from the token
|
||||
# Extract the `nonce`, `client_redirect_url`, and maybe the
|
||||
# `ui_auth_session_id` from the token.
|
||||
nonce = self._get_value_from_macaroon(macaroon, "nonce")
|
||||
client_redirect_url = self._get_value_from_macaroon(
|
||||
macaroon, "client_redirect_url"
|
||||
)
|
||||
try:
|
||||
ui_auth_session_id = self._get_value_from_macaroon(
|
||||
macaroon, "ui_auth_session_id"
|
||||
) # type: Optional[str]
|
||||
except ValueError:
|
||||
ui_auth_session_id = None
|
||||
|
||||
return nonce, client_redirect_url
|
||||
return nonce, client_redirect_url, ui_auth_session_id
|
||||
|
||||
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
|
||||
"""Extracts a caveat value from a macaroon token.
|
||||
|
@ -773,7 +811,7 @@ class OidcHandler:
|
|||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(prefix):
|
||||
return caveat.caveat_id[len(prefix) :]
|
||||
raise Exception("No %s caveat in macaroon" % (key,))
|
||||
raise ValueError("No %s caveat in macaroon" % (key,))
|
||||
|
||||
def _verify_expiry(self, caveat: str) -> bool:
|
||||
prefix = "time < "
|
||||
|
|
|
@ -17,13 +17,16 @@
|
|||
|
||||
import abc
|
||||
import logging
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from six.moves import http_client
|
||||
|
||||
from synapse import types
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
||||
from synapse.types import Collection, RoomID, UserID
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.distributor import user_joined_room, user_left_room
|
||||
|
||||
|
@ -74,84 +77,84 @@ class RoomMemberHandler(object):
|
|||
self.base_handler = BaseHandler(hs)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
|
||||
async def _remote_join(
|
||||
self,
|
||||
requester: Requester,
|
||||
remote_room_hosts: List[str],
|
||||
room_id: str,
|
||||
user: UserID,
|
||||
content: dict,
|
||||
) -> Optional[dict]:
|
||||
"""Try and join a room that this server is not in
|
||||
|
||||
Args:
|
||||
requester (Requester)
|
||||
remote_room_hosts (list[str]): List of servers that can be used
|
||||
to join via.
|
||||
room_id (str): Room that we are trying to join
|
||||
user (UserID): User who is trying to join
|
||||
content (dict): A dict that should be used as the content of the
|
||||
join event.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
requester
|
||||
remote_room_hosts: List of servers that can be used to join via.
|
||||
room_id: Room that we are trying to join
|
||||
user: User who is trying to join
|
||||
content: A dict that should be used as the content of the join event.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _remote_reject_invite(
|
||||
self, requester, remote_room_hosts, room_id, target, content
|
||||
):
|
||||
self,
|
||||
requester: Requester,
|
||||
remote_room_hosts: List[str],
|
||||
room_id: str,
|
||||
target: UserID,
|
||||
content: dict,
|
||||
) -> dict:
|
||||
"""Attempt to reject an invite for a room this server is not in. If we
|
||||
fail to do so we locally mark the invite as rejected.
|
||||
|
||||
Args:
|
||||
requester (Requester)
|
||||
remote_room_hosts (list[str]): List of servers to use to try and
|
||||
reject invite
|
||||
room_id (str)
|
||||
target (UserID): The user rejecting the invite
|
||||
content (dict): The content for the rejection event
|
||||
requester
|
||||
remote_room_hosts: List of servers to use to try and reject invite
|
||||
room_id
|
||||
target: The user rejecting the invite
|
||||
content: The content for the rejection event
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: A dictionary to be returned to the client, may
|
||||
A dictionary to be returned to the client, may
|
||||
include event_id etc, or nothing if we locally rejected
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _user_joined_room(self, target, room_id):
|
||||
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
|
||||
"""Notifies distributor on master process that the user has joined the
|
||||
room.
|
||||
|
||||
Args:
|
||||
target (UserID)
|
||||
room_id (str)
|
||||
|
||||
Returns:
|
||||
None
|
||||
target
|
||||
room_id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _user_left_room(self, target, room_id):
|
||||
async def _user_left_room(self, target: UserID, room_id: str) -> None:
|
||||
"""Notifies distributor on master process that the user has left the
|
||||
room.
|
||||
|
||||
Args:
|
||||
target (UserID)
|
||||
room_id (str)
|
||||
|
||||
Returns:
|
||||
None
|
||||
target
|
||||
room_id
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def _local_membership_update(
|
||||
self,
|
||||
requester,
|
||||
target,
|
||||
room_id,
|
||||
membership,
|
||||
requester: Requester,
|
||||
target: UserID,
|
||||
room_id: str,
|
||||
membership: str,
|
||||
prev_event_ids: Collection[str],
|
||||
txn_id=None,
|
||||
ratelimit=True,
|
||||
content=None,
|
||||
require_consent=True,
|
||||
):
|
||||
txn_id: Optional[str] = None,
|
||||
ratelimit: bool = True,
|
||||
content: Optional[dict] = None,
|
||||
require_consent: bool = True,
|
||||
) -> EventBase:
|
||||
user_id = target.to_string()
|
||||
|
||||
if content is None:
|
||||
|
@ -214,16 +217,13 @@ class RoomMemberHandler(object):
|
|||
|
||||
async def copy_room_tags_and_direct_to_room(
|
||||
self, old_room_id, new_room_id, user_id
|
||||
):
|
||||
) -> None:
|
||||
"""Copies the tags and direct room state from one room to another.
|
||||
|
||||
Args:
|
||||
old_room_id (str)
|
||||
new_room_id (str)
|
||||
user_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred[None]
|
||||
old_room_id: The room ID of the old room.
|
||||
new_room_id: The room ID of the new room.
|
||||
user_id: The user's ID.
|
||||
"""
|
||||
# Retrieve user account data for predecessor room
|
||||
user_account_data, _ = await self.store.get_account_data_for_user(user_id)
|
||||
|
@ -253,17 +253,17 @@ class RoomMemberHandler(object):
|
|||
|
||||
async def update_membership(
|
||||
self,
|
||||
requester,
|
||||
target,
|
||||
room_id,
|
||||
action,
|
||||
txn_id=None,
|
||||
remote_room_hosts=None,
|
||||
third_party_signed=None,
|
||||
ratelimit=True,
|
||||
content=None,
|
||||
require_consent=True,
|
||||
):
|
||||
requester: Requester,
|
||||
target: UserID,
|
||||
room_id: str,
|
||||
action: str,
|
||||
txn_id: Optional[str] = None,
|
||||
remote_room_hosts: Optional[List[str]] = None,
|
||||
third_party_signed: Optional[dict] = None,
|
||||
ratelimit: bool = True,
|
||||
content: Optional[dict] = None,
|
||||
require_consent: bool = True,
|
||||
) -> Union[EventBase, Optional[dict]]:
|
||||
key = (room_id,)
|
||||
|
||||
with (await self.member_linearizer.queue(key)):
|
||||
|
@ -284,17 +284,17 @@ class RoomMemberHandler(object):
|
|||
|
||||
async def _update_membership(
|
||||
self,
|
||||
requester,
|
||||
target,
|
||||
room_id,
|
||||
action,
|
||||
txn_id=None,
|
||||
remote_room_hosts=None,
|
||||
third_party_signed=None,
|
||||
ratelimit=True,
|
||||
content=None,
|
||||
require_consent=True,
|
||||
):
|
||||
requester: Requester,
|
||||
target: UserID,
|
||||
room_id: str,
|
||||
action: str,
|
||||
txn_id: Optional[str] = None,
|
||||
remote_room_hosts: Optional[List[str]] = None,
|
||||
third_party_signed: Optional[dict] = None,
|
||||
ratelimit: bool = True,
|
||||
content: Optional[dict] = None,
|
||||
require_consent: bool = True,
|
||||
) -> Union[EventBase, Optional[dict]]:
|
||||
content_specified = bool(content)
|
||||
if content is None:
|
||||
content = {}
|
||||
|
@ -468,12 +468,11 @@ class RoomMemberHandler(object):
|
|||
else:
|
||||
# send the rejection to the inviter's HS.
|
||||
remote_room_hosts = remote_room_hosts + [inviter.domain]
|
||||
res = await self._remote_reject_invite(
|
||||
return await self._remote_reject_invite(
|
||||
requester, remote_room_hosts, room_id, target, content,
|
||||
)
|
||||
return res
|
||||
|
||||
res = await self._local_membership_update(
|
||||
return await self._local_membership_update(
|
||||
requester=requester,
|
||||
target=target,
|
||||
room_id=room_id,
|
||||
|
@ -484,9 +483,10 @@ class RoomMemberHandler(object):
|
|||
content=content,
|
||||
require_consent=require_consent,
|
||||
)
|
||||
return res
|
||||
|
||||
async def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
|
||||
async def transfer_room_state_on_room_upgrade(
|
||||
self, old_room_id: str, room_id: str
|
||||
) -> None:
|
||||
"""Upon our server becoming aware of an upgraded room, either by upgrading a room
|
||||
ourselves or joining one, we can transfer over information from the previous room.
|
||||
|
||||
|
@ -494,12 +494,8 @@ class RoomMemberHandler(object):
|
|||
well as migrating the room directory state.
|
||||
|
||||
Args:
|
||||
old_room_id (str): The ID of the old room
|
||||
|
||||
room_id (str): The ID of the new room
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
old_room_id: The ID of the old room
|
||||
room_id: The ID of the new room
|
||||
"""
|
||||
logger.info("Transferring room state from %s to %s", old_room_id, room_id)
|
||||
|
||||
|
@ -526,17 +522,16 @@ class RoomMemberHandler(object):
|
|||
# Remove the old room from those groups
|
||||
await self.store.remove_room_from_group(group_id, old_room_id)
|
||||
|
||||
async def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
|
||||
async def copy_user_state_on_room_upgrade(
|
||||
self, old_room_id: str, new_room_id: str, user_ids: Iterable[str]
|
||||
) -> None:
|
||||
"""Copy user-specific information when they join a new room when that new room is the
|
||||
result of a room upgrade
|
||||
|
||||
Args:
|
||||
old_room_id (str): The ID of upgraded room
|
||||
new_room_id (str): The ID of the new room
|
||||
user_ids (Iterable[str]): User IDs to copy state for
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
old_room_id: The ID of upgraded room
|
||||
new_room_id: The ID of the new room
|
||||
user_ids: User IDs to copy state for
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
|
@ -566,17 +561,23 @@ class RoomMemberHandler(object):
|
|||
)
|
||||
continue
|
||||
|
||||
async def send_membership_event(self, requester, event, context, ratelimit=True):
|
||||
async def send_membership_event(
|
||||
self,
|
||||
requester: Requester,
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
ratelimit: bool = True,
|
||||
):
|
||||
"""
|
||||
Change the membership status of a user in a room.
|
||||
|
||||
Args:
|
||||
requester (Requester): The local user who requested the membership
|
||||
requester: The local user who requested the membership
|
||||
event. If None, certain checks, like whether this homeserver can
|
||||
act as the sender, will be skipped.
|
||||
event (SynapseEvent): The membership event.
|
||||
event: The membership event.
|
||||
context: The context of the event.
|
||||
ratelimit (bool): Whether to rate limit this request.
|
||||
ratelimit: Whether to rate limit this request.
|
||||
Raises:
|
||||
SynapseError if there was a problem changing the membership.
|
||||
"""
|
||||
|
@ -636,7 +637,9 @@ class RoomMemberHandler(object):
|
|||
if prev_member_event.membership == Membership.JOIN:
|
||||
await self._user_left_room(target_user, room_id)
|
||||
|
||||
async def _can_guest_join(self, current_state_ids):
|
||||
async def _can_guest_join(
|
||||
self, current_state_ids: Dict[Tuple[str, str], str]
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether a guest can join a room based on its current state.
|
||||
"""
|
||||
|
@ -653,12 +656,14 @@ class RoomMemberHandler(object):
|
|||
and guest_access.content["guest_access"] == "can_join"
|
||||
)
|
||||
|
||||
async def lookup_room_alias(self, room_alias):
|
||||
async def lookup_room_alias(
|
||||
self, room_alias: RoomAlias
|
||||
) -> Tuple[RoomID, List[str]]:
|
||||
"""
|
||||
Get the room ID associated with a room alias.
|
||||
|
||||
Args:
|
||||
room_alias (RoomAlias): The alias to look up.
|
||||
room_alias: The alias to look up.
|
||||
Returns:
|
||||
A tuple of:
|
||||
The room ID as a RoomID object.
|
||||
|
@ -682,24 +687,25 @@ class RoomMemberHandler(object):
|
|||
|
||||
return RoomID.from_string(room_id), servers
|
||||
|
||||
async def _get_inviter(self, user_id, room_id):
|
||||
async def _get_inviter(self, user_id: str, room_id: str) -> Optional[UserID]:
|
||||
invite = await self.store.get_invite_for_local_user_in_room(
|
||||
user_id=user_id, room_id=room_id
|
||||
)
|
||||
if invite:
|
||||
return UserID.from_string(invite.sender)
|
||||
return None
|
||||
|
||||
async def do_3pid_invite(
|
||||
self,
|
||||
room_id,
|
||||
inviter,
|
||||
medium,
|
||||
address,
|
||||
id_server,
|
||||
requester,
|
||||
txn_id,
|
||||
id_access_token=None,
|
||||
):
|
||||
room_id: str,
|
||||
inviter: UserID,
|
||||
medium: str,
|
||||
address: str,
|
||||
id_server: str,
|
||||
requester: Requester,
|
||||
txn_id: Optional[str],
|
||||
id_access_token: Optional[str] = None,
|
||||
) -> None:
|
||||
if self.config.block_non_admin_invites:
|
||||
is_requester_admin = await self.auth.is_server_admin(requester.user)
|
||||
if not is_requester_admin:
|
||||
|
@ -748,15 +754,15 @@ class RoomMemberHandler(object):
|
|||
|
||||
async def _make_and_store_3pid_invite(
|
||||
self,
|
||||
requester,
|
||||
id_server,
|
||||
medium,
|
||||
address,
|
||||
room_id,
|
||||
user,
|
||||
txn_id,
|
||||
id_access_token=None,
|
||||
):
|
||||
requester: Requester,
|
||||
id_server: str,
|
||||
medium: str,
|
||||
address: str,
|
||||
room_id: str,
|
||||
user: UserID,
|
||||
txn_id: Optional[str],
|
||||
id_access_token: Optional[str] = None,
|
||||
) -> None:
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
|
||||
inviter_display_name = ""
|
||||
|
@ -830,7 +836,9 @@ class RoomMemberHandler(object):
|
|||
txn_id=txn_id,
|
||||
)
|
||||
|
||||
async def _is_host_in_room(self, current_state_ids):
|
||||
async def _is_host_in_room(
|
||||
self, current_state_ids: Dict[Tuple[str, str], str]
|
||||
) -> bool:
|
||||
# Have we just created the room, and is this about to be the very
|
||||
# first member event?
|
||||
create_event_id = current_state_ids.get(("m.room.create", ""))
|
||||
|
@ -852,7 +860,7 @@ class RoomMemberHandler(object):
|
|||
|
||||
return False
|
||||
|
||||
async def _is_server_notice_room(self, room_id):
|
||||
async def _is_server_notice_room(self, room_id: str) -> bool:
|
||||
if self._server_notices_mxid is None:
|
||||
return False
|
||||
user_ids = await self.store.get_users_in_room(room_id)
|
||||
|
@ -867,13 +875,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
self.distributor.declare("user_joined_room")
|
||||
self.distributor.declare("user_left_room")
|
||||
|
||||
async def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
|
||||
async def _is_remote_room_too_complex(
|
||||
self, room_id: str, remote_room_hosts: List[str]
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Check if complexity of a remote room is too great.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
remote_room_hosts (list[str])
|
||||
room_id
|
||||
remote_room_hosts
|
||||
|
||||
Returns: bool of whether the complexity is too great, or None
|
||||
if unable to be fetched
|
||||
|
@ -887,21 +897,26 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
return complexity["v1"] > max_complexity
|
||||
return None
|
||||
|
||||
async def _is_local_room_too_complex(self, room_id):
|
||||
async def _is_local_room_too_complex(self, room_id: str) -> bool:
|
||||
"""
|
||||
Check if the complexity of a local room is too great.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
|
||||
Returns: bool
|
||||
room_id: The room ID to check for complexity.
|
||||
"""
|
||||
max_complexity = self.hs.config.limit_remote_rooms.complexity
|
||||
complexity = await self.store.get_room_complexity(room_id)
|
||||
|
||||
return complexity["v1"] > max_complexity
|
||||
|
||||
async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
|
||||
async def _remote_join(
|
||||
self,
|
||||
requester: Requester,
|
||||
remote_room_hosts: List[str],
|
||||
room_id: str,
|
||||
user: UserID,
|
||||
content: dict,
|
||||
) -> None:
|
||||
"""Implements RoomMemberHandler._remote_join
|
||||
"""
|
||||
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
|
||||
|
@ -961,8 +976,13 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
)
|
||||
|
||||
async def _remote_reject_invite(
|
||||
self, requester, remote_room_hosts, room_id, target, content
|
||||
):
|
||||
self,
|
||||
requester: Requester,
|
||||
remote_room_hosts: List[str],
|
||||
room_id: str,
|
||||
target: UserID,
|
||||
content: dict,
|
||||
) -> dict:
|
||||
"""Implements RoomMemberHandler._remote_reject_invite
|
||||
"""
|
||||
fed_handler = self.federation_handler
|
||||
|
@ -983,17 +1003,17 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
await self.store.locally_reject_invite(target.to_string(), room_id)
|
||||
return {}
|
||||
|
||||
async def _user_joined_room(self, target, room_id):
|
||||
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
|
||||
"""Implements RoomMemberHandler._user_joined_room
|
||||
"""
|
||||
return user_joined_room(self.distributor, target, room_id)
|
||||
user_joined_room(self.distributor, target, room_id)
|
||||
|
||||
async def _user_left_room(self, target, room_id):
|
||||
async def _user_left_room(self, target: UserID, room_id: str) -> None:
|
||||
"""Implements RoomMemberHandler._user_left_room
|
||||
"""
|
||||
return user_left_room(self.distributor, target, room_id)
|
||||
user_left_room(self.distributor, target, room_id)
|
||||
|
||||
async def forget(self, user, room_id):
|
||||
async def forget(self, user: UserID, room_id: str) -> None:
|
||||
user_id = user.to_string()
|
||||
|
||||
member = await self.state_handler.get_current_state(
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.handlers.room_member import RoomMemberHandler
|
||||
|
@ -22,6 +23,7 @@ from synapse.replication.http.membership import (
|
|||
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
|
||||
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
|
||||
)
|
||||
from synapse.types import Requester, UserID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -34,7 +36,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
|
|||
self._remote_reject_client = ReplRejectInvite.make_client(hs)
|
||||
self._notify_change_client = ReplJoinedLeft.make_client(hs)
|
||||
|
||||
async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
|
||||
async def _remote_join(
|
||||
self,
|
||||
requester: Requester,
|
||||
remote_room_hosts: List[str],
|
||||
room_id: str,
|
||||
user: UserID,
|
||||
content: dict,
|
||||
) -> Optional[dict]:
|
||||
"""Implements RoomMemberHandler._remote_join
|
||||
"""
|
||||
if len(remote_room_hosts) == 0:
|
||||
|
@ -53,8 +62,13 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
|
|||
return ret
|
||||
|
||||
async def _remote_reject_invite(
|
||||
self, requester, remote_room_hosts, room_id, target, content
|
||||
):
|
||||
self,
|
||||
requester: Requester,
|
||||
remote_room_hosts: List[str],
|
||||
room_id: str,
|
||||
target: UserID,
|
||||
content: dict,
|
||||
) -> dict:
|
||||
"""Implements RoomMemberHandler._remote_reject_invite
|
||||
"""
|
||||
return await self._remote_reject_client(
|
||||
|
@ -65,16 +79,16 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
|
|||
content=content,
|
||||
)
|
||||
|
||||
async def _user_joined_room(self, target, room_id):
|
||||
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
|
||||
"""Implements RoomMemberHandler._user_joined_room
|
||||
"""
|
||||
return await self._notify_change_client(
|
||||
await self._notify_change_client(
|
||||
user_id=target.to_string(), room_id=room_id, change="joined"
|
||||
)
|
||||
|
||||
async def _user_left_room(self, target, room_id):
|
||||
async def _user_left_room(self, target: UserID, room_id: str) -> None:
|
||||
"""Implements RoomMemberHandler._user_left_room
|
||||
"""
|
||||
return await self._notify_change_client(
|
||||
await self._notify_change_client(
|
||||
user_id=target.to_string(), room_id=room_id, change="left"
|
||||
)
|
||||
|
|
|
@ -19,7 +19,7 @@ import random
|
|||
import sys
|
||||
from io import BytesIO
|
||||
|
||||
from six import PY3, raise_from, string_types
|
||||
from six import raise_from, string_types
|
||||
from six.moves import urllib
|
||||
|
||||
import attr
|
||||
|
@ -70,11 +70,7 @@ incoming_responses_counter = Counter(
|
|||
|
||||
MAX_LONG_RETRIES = 10
|
||||
MAX_SHORT_RETRIES = 3
|
||||
|
||||
if PY3:
|
||||
MAXINT = sys.maxsize
|
||||
else:
|
||||
MAXINT = sys.maxint
|
||||
MAXINT = sys.maxsize
|
||||
|
||||
|
||||
_next_id = 1
|
||||
|
|
|
@ -20,8 +20,6 @@ import time
|
|||
from functools import wraps
|
||||
from inspect import getcallargs
|
||||
|
||||
from six import PY3
|
||||
|
||||
_TIME_FUNC_ID = 0
|
||||
|
||||
|
||||
|
@ -30,12 +28,8 @@ def _log_debug_as_f(f, msg, msg_args):
|
|||
logger = logging.getLogger(name)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
if PY3:
|
||||
lineno = f.__code__.co_firstlineno
|
||||
pathname = f.__code__.co_filename
|
||||
else:
|
||||
lineno = f.func_code.co_firstlineno
|
||||
pathname = f.func_code.co_filename
|
||||
lineno = f.__code__.co_firstlineno
|
||||
pathname = f.__code__.co_filename
|
||||
|
||||
record = logging.LogRecord(
|
||||
name=name,
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import six
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -28,9 +26,6 @@ from synapse.push import PusherConfigException
|
|||
|
||||
from . import push_rule_evaluator, push_tools
|
||||
|
||||
if six.PY3:
|
||||
long = int
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
http_push_processed_counter = Counter(
|
||||
|
@ -318,7 +313,7 @@ class HttpPusher(object):
|
|||
{
|
||||
"app_id": self.app_id,
|
||||
"pushkey": self.pushkey,
|
||||
"pushkey_ts": long(self.pushkey_ts / 1000),
|
||||
"pushkey_ts": int(self.pushkey_ts / 1000),
|
||||
"data": self.data_minus_url,
|
||||
}
|
||||
],
|
||||
|
@ -347,7 +342,7 @@ class HttpPusher(object):
|
|||
{
|
||||
"app_id": self.app_id,
|
||||
"pushkey": self.pushkey,
|
||||
"pushkey_ts": long(self.pushkey_ts / 1000),
|
||||
"pushkey_ts": int(self.pushkey_ts / 1000),
|
||||
"data": self.data_minus_url,
|
||||
"tweaks": tweaks,
|
||||
}
|
||||
|
@ -409,7 +404,7 @@ class HttpPusher(object):
|
|||
{
|
||||
"app_id": self.app_id,
|
||||
"pushkey": self.pushkey,
|
||||
"pushkey_ts": long(self.pushkey_ts / 1000),
|
||||
"pushkey_ts": int(self.pushkey_ts / 1000),
|
||||
"data": self.data_minus_url,
|
||||
}
|
||||
],
|
||||
|
|
|
@ -34,9 +34,12 @@ class ReplicationRestResource(JsonResource):
|
|||
|
||||
def register_servlets(self, hs):
|
||||
send_event.register_servlets(hs, self)
|
||||
membership.register_servlets(hs, self)
|
||||
federation.register_servlets(hs, self)
|
||||
login.register_servlets(hs, self)
|
||||
register.register_servlets(hs, self)
|
||||
devices.register_servlets(hs, self)
|
||||
streams.register_servlets(hs, self)
|
||||
|
||||
# The following can't currently be instantiated on workers.
|
||||
if hs.config.worker.worker_app is None:
|
||||
membership.register_servlets(hs, self)
|
||||
login.register_servlets(hs, self)
|
||||
register.register_servlets(hs, self)
|
||||
devices.register_servlets(hs, self)
|
||||
streams.register_servlets(hs, self)
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import six
|
||||
|
||||
from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
@ -26,13 +24,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def __func__(inp):
|
||||
if six.PY3:
|
||||
return inp
|
||||
else:
|
||||
return inp.__func__
|
||||
|
||||
|
||||
class BaseSlavedStore(CacheInvalidationWorkerStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
|
||||
|
|
|
@ -18,7 +18,7 @@ from synapse.storage.data_stores.main.presence import PresenceStore
|
|||
from synapse.storage.database import Database
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
from ._base import BaseSlavedStore, __func__
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
|
||||
|
@ -27,14 +27,14 @@ class SlavedPresenceStore(BaseSlavedStore):
|
|||
super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
|
||||
self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
|
||||
|
||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||
self._presence_on_startup = self._get_active_presence(db_conn) # type: ignore
|
||||
|
||||
self.presence_stream_cache = StreamChangeCache(
|
||||
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
|
||||
)
|
||||
|
||||
_get_active_presence = __func__(DataStore._get_active_presence)
|
||||
take_presence_startup_info = __func__(DataStore.take_presence_startup_info)
|
||||
_get_active_presence = DataStore._get_active_presence
|
||||
take_presence_startup_info = DataStore.take_presence_startup_info
|
||||
_get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
|
||||
get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]
|
||||
|
||||
|
|
|
@ -89,6 +89,8 @@ class ReplicationCommandHandler:
|
|||
self._streams_to_replicate.append(stream)
|
||||
continue
|
||||
|
||||
# Only add EventStream and BackfillStream as a source on the
|
||||
# instance in charge of event persistence.
|
||||
if (
|
||||
isinstance(stream, (EventsStream, BackfillStream))
|
||||
and hs.config.worker.writers.events == hs.get_instance_name()
|
||||
|
|
|
@ -14,14 +14,27 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import heapq
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import synapse.server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# the number of rows to request from an update_function.
|
||||
|
@ -37,7 +50,7 @@ Token = int
|
|||
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
|
||||
# just a row from a database query, though this is dependent on the stream in question.
|
||||
#
|
||||
StreamRow = Tuple
|
||||
StreamRow = TypeVar("StreamRow", bound=Tuple)
|
||||
|
||||
# The type returned by the update_function of a stream, as well as get_updates(),
|
||||
# get_updates_since, etc.
|
||||
|
@ -533,32 +546,63 @@ class AccountDataStream(Stream):
|
|||
"""
|
||||
|
||||
AccountDataStreamRow = namedtuple(
|
||||
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
|
||||
"AccountDataStream",
|
||||
("user_id", "room_id", "data_type"), # str # Optional[str] # str
|
||||
)
|
||||
|
||||
NAME = "account_data"
|
||||
ROW_TYPE = AccountDataStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(self.store.get_max_account_data_stream_id),
|
||||
db_query_to_update_function(self._update_function),
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
async def _update_function(self, from_token, to_token, limit):
|
||||
global_results, room_results = await self.store.get_all_updated_account_data(
|
||||
from_token, from_token, to_token, limit
|
||||
async def _update_function(
|
||||
self, instance_name: str, from_token: int, to_token: int, limit: int
|
||||
) -> StreamUpdateResult:
|
||||
limited = False
|
||||
global_results = await self.store.get_updated_global_account_data(
|
||||
from_token, to_token, limit
|
||||
)
|
||||
|
||||
results = list(room_results)
|
||||
results.extend(
|
||||
(stream_id, user_id, None, account_data_type)
|
||||
# if the global results hit the limit, we'll need to limit the room results to
|
||||
# the same stream token.
|
||||
if len(global_results) >= limit:
|
||||
to_token = global_results[-1][0]
|
||||
limited = True
|
||||
|
||||
room_results = await self.store.get_updated_room_account_data(
|
||||
from_token, to_token, limit
|
||||
)
|
||||
|
||||
# likewise, if the room results hit the limit, limit the global results to
|
||||
# the same stream token.
|
||||
if len(room_results) >= limit:
|
||||
to_token = room_results[-1][0]
|
||||
limited = True
|
||||
|
||||
# convert the global results to the right format, and limit them to the to_token
|
||||
# at the same time
|
||||
global_rows = (
|
||||
(stream_id, (user_id, None, account_data_type))
|
||||
for stream_id, user_id, account_data_type in global_results
|
||||
if stream_id <= to_token
|
||||
)
|
||||
|
||||
return results
|
||||
# we know that the room_results are already limited to `to_token` so no need
|
||||
# for a check on `stream_id` here.
|
||||
room_rows = (
|
||||
(stream_id, (user_id, room_id, account_data_type))
|
||||
for stream_id, user_id, room_id, account_data_type in room_results
|
||||
)
|
||||
|
||||
# we need to return a sorted list, so merge them together.
|
||||
updates = list(heapq.merge(room_rows, global_rows))
|
||||
return updates, to_token, limited
|
||||
|
||||
|
||||
class GroupServerStream(Stream):
|
||||
|
|
|
@ -401,19 +401,22 @@ class BaseSSORedirectServlet(RestServlet):
|
|||
|
||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
||||
|
||||
def on_GET(self, request: SynapseRequest):
|
||||
async def on_GET(self, request: SynapseRequest):
|
||||
args = request.args
|
||||
if b"redirectUrl" not in args:
|
||||
return 400, "Redirect URL not specified for SSO auth"
|
||||
client_redirect_url = args[b"redirectUrl"][0]
|
||||
sso_url = self.get_sso_url(client_redirect_url)
|
||||
sso_url = await self.get_sso_url(request, client_redirect_url)
|
||||
request.redirect(sso_url)
|
||||
finish_request(request)
|
||||
|
||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||
async def get_sso_url(
|
||||
self, request: SynapseRequest, client_redirect_url: bytes
|
||||
) -> bytes:
|
||||
"""Get the URL to redirect to, to perform SSO auth
|
||||
|
||||
Args:
|
||||
request: The client request to redirect.
|
||||
client_redirect_url: the URL that we should redirect the
|
||||
client to when everything is done
|
||||
|
||||
|
@ -428,7 +431,9 @@ class CasRedirectServlet(BaseSSORedirectServlet):
|
|||
def __init__(self, hs):
|
||||
self._cas_handler = hs.get_cas_handler()
|
||||
|
||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||
async def get_sso_url(
|
||||
self, request: SynapseRequest, client_redirect_url: bytes
|
||||
) -> bytes:
|
||||
return self._cas_handler.get_redirect_url(
|
||||
{"redirectUrl": client_redirect_url}
|
||||
).encode("ascii")
|
||||
|
@ -465,11 +470,13 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
|
|||
def __init__(self, hs):
|
||||
self._saml_handler = hs.get_saml_handler()
|
||||
|
||||
def get_sso_url(self, client_redirect_url: bytes) -> bytes:
|
||||
async def get_sso_url(
|
||||
self, request: SynapseRequest, client_redirect_url: bytes
|
||||
) -> bytes:
|
||||
return self._saml_handler.handle_redirect_request(client_redirect_url)
|
||||
|
||||
|
||||
class OIDCRedirectServlet(RestServlet):
|
||||
class OIDCRedirectServlet(BaseSSORedirectServlet):
|
||||
"""Implementation for /login/sso/redirect for the OIDC login flow."""
|
||||
|
||||
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
|
||||
|
@ -477,12 +484,12 @@ class OIDCRedirectServlet(RestServlet):
|
|||
def __init__(self, hs):
|
||||
self._oidc_handler = hs.get_oidc_handler()
|
||||
|
||||
async def on_GET(self, request):
|
||||
args = request.args
|
||||
if b"redirectUrl" not in args:
|
||||
return 400, "Redirect URL not specified for SSO auth"
|
||||
client_redirect_url = args[b"redirectUrl"][0]
|
||||
await self._oidc_handler.handle_redirect_request(request, client_redirect_url)
|
||||
async def get_sso_url(
|
||||
self, request: SynapseRequest, client_redirect_url: bytes
|
||||
) -> bytes:
|
||||
return await self._oidc_handler.handle_redirect_request(
|
||||
request, client_redirect_url
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
|
|
|
@ -131,14 +131,19 @@ class AuthRestServlet(RestServlet):
|
|||
self.registration_handler = hs.get_registration_handler()
|
||||
|
||||
# SSO configuration.
|
||||
self._saml_enabled = hs.config.saml2_enabled
|
||||
if self._saml_enabled:
|
||||
self._saml_handler = hs.get_saml_handler()
|
||||
self._cas_enabled = hs.config.cas_enabled
|
||||
if self._cas_enabled:
|
||||
self._cas_handler = hs.get_cas_handler()
|
||||
self._cas_server_url = hs.config.cas_server_url
|
||||
self._cas_service_url = hs.config.cas_service_url
|
||||
self._saml_enabled = hs.config.saml2_enabled
|
||||
if self._saml_enabled:
|
||||
self._saml_handler = hs.get_saml_handler()
|
||||
self._oidc_enabled = hs.config.oidc_enabled
|
||||
if self._oidc_enabled:
|
||||
self._oidc_handler = hs.get_oidc_handler()
|
||||
self._cas_server_url = hs.config.cas_server_url
|
||||
self._cas_service_url = hs.config.cas_service_url
|
||||
|
||||
async def on_GET(self, request, stagetype):
|
||||
session = parse_string(request, "session")
|
||||
|
@ -172,11 +177,17 @@ class AuthRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
elif self._saml_enabled:
|
||||
client_redirect_url = ""
|
||||
client_redirect_url = b""
|
||||
sso_redirect_url = self._saml_handler.handle_redirect_request(
|
||||
client_redirect_url, session
|
||||
)
|
||||
|
||||
elif self._oidc_enabled:
|
||||
client_redirect_url = b""
|
||||
sso_redirect_url = await self._oidc_handler.handle_redirect_request(
|
||||
request, client_redirect_url, session
|
||||
)
|
||||
|
||||
else:
|
||||
raise SynapseError(400, "Homeserver not configured for SSO.")
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
from six import PY3
|
||||
from six.moves import urllib
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -324,23 +323,15 @@ def get_filename_from_headers(headers):
|
|||
upload_name_utf8 = upload_name_utf8[7:]
|
||||
# We have a filename*= section. This MUST be ASCII, and any UTF-8
|
||||
# bytes are %-quoted.
|
||||
if PY3:
|
||||
try:
|
||||
# Once it is decoded, we can then unquote the %-encoded
|
||||
# parts strictly into a unicode string.
|
||||
upload_name = urllib.parse.unquote(
|
||||
upload_name_utf8.decode("ascii"), errors="strict"
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
# Incorrect UTF-8.
|
||||
pass
|
||||
else:
|
||||
# On Python 2, we first unquote the %-encoded parts and then
|
||||
# decode it strictly using UTF-8.
|
||||
try:
|
||||
upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8")
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
try:
|
||||
# Once it is decoded, we can then unquote the %-encoded
|
||||
# parts strictly into a unicode string.
|
||||
upload_name = urllib.parse.unquote(
|
||||
upload_name_utf8.decode("ascii"), errors="strict"
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
# Incorrect UTF-8.
|
||||
pass
|
||||
|
||||
# If there isn't check for an ascii name.
|
||||
if not upload_name:
|
||||
|
|
|
@ -19,9 +19,6 @@ import random
|
|||
from abc import ABCMeta
|
||||
from typing import Any, Optional
|
||||
|
||||
from six import PY2
|
||||
from six.moves import builtins
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.storage.database import LoggingTransaction # noqa: F401
|
||||
|
@ -103,11 +100,6 @@ def db_to_json(db_content):
|
|||
if isinstance(db_content, memoryview):
|
||||
db_content = db_content.tobytes()
|
||||
|
||||
# psycopg2 on Python 2 returns buffer objects, which we need to cast to
|
||||
# bytes to decode
|
||||
if PY2 and isinstance(db_content, builtins.buffer):
|
||||
db_content = bytes(db_content)
|
||||
|
||||
# Decode it to a Unicode string before feeding it to json.loads, so we
|
||||
# consistenty get a Unicode-containing object out.
|
||||
if isinstance(db_content, (bytes, bytearray)):
|
||||
|
|
|
@ -67,7 +67,7 @@ class DataStores(object):
|
|||
self.main = main_store_class(database, db_conn, hs)
|
||||
|
||||
# If we're on a process that can persist events also
|
||||
# instansiate a `PersistEventsStore`
|
||||
# instantiate a `PersistEventsStore`
|
||||
if hs.config.worker.writers.events == hs.get_instance_name():
|
||||
self.persist_events = PersistEventsStore(
|
||||
hs, database, self.main
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import abc
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
|
@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
|
||||
)
|
||||
|
||||
def get_all_updated_account_data(
|
||||
self, last_global_id, last_room_id, current_id, limit
|
||||
):
|
||||
"""Get all the client account_data that has changed on the server
|
||||
Args:
|
||||
last_global_id(int): The position to fetch from for top level data
|
||||
last_room_id(int): The position to fetch from for per room data
|
||||
current_id(int): The position to fetch up to.
|
||||
Returns:
|
||||
A deferred pair of lists of tuples of stream_id int, user_id string,
|
||||
room_id string, and type string.
|
||||
"""
|
||||
if last_room_id == current_id and last_global_id == current_id:
|
||||
return defer.succeed(([], []))
|
||||
async def get_updated_global_account_data(
|
||||
self, last_id: int, current_id: int, limit: int
|
||||
) -> List[Tuple[int, str, str]]:
|
||||
"""Get the global account_data that has changed, for the account_data stream
|
||||
|
||||
def get_updated_account_data_txn(txn):
|
||||
Args:
|
||||
last_id: the last stream_id from the previous batch.
|
||||
current_id: the maximum stream_id to return up to
|
||||
limit: the maximum number of rows to return
|
||||
|
||||
Returns:
|
||||
A list of tuples of stream_id int, user_id string,
|
||||
and type string.
|
||||
"""
|
||||
if last_id == current_id:
|
||||
return []
|
||||
|
||||
def get_updated_global_account_data_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, account_data_type"
|
||||
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_global_id, current_id, limit))
|
||||
global_results = txn.fetchall()
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return await self.db.runInteraction(
|
||||
"get_updated_global_account_data", get_updated_global_account_data_txn
|
||||
)
|
||||
|
||||
async def get_updated_room_account_data(
|
||||
self, last_id: int, current_id: int, limit: int
|
||||
) -> List[Tuple[int, str, str, str]]:
|
||||
"""Get the global account_data that has changed, for the account_data stream
|
||||
|
||||
Args:
|
||||
last_id: the last stream_id from the previous batch.
|
||||
current_id: the maximum stream_id to return up to
|
||||
limit: the maximum number of rows to return
|
||||
|
||||
Returns:
|
||||
A list of tuples of stream_id int, user_id string,
|
||||
room_id string and type string.
|
||||
"""
|
||||
if last_id == current_id:
|
||||
return []
|
||||
|
||||
def get_updated_room_account_data_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, room_id, account_data_type"
|
||||
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_room_id, current_id, limit))
|
||||
room_results = txn.fetchall()
|
||||
return global_results, room_results
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_updated_account_data_txn", get_updated_account_data_txn
|
||||
return await self.db.runInteraction(
|
||||
"get_updated_room_account_data", get_updated_room_account_data_txn
|
||||
)
|
||||
|
||||
def get_updated_account_data_for_user(self, user_id, stream_id):
|
||||
|
|
|
@ -30,12 +30,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _make_exclusive_regex(services_cache):
|
||||
# We precompie a regex constructed from all the regexes that the AS's
|
||||
# We precompile a regex constructed from all the regexes that the AS's
|
||||
# have registered for exclusive users.
|
||||
exclusive_user_regexes = [
|
||||
regex.pattern
|
||||
for service in services_cache
|
||||
for regex in service.get_exlusive_user_regexes()
|
||||
for regex in service.get_exclusive_user_regexes()
|
||||
]
|
||||
if exclusive_user_regexes:
|
||||
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
|
||||
|
|
|
@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||
desc="get_invited_users_in_group",
|
||||
)
|
||||
|
||||
def get_rooms_in_group(self, group_id, include_private=False):
|
||||
def get_rooms_in_group(self, group_id: str, include_private: bool = False):
|
||||
"""Retrieve the rooms that belong to a given group. Does not return rooms that
|
||||
lack members.
|
||||
|
||||
Args:
|
||||
group_id: The ID of the group to query for rooms
|
||||
include_private: Whether to return private rooms in results
|
||||
|
||||
Returns:
|
||||
Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
|
||||
form of:
|
||||
|
||||
{
|
||||
"room_id": "!a_room_id:example.com", # The ID of the room
|
||||
"is_public": False # Whether this is a public room or not
|
||||
}
|
||||
"""
|
||||
# TODO: Pagination
|
||||
|
||||
keyvalues = {"group_id": group_id}
|
||||
if not include_private:
|
||||
keyvalues["is_public"] = True
|
||||
def _get_rooms_in_group_txn(txn):
|
||||
sql = """
|
||||
SELECT room_id, is_public FROM group_rooms
|
||||
WHERE group_id = ?
|
||||
AND room_id IN (
|
||||
SELECT group_rooms.room_id FROM group_rooms
|
||||
LEFT JOIN room_stats_current ON
|
||||
group_rooms.room_id = room_stats_current.room_id
|
||||
AND joined_members > 0
|
||||
AND local_users_in_room > 0
|
||||
LEFT JOIN rooms ON
|
||||
group_rooms.room_id = rooms.room_id
|
||||
AND (room_version <> '') = ?
|
||||
)
|
||||
"""
|
||||
args = [group_id, False]
|
||||
|
||||
return self.db.simple_select_list(
|
||||
table="group_rooms",
|
||||
keyvalues=keyvalues,
|
||||
retcols=("room_id", "is_public"),
|
||||
desc="get_rooms_in_group",
|
||||
)
|
||||
if not include_private:
|
||||
sql += " AND is_public = ?"
|
||||
args += [True]
|
||||
|
||||
def get_rooms_for_summary_by_category(self, group_id, include_private=False):
|
||||
txn.execute(sql, args)
|
||||
|
||||
return [
|
||||
{"room_id": room_id, "is_public": is_public}
|
||||
for room_id, is_public in txn
|
||||
]
|
||||
|
||||
return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn)
|
||||
|
||||
def get_rooms_for_summary_by_category(
|
||||
self, group_id: str, include_private: bool = False,
|
||||
):
|
||||
"""Get the rooms and categories that should be included in a summary request
|
||||
|
||||
Returns ([rooms], [categories])
|
||||
Args:
|
||||
group_id: The ID of the group to query the summary for
|
||||
include_private: Whether to return private rooms in results
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[List, Dict]]: A tuple containing:
|
||||
|
||||
* A list of dictionaries with the keys:
|
||||
* "room_id": str, the room ID
|
||||
* "is_public": bool, whether the room is public
|
||||
* "category_id": str|None, the category ID if set, else None
|
||||
* "order": int, the sort order of rooms
|
||||
|
||||
* A dictionary with the key:
|
||||
* category_id (str): a dictionary with the keys:
|
||||
* "is_public": bool, whether the category is public
|
||||
* "profile": str, the category profile
|
||||
* "order": int, the sort order of rooms in this category
|
||||
"""
|
||||
|
||||
def _get_rooms_for_summary_txn(txn):
|
||||
|
@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||
SELECT room_id, is_public, category_id, room_order
|
||||
FROM group_summary_rooms
|
||||
WHERE group_id = ?
|
||||
AND room_id IN (
|
||||
SELECT group_rooms.room_id FROM group_rooms
|
||||
LEFT JOIN room_stats_current ON
|
||||
group_rooms.room_id = room_stats_current.room_id
|
||||
AND joined_members > 0
|
||||
AND local_users_in_room > 0
|
||||
LEFT JOIN rooms ON
|
||||
group_rooms.room_id = rooms.room_id
|
||||
AND (room_version <> '') = ?
|
||||
)
|
||||
"""
|
||||
|
||||
if not include_private:
|
||||
sql += " AND is_public = ?"
|
||||
txn.execute(sql, (group_id, True))
|
||||
txn.execute(sql, (group_id, False, True))
|
||||
else:
|
||||
txn.execute(sql, (group_id,))
|
||||
txn.execute(sql, (group_id, False))
|
||||
|
||||
rooms = [
|
||||
{
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
import itertools
|
||||
import logging
|
||||
|
||||
import six
|
||||
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
|
@ -28,12 +26,8 @@ from synapse.util.iterutils import batch_iter
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
|
||||
# despite being deprecated and removed in favor of memoryview
|
||||
if six.PY2:
|
||||
db_binary_type = six.moves.builtins.buffer
|
||||
else:
|
||||
db_binary_type = memoryview
|
||||
|
||||
db_binary_type = memoryview
|
||||
|
||||
|
||||
class KeyStore(SQLBaseStore):
|
||||
|
|
|
@ -45,7 +45,6 @@ from synapse.util.async_helpers import Linearizer
|
|||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.stringutils import to_ascii
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -179,7 +178,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
"""
|
||||
|
||||
txn.execute(sql, (room_id, Membership.JOIN))
|
||||
return [to_ascii(r[0]) for r in txn]
|
||||
return [r[0] for r in txn]
|
||||
|
||||
@cached(max_entries=100000)
|
||||
def get_room_summary(self, room_id):
|
||||
|
@ -223,7 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
txn.execute(sql, (room_id,))
|
||||
res = {}
|
||||
for count, membership in txn:
|
||||
summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
|
||||
summary = res.setdefault(membership, MemberSummary([], count))
|
||||
|
||||
# we order by membership and then fairly arbitrarily by event_id so
|
||||
# heroes are consistent
|
||||
|
@ -255,11 +254,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
# 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
|
||||
txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
|
||||
for user_id, membership, event_id in txn:
|
||||
summary = res[to_ascii(membership)]
|
||||
summary = res[membership]
|
||||
# we will always have a summary for this membership type at this
|
||||
# point given the summary currently contains the counts.
|
||||
members = summary.members
|
||||
members.append((to_ascii(user_id), to_ascii(event_id)))
|
||||
members.append((user_id, event_id))
|
||||
|
||||
return res
|
||||
|
||||
|
@ -584,13 +583,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
ev_entry = event_map.get(event_id)
|
||||
if ev_entry:
|
||||
if ev_entry.event.membership == Membership.JOIN:
|
||||
users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
|
||||
display_name=to_ascii(
|
||||
ev_entry.event.content.get("displayname", None)
|
||||
),
|
||||
avatar_url=to_ascii(
|
||||
ev_entry.event.content.get("avatar_url", None)
|
||||
),
|
||||
users_in_room[ev_entry.event.state_key] = ProfileInfo(
|
||||
display_name=ev_entry.event.content.get("displayname", None),
|
||||
avatar_url=ev_entry.event.content.get("avatar_url", None),
|
||||
)
|
||||
else:
|
||||
missing_member_event_ids.append(event_id)
|
||||
|
@ -604,9 +599,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
if event is not None and event.type == EventTypes.Member:
|
||||
if event.membership == Membership.JOIN:
|
||||
if event.event_id in member_event_ids:
|
||||
users_in_room[to_ascii(event.state_key)] = ProfileInfo(
|
||||
display_name=to_ascii(event.content.get("displayname", None)),
|
||||
avatar_url=to_ascii(event.content.get("avatar_url", None)),
|
||||
users_in_room[event.state_key] = ProfileInfo(
|
||||
display_name=event.content.get("displayname", None),
|
||||
avatar_url=event.content.get("avatar_url", None),
|
||||
)
|
||||
|
||||
return users_in_room
|
||||
|
|
|
@ -37,7 +37,55 @@ SearchEntry = namedtuple(
|
|||
)
|
||||
|
||||
|
||||
class SearchBackgroundUpdateStore(SQLBaseStore):
|
||||
class SearchWorkerStore(SQLBaseStore):
|
||||
def store_search_entries_txn(self, txn, entries):
|
||||
"""Add entries to the search table
|
||||
|
||||
Args:
|
||||
txn (cursor):
|
||||
entries (iterable[SearchEntry]):
|
||||
entries to be added to the table
|
||||
"""
|
||||
if not self.hs.config.enable_search:
|
||||
return
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = (
|
||||
"INSERT INTO event_search"
|
||||
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
|
||||
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
|
||||
)
|
||||
|
||||
args = (
|
||||
(
|
||||
entry.event_id,
|
||||
entry.room_id,
|
||||
entry.key,
|
||||
entry.value,
|
||||
entry.stream_ordering,
|
||||
entry.origin_server_ts,
|
||||
)
|
||||
for entry in entries
|
||||
)
|
||||
|
||||
txn.executemany(sql, args)
|
||||
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
sql = (
|
||||
"INSERT INTO event_search (event_id, room_id, key, value)"
|
||||
" VALUES (?,?,?,?)"
|
||||
)
|
||||
args = (
|
||||
(entry.event_id, entry.room_id, entry.key, entry.value)
|
||||
for entry in entries
|
||||
)
|
||||
|
||||
txn.executemany(sql, args)
|
||||
else:
|
||||
# This should be unreachable.
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
||||
|
||||
class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
|
||||
EVENT_SEARCH_UPDATE_NAME = "event_search"
|
||||
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
|
||||
|
@ -296,52 +344,6 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
return num_rows
|
||||
|
||||
def store_search_entries_txn(self, txn, entries):
|
||||
"""Add entries to the search table
|
||||
|
||||
Args:
|
||||
txn (cursor):
|
||||
entries (iterable[SearchEntry]):
|
||||
entries to be added to the table
|
||||
"""
|
||||
if not self.hs.config.enable_search:
|
||||
return
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = (
|
||||
"INSERT INTO event_search"
|
||||
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
|
||||
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
|
||||
)
|
||||
|
||||
args = (
|
||||
(
|
||||
entry.event_id,
|
||||
entry.room_id,
|
||||
entry.key,
|
||||
entry.value,
|
||||
entry.stream_ordering,
|
||||
entry.origin_server_ts,
|
||||
)
|
||||
for entry in entries
|
||||
)
|
||||
|
||||
txn.executemany(sql, args)
|
||||
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
sql = (
|
||||
"INSERT INTO event_search (event_id, room_id, key, value)"
|
||||
" VALUES (?,?,?,?)"
|
||||
)
|
||||
args = (
|
||||
(entry.event_id, entry.room_id, entry.key, entry.value)
|
||||
for entry in entries
|
||||
)
|
||||
|
||||
txn.executemany(sql, args)
|
||||
else:
|
||||
# This should be unreachable.
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
||||
|
||||
class SearchStore(SearchBackgroundUpdateStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
|
|
|
@ -29,7 +29,6 @@ from synapse.storage.database import Database
|
|||
from synapse.storage.state import StateFilter
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.stringutils import to_ascii
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -185,9 +184,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
(room_id,),
|
||||
)
|
||||
|
||||
return {
|
||||
(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
|
||||
}
|
||||
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_current_state_ids", _get_current_state_ids_txn
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import six
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
@ -27,12 +25,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
|
|||
from synapse.storage.database import Database
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
||||
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
|
||||
# despite being deprecated and removed in favor of memoryview
|
||||
if six.PY2:
|
||||
db_binary_type = six.moves.builtins.buffer
|
||||
else:
|
||||
db_binary_type = memoryview
|
||||
db_binary_type = memoryview
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -50,7 +50,6 @@ from synapse.storage.background_updates import BackgroundUpdater
|
|||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.types import Collection
|
||||
from synapse.util.stringutils import exception_to_unicode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -424,20 +423,14 @@ class Database(object):
|
|||
# This can happen if the database disappears mid
|
||||
# transaction.
|
||||
logger.warning(
|
||||
"[TXN OPERROR] {%s} %s %d/%d",
|
||||
name,
|
||||
exception_to_unicode(e),
|
||||
i,
|
||||
N,
|
||||
"[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
|
||||
)
|
||||
if i < N:
|
||||
i += 1
|
||||
try:
|
||||
conn.rollback()
|
||||
except self.engine.module.Error as e1:
|
||||
logger.warning(
|
||||
"[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
|
||||
)
|
||||
logger.warning("[TXN EROLL] {%s} %s", name, e1)
|
||||
continue
|
||||
raise
|
||||
except self.engine.module.DatabaseError as e:
|
||||
|
@ -449,9 +442,7 @@ class Database(object):
|
|||
conn.rollback()
|
||||
except self.engine.module.Error as e1:
|
||||
logger.warning(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name,
|
||||
exception_to_unicode(e1),
|
||||
"[TXN EROLL] {%s} %s", name, e1,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
|
|
@ -15,11 +15,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from sys import intern
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import six
|
||||
from six.moves import intern
|
||||
|
||||
import attr
|
||||
from prometheus_client.core import Gauge
|
||||
|
||||
|
@ -154,9 +152,6 @@ def intern_string(string):
|
|||
return None
|
||||
|
||||
try:
|
||||
if six.PY2:
|
||||
string = string.encode("ascii")
|
||||
|
||||
return intern(string)
|
||||
except UnicodeEncodeError:
|
||||
return string
|
||||
|
|
|
@ -19,10 +19,6 @@ import re
|
|||
import string
|
||||
from collections import Iterable
|
||||
|
||||
import six
|
||||
from six import PY2, PY3
|
||||
from six.moves import range
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
|
||||
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
|
||||
|
@ -47,80 +43,16 @@ def random_string_with_symbols(length):
|
|||
|
||||
|
||||
def is_ascii(s):
|
||||
|
||||
if PY3:
|
||||
if isinstance(s, bytes):
|
||||
try:
|
||||
s.decode("ascii").encode("ascii")
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
except UnicodeEncodeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
try:
|
||||
s.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
return False
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
else:
|
||||
if isinstance(s, bytes):
|
||||
try:
|
||||
s.decode("ascii").encode("ascii")
|
||||
except UnicodeDecodeError:
|
||||
return False
|
||||
except UnicodeEncodeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def to_ascii(s):
|
||||
"""Converts a string to ascii if it is ascii, otherwise leave it alone.
|
||||
|
||||
If given None then will return None.
|
||||
"""
|
||||
if PY3:
|
||||
return s
|
||||
|
||||
if s is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return s.encode("ascii")
|
||||
except UnicodeEncodeError:
|
||||
return s
|
||||
|
||||
|
||||
def exception_to_unicode(e):
|
||||
"""Helper function to extract the text of an exception as a unicode string
|
||||
|
||||
Args:
|
||||
e (Exception): exception to be stringified
|
||||
|
||||
Returns:
|
||||
unicode
|
||||
"""
|
||||
# urgh, this is a mess. The basic problem here is that psycopg2 constructs its
|
||||
# exceptions with PyErr_SetString, with a (possibly non-ascii) argument. str() will
|
||||
# then produce the raw byte sequence. Under Python 2, this will then cause another
|
||||
# error if it gets mixed with a `unicode` object, as per
|
||||
# https://github.com/matrix-org/synapse/issues/4252
|
||||
|
||||
# First of all, if we're under python3, everything is fine because it will sort this
|
||||
# nonsense out for us.
|
||||
if not PY2:
|
||||
return str(e)
|
||||
|
||||
# otherwise let's have a stab at decoding the exception message. We'll circumvent
|
||||
# Exception.__str__(), which would explode if someone raised Exception(u'non-ascii')
|
||||
# and instead look at what is in the args member.
|
||||
|
||||
if len(e.args) == 0:
|
||||
return ""
|
||||
elif len(e.args) > 1:
|
||||
return six.text_type(repr(e.args))
|
||||
|
||||
msg = e.args[0]
|
||||
if isinstance(msg, bytes):
|
||||
return msg.decode("utf-8", errors="replace")
|
||||
else:
|
||||
return msg
|
||||
|
||||
|
||||
def assert_valid_client_secret(client_secret):
|
||||
"""Validate that a given string matches the client_secret regex defined by the spec"""
|
||||
if client_secret_regex.match(client_secret) is None:
|
||||
|
|
|
@ -292,11 +292,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_redirect_request(self):
|
||||
"""The redirect request has the right arguments & generates a valid session cookie."""
|
||||
req = Mock(spec=["addCookie", "redirect", "finish"])
|
||||
yield defer.ensureDeferred(
|
||||
req = Mock(spec=["addCookie"])
|
||||
url = yield defer.ensureDeferred(
|
||||
self.handler.handle_redirect_request(req, b"http://client/redirect")
|
||||
)
|
||||
url = req.redirect.call_args[0][0]
|
||||
url = urlparse(url)
|
||||
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
|
||||
|
||||
|
@ -382,7 +381,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
nonce = "nonce"
|
||||
client_redirect_url = "http://client/redirect"
|
||||
session = self.handler._generate_oidc_session_token(
|
||||
state=state, nonce=nonce, client_redirect_url=client_redirect_url,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=None,
|
||||
)
|
||||
request.getCookie.return_value = session
|
||||
|
||||
|
@ -472,7 +474,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|||
|
||||
# Mismatching session
|
||||
session = self.handler._generate_oidc_session_token(
|
||||
state="state", nonce="nonce", client_redirect_url="http://client/redirect",
|
||||
state="state",
|
||||
nonce="nonce",
|
||||
client_redirect_url="http://client/redirect",
|
||||
ui_auth_session_id=None,
|
||||
)
|
||||
request.args = {}
|
||||
request.args[b"state"] = [b"mismatching state"]
|
||||
|
|
|
@ -17,11 +17,12 @@ from canonicaljson import encode_canonical_json
|
|||
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.handlers.room import RoomEventSource
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
|
||||
from tests.server import FakeTransport
|
||||
|
||||
from ._base import BaseSlavedStoreTestCase
|
||||
|
||||
USER_ID = "@feeling:test"
|
||||
|
@ -240,6 +241,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
|||
|
||||
# limit the replication rate
|
||||
repl_transport = self._server_transport
|
||||
assert isinstance(repl_transport, FakeTransport)
|
||||
repl_transport.autoflush = False
|
||||
|
||||
# build the join and message events and persist them in the same batch.
|
||||
|
@ -322,7 +324,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
|||
type="m.room.message",
|
||||
key=None,
|
||||
internal={},
|
||||
state=None,
|
||||
depth=None,
|
||||
prev_events=[],
|
||||
auth_events=[],
|
||||
|
@ -362,15 +363,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
|||
event = make_event_from_dict(event_dict, internal_metadata_dict=internal)
|
||||
|
||||
self.event_id += 1
|
||||
|
||||
if state is not None:
|
||||
state_ids = {key: e.event_id for key, e in state.items()}
|
||||
context = EventContext.with_state(
|
||||
state_group=None, current_state_ids=state_ids, prev_state_ids=state_ids
|
||||
)
|
||||
else:
|
||||
state_handler = self.hs.get_state_handler()
|
||||
context = self.get_success(state_handler.compute_event_context(event))
|
||||
state_handler = self.hs.get_state_handler()
|
||||
context = self.get_success(state_handler.compute_event_context(event))
|
||||
|
||||
self.master_store.add_push_actions_to_staging(
|
||||
event.event_id, {user_id: actions for user_id, actions in push_actions}
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
_STREAM_UPDATE_TARGET_ROW_COUNT,
|
||||
AccountDataStream,
|
||||
)
|
||||
|
||||
from tests.replication._base import BaseStreamTestCase
|
||||
|
||||
|
||||
class AccountDataStreamTestCase(BaseStreamTestCase):
|
||||
def test_update_function_room_account_data_limit(self):
|
||||
"""Test replication with many room account data updates
|
||||
"""
|
||||
store = self.hs.get_datastore()
|
||||
|
||||
# generate lots of account data updates
|
||||
updates = []
|
||||
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
|
||||
update = "m.test_type.%i" % (i,)
|
||||
self.get_success(
|
||||
store.add_account_data_to_room("test_user", "test_room", update, {})
|
||||
)
|
||||
updates.append(update)
|
||||
|
||||
# also one global update
|
||||
self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))
|
||||
|
||||
# tell the notifier to catch up to avoid duplicate rows.
|
||||
# workaround for https://github.com/matrix-org/synapse/issues/7360
|
||||
# FIXME remove this when the above is fixed
|
||||
self.replicate()
|
||||
|
||||
# check we're testing what we think we are: no rows should yet have been
|
||||
# received
|
||||
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||
|
||||
# now reconnect to pull the updates
|
||||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# we should have received all the expected rows in the right order
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
|
||||
for t in updates:
|
||||
(stream_name, token, row) = received_rows.pop(0)
|
||||
self.assertEqual(stream_name, AccountDataStream.NAME)
|
||||
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
|
||||
self.assertEqual(row.data_type, t)
|
||||
self.assertEqual(row.room_id, "test_room")
|
||||
|
||||
(stream_name, token, row) = received_rows.pop(0)
|
||||
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
|
||||
self.assertEqual(row.data_type, "m.global")
|
||||
self.assertIsNone(row.room_id)
|
||||
|
||||
self.assertEqual([], received_rows)
|
||||
|
||||
def test_update_function_global_account_data_limit(self):
|
||||
"""Test replication with many global account data updates
|
||||
"""
|
||||
store = self.hs.get_datastore()
|
||||
|
||||
# generate lots of account data updates
|
||||
updates = []
|
||||
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
|
||||
update = "m.test_type.%i" % (i,)
|
||||
self.get_success(store.add_account_data_for_user("test_user", update, {}))
|
||||
updates.append(update)
|
||||
|
||||
# also one per-room update
|
||||
self.get_success(
|
||||
store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
|
||||
)
|
||||
|
||||
# tell the notifier to catch up to avoid duplicate rows.
|
||||
# workaround for https://github.com/matrix-org/synapse/issues/7360
|
||||
# FIXME remove this when the above is fixed
|
||||
self.replicate()
|
||||
|
||||
# check we're testing what we think we are: no rows should yet have been
|
||||
# received
|
||||
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||
|
||||
# now reconnect to pull the updates
|
||||
self.reconnect()
|
||||
self.replicate()
|
||||
|
||||
# we should have received all the expected rows in the right order
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
|
||||
for t in updates:
|
||||
(stream_name, token, row) = received_rows.pop(0)
|
||||
self.assertEqual(stream_name, AccountDataStream.NAME)
|
||||
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
|
||||
self.assertEqual(row.data_type, t)
|
||||
self.assertIsNone(row.room_id)
|
||||
|
||||
(stream_name, token, row) = received_rows.pop(0)
|
||||
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
|
||||
self.assertEqual(row.data_type, "m.per_room")
|
||||
self.assertEqual(row.room_id, "test_room")
|
||||
|
||||
self.assertEqual([], received_rows)
|
|
@ -30,7 +30,7 @@ class ParseCommandTestCase(TestCase):
|
|||
def test_parse_rdata(self):
|
||||
line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
|
||||
cmd = parse_command_from_line(line)
|
||||
self.assertIsInstance(cmd, RdataCommand)
|
||||
assert isinstance(cmd, RdataCommand)
|
||||
self.assertEqual(cmd.stream_name, "events")
|
||||
self.assertEqual(cmd.instance_name, "master")
|
||||
self.assertEqual(cmd.token, 6287863)
|
||||
|
@ -38,7 +38,7 @@ class ParseCommandTestCase(TestCase):
|
|||
def test_parse_rdata_batch(self):
|
||||
line = 'RDATA presence master batch ["@foo:example.com", "online"]'
|
||||
cmd = parse_command_from_line(line)
|
||||
self.assertIsInstance(cmd, RdataCommand)
|
||||
assert isinstance(cmd, RdataCommand)
|
||||
self.assertEqual(cmd.stream_name, "presence")
|
||||
self.assertEqual(cmd.instance_name, "master")
|
||||
self.assertIsNone(cmd.token)
|
||||
|
|
4
tox.ini
4
tox.ini
|
@ -188,6 +188,8 @@ commands = mypy \
|
|||
synapse/handlers/directory.py \
|
||||
synapse/handlers/oidc_handler.py \
|
||||
synapse/handlers/presence.py \
|
||||
synapse/handlers/room_member.py \
|
||||
synapse/handlers/room_member_worker.py \
|
||||
synapse/handlers/saml_handler.py \
|
||||
synapse/handlers/sync.py \
|
||||
synapse/handlers/ui_auth \
|
||||
|
@ -205,7 +207,7 @@ commands = mypy \
|
|||
synapse/storage/util \
|
||||
synapse/streams \
|
||||
synapse/util/caches/stream_change_cache.py \
|
||||
tests/replication/tcp/streams \
|
||||
tests/replication \
|
||||
tests/test_utils \
|
||||
tests/rest/client/v2_alpha/test_auth.py \
|
||||
tests/util/test_stream_change_cache.py
|
||||
|
|
Loading…
Reference in New Issue