Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

michaelkaye/remove_warning
Erik Johnston 2021-01-19 10:19:25 +00:00
commit bed4fa29fd
46 changed files with 1055 additions and 370 deletions

1
changelog.d/9104.feature Normal file
View File

@ -0,0 +1 @@
Add experimental support for moving off receipts and account data persistence off master.

1
changelog.d/9127.feature Normal file
View File

@ -0,0 +1 @@
Add support for multiple SSO Identity Providers.

1
changelog.d/9128.bugfix Normal file
View File

@ -0,0 +1 @@
Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login.

1
changelog.d/9144.misc Normal file
View File

@ -0,0 +1 @@
Enforce that replication HTTP clients are called with keyword arguments only.

1
changelog.d/9145.bugfix Normal file
View File

@ -0,0 +1 @@
Fix "UnboundLocalError: local variable 'length' referenced before assignment" errors when the response body exceeds the expected size. This bug was introduced in v1.25.0.

1
changelog.d/9146.misc Normal file
View File

@ -0,0 +1 @@
Fix the Python 3.5 + old dependencies build in CI.

1
changelog.d/9151.doc Normal file
View File

@ -0,0 +1 @@
Quote `pip install` packages when extras are used to avoid shells interpreting bracket characters.

View File

@ -18,7 +18,7 @@ connect to a postgres database.
virtualenv](../INSTALL.md#installing-from-source), you can install virtualenv](../INSTALL.md#installing-from-source), you can install
the library with: the library with:
~/synapse/env/bin/pip install matrix-synapse[postgres] ~/synapse/env/bin/pip install "matrix-synapse[postgres]"
(substituting the path to your virtualenv for `~/synapse/env`, if (substituting the path to your virtualenv for `~/synapse/env`, if
you used a different path). You will require the postgres you used a different path). You will require the postgres

View File

@ -59,7 +59,7 @@ The appropriate dependencies must also be installed for Synapse. If using a
virtualenv, these can be installed with: virtualenv, these can be installed with:
```sh ```sh
pip install matrix-synapse[redis] pip install "matrix-synapse[redis]"
``` ```
Note that these dependencies are included when synapse is installed with `pip Note that these dependencies are included when synapse is installed with `pip

View File

@ -100,7 +100,16 @@ from synapse.rest.client.v1.profile import (
) )
from synapse.rest.client.v1.push_rule import PushRuleRestServlet from synapse.rest.client.v1.push_rule import PushRuleRestServlet
from synapse.rest.client.v1.voip import VoipRestServlet from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory from synapse.rest.client.v2_alpha import (
account_data,
groups,
read_marker,
receipts,
room_keys,
sync,
tags,
user_directory,
)
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
from synapse.rest.client.v2_alpha.account_data import ( from synapse.rest.client.v2_alpha.account_data import (
@ -531,6 +540,10 @@ class GenericWorkerServer(HomeServer):
room.register_deprecated_servlets(self, resource) room.register_deprecated_servlets(self, resource)
InitialSyncRestServlet(self).register(resource) InitialSyncRestServlet(self).register(resource)
room_keys.register_servlets(self, resource) room_keys.register_servlets(self, resource)
tags.register_servlets(self, resource)
account_data.register_servlets(self, resource)
receipts.register_servlets(self, resource)
read_marker.register_servlets(self, resource)
SendToDeviceRestServlet(self).register(resource) SendToDeviceRestServlet(self).register(resource)

View File

@ -56,6 +56,12 @@ class WriterLocations:
to_device = attr.ib( to_device = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter, default=["master"], type=List[str], converter=_instance_to_list_converter,
) )
account_data = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
)
receipts = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
)
class WorkerConfig(Config): class WorkerConfig(Config):
@ -127,7 +133,7 @@ class WorkerConfig(Config):
# Check that the configured writers for events and typing also appears in # Check that the configured writers for events and typing also appears in
# `instance_map`. # `instance_map`.
for stream in ("events", "typing", "to_device"): for stream in ("events", "typing", "to_device", "account_data", "receipts"):
instances = _instance_to_list_converter(getattr(self.writers, stream)) instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances: for instance in instances:
if instance != "master" and instance not in self.instance_map: if instance != "master" and instance not in self.instance_map:
@ -141,6 +147,16 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `to_device` messages." "Must only specify one instance to handle `to_device` messages."
) )
if len(self.writers.account_data) != 1:
raise ConfigError(
"Must only specify one instance to handle `account_data` messages."
)
if len(self.writers.receipts) != 1:
raise ConfigError(
"Must only specify one instance to handle `receipts` messages."
)
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
# Whether this worker should run background tasks or not. # Whether this worker should run background tasks or not.

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,14 +13,157 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, List, Tuple
from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
ReplicationRemoveTagRestServlet,
ReplicationRoomAccountDataRestServlet,
ReplicationUserAccountDataRestServlet,
)
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer from synapse.app.homeserver import HomeServer
class AccountDataHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
self._instance_name = hs.get_instance_name()
self._notifier = hs.get_notifier()
self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
self._account_data_writers = hs.config.worker.writers.account_data
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user.
Args:
user_id: The user to add a tag for.
room_id: The room to add a tag for.
account_data_type: The type of account_data to add.
content: A json object to associate with the tag.
Returns:
The maximum stream ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.add_account_data_to_room(
user_id, room_id, account_data_type, content
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._room_data_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
account_data_type=account_data_type,
content=content,
)
return response["max_stream_id"]
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user.
Args:
user_id: The user to add a tag for.
account_data_type: The type of account_data to add.
content: A json object to associate with the tag.
Returns:
The maximum stream ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.add_account_data_for_user(
user_id, account_data_type, content
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._user_data_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
account_data_type=account_data_type,
content=content,
)
return response["max_stream_id"]
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int:
"""Add a tag to a room for a user.
Args:
user_id: The user to add a tag for.
room_id: The room to add a tag for.
tag: The tag name to add.
content: A json object to associate with the tag.
Returns:
The next account data ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.add_tag_to_room(
user_id, room_id, tag, content
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._add_tag_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
tag=tag,
content=content,
)
return response["max_stream_id"]
async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
"""Remove a tag from a room for a user.
Returns:
The next account data ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.remove_tag_from_room(
user_id, room_id, tag
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._remove_tag_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
tag=tag,
)
return response["max_stream_id"]
class AccountDataEventSource: class AccountDataEventSource:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()

View File

@ -1504,8 +1504,8 @@ class AuthHandler(BaseHandler):
@staticmethod @staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any): def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url)) url_parts = list(urllib.parse.urlparse(url))
query = dict(urllib.parse.parse_qsl(url_parts[4])) query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
query.update({param_name: param}) query.append((param_name, param))
url_parts[4] = urllib.parse.urlencode(query) url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts) return urllib.parse.urlunparse(url_parts)

View File

@ -85,7 +85,7 @@ class OidcHandler:
self._token_generator = OidcSessionTokenGenerator(hs) self._token_generator = OidcSessionTokenGenerator(hs)
self._providers = { self._providers = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
} } # type: Dict[str, OidcProvider]
async def load_metadata(self) -> None: async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint. """Validate the config and load the metadata from the remote endpoint.

View File

@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
super().__init__(hs) super().__init__(hs)
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.account_data_handler = hs.get_account_data_handler()
self.read_marker_linearizer = Linearizer(name="read_marker") self.read_marker_linearizer = Linearizer(name="read_marker")
self.notifier = hs.get_notifier()
async def received_client_read_marker( async def received_client_read_marker(
self, room_id: str, user_id: str, event_id: str self, room_id: str, user_id: str, event_id: str
@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
if should_update: if should_update:
content = {"event_id": event_id} content = {"event_id": event_id}
max_id = await self.store.add_account_data_to_room( await self.account_data_handler.add_account_data_to_room(
user_id, room_id, "m.fully_read", content user_id, room_id, "m.fully_read", content
) )
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])

View File

@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler):
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.federation = hs.get_federation_sender()
# We only need to poke the federation sender explicitly if its on the
# same instance. Other federation sender instances will get notified by
# `synapse.app.generic_worker.FederationSenderHandler` when it sees it
# in the receipts stream.
self.federation_sender = None
if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender()
# If we can handle the receipt EDUs we do so, otherwise we route them
# to the appropriate worker.
if hs.get_instance_name() in hs.config.worker.writers.receipts:
hs.get_federation_registry().register_edu_handler( hs.get_federation_registry().register_edu_handler(
"m.receipt", self._received_remote_receipt "m.receipt", self._received_remote_receipt
) )
else:
hs.get_federation_registry().register_instances_for_edu(
"m.receipt", hs.config.worker.writers.receipts,
)
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler):
if not is_new: if not is_new:
return return
await self.federation.send_read_receipt(receipt) if self.federation_sender:
await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource: class ReceiptEventSource:

View File

@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.account_data_handler = hs.get_account_data_handler()
self.member_linearizer = Linearizer(name="member") self.member_linearizer = Linearizer(name="member")
self.member_limiter = Linearizer(max_count=10, name="member_as_limiter") self.member_limiter = Linearizer(max_count=10, name="member_as_limiter")
@ -254,7 +255,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
direct_rooms[key].append(new_room_id) direct_rooms[key].append(new_room_id)
# Save back to user's m.direct account data # Save back to user's m.direct account data
await self.store.add_account_data_for_user( await self.account_data_handler.add_account_data_for_user(
user_id, AccountDataTypes.DIRECT, direct_rooms user_id, AccountDataTypes.DIRECT, direct_rooms
) )
break break
@ -264,7 +265,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Copy each room tag to the new room # Copy each room tag to the new room
for tag, tag_content in room_tags.items(): for tag, tag_content in room_tags.items():
await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) await self.account_data_handler.add_tag_to_room(
user_id, new_room_id, tag, tag_content
)
async def update_membership( async def update_membership(
self, self,

View File

@ -724,7 +724,7 @@ class SimpleHttpClient:
read_body_with_max_size(response, output_stream, max_size) read_body_with_max_size(response, output_stream, max_size)
) )
except BodyExceededMaxSize: except BodyExceededMaxSize:
SynapseError( raise SynapseError(
502, 502,
"Requested file is too large > %r bytes" % (max_size,), "Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE, Codes.TOO_LARGE,

View File

@ -996,7 +996,7 @@ class MatrixFederationHttpClient:
logger.warning( logger.warning(
"{%s} [%s] %s", request.txn_id, request.destination, msg, "{%s} [%s] %s", request.txn_id, request.destination, msg,
) )
SynapseError(502, msg, Codes.TOO_LARGE) raise SynapseError(502, msg, Codes.TOO_LARGE)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"{%s} [%s] Error reading response: %s", "{%s} [%s] Error reading response: %s",

View File

@ -15,6 +15,7 @@
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.replication.http import ( from synapse.replication.http import (
account_data,
devices, devices,
federation, federation,
login, login,
@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource):
presence.register_servlets(hs, self) presence.register_servlets(hs, self)
membership.register_servlets(hs, self) membership.register_servlets(hs, self)
streams.register_servlets(hs, self) streams.register_servlets(hs, self)
account_data.register_servlets(hs, self)
# The following can't currently be instantiated on workers. # The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:

View File

@ -177,7 +177,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
@trace(opname="outgoing_replication_request") @trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress() @outgoing_gauge.track_inprogress()
async def send_request(instance_name="master", **kwargs): async def send_request(*, instance_name="master", **kwargs):
if instance_name == local_instance_name: if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self") raise Exception("Trying to send HTTP request to self")
if instance_name == "master": if instance_name == "master":

View File

@ -0,0 +1,187 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
"""Add user account data on the appropriate account data worker.
Request format:
POST /_synapse/replication/add_user_account_data/:user_id/:type
{
"content": { ... },
}
"""
NAME = "add_user_account_data"
PATH_ARGS = ("user_id", "account_data_type")
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, account_data_type, content):
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, account_data_type):
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_for_user(
user_id, account_data_type, content["content"]
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
"""Add room account data on the appropriate account data worker.
Request format:
POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
{
"content": { ... },
}
"""
NAME = "add_room_account_data"
PATH_ARGS = ("user_id", "room_id", "account_data_type")
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, account_data_type, content):
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, room_id, account_data_type):
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, content["content"]
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationAddTagRestServlet(ReplicationEndpoint):
"""Add tag on the appropriate account data worker.
Request format:
POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
{
"content": { ... },
}
"""
NAME = "add_tag"
PATH_ARGS = ("user_id", "room_id", "tag")
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, tag, content):
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, room_id, tag):
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_tag_to_room(
user_id, room_id, tag, content["content"]
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
"""Remove tag on the appropriate account data worker.
Request format:
POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
{}
"""
NAME = "remove_tag"
PATH_ARGS = (
"user_id",
"room_id",
"tag",
)
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, tag):
return {}
async def _handle_request(self, request, user_id, room_id, tag):
max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
return 200, {"max_stream_id": max_stream_id}
def register_servlets(hs, http_server):
ReplicationUserAccountDataRestServlet(hs).register(http_server)
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
ReplicationAddTagRestServlet(hs).register(http_server)
ReplicationRemoveTagRestServlet(hs).register(http_server)

View File

@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
database, database,
stream_name="caches", stream_name="caches",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance", tables=[
instance_column="instance_name", (
id_column="stream_id", "cache_invalidation_stream_by_instance",
"instance_name",
"stream_id",
)
],
sequence_name="cache_invalidation_stream_seq", sequence_name="cache_invalidation_stream_seq",
writers=[], writers=[],
) # type: Optional[MultiWriterIdGenerator] ) # type: Optional[MultiWriterIdGenerator]

View File

@ -15,47 +15,9 @@
# limitations under the License. # limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): pass
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"account_data",
"stream_id",
extra_tables=[
("room_account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
],
)
super().__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
(row.data_type, row.user_id)
)
self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@ -14,43 +14,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): pass
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
super().__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@ -51,11 +51,14 @@ from synapse.replication.tcp.commands import (
from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import ( from synapse.replication.tcp.streams import (
STREAMS_MAP, STREAMS_MAP,
AccountDataStream,
BackfillStream, BackfillStream,
CachesStream, CachesStream,
EventsStream, EventsStream,
FederationStream, FederationStream,
ReceiptsStream,
Stream, Stream,
TagAccountDataStream,
ToDeviceStream, ToDeviceStream,
TypingStream, TypingStream,
) )
@ -132,6 +135,22 @@ class ReplicationCommandHandler:
continue continue
if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
# Only add AccountDataStream and TagAccountDataStream as a source on the
# instance in charge of account_data persistence.
if hs.get_instance_name() in hs.config.worker.writers.account_data:
self._streams_to_replicate.append(stream)
continue
if isinstance(stream, ReceiptsStream):
# Only add ReceiptsStream as a source on the instance in charge of
# receipts.
if hs.get_instance_name() in hs.config.worker.writers.receipts:
self._streams_to_replicate.append(stream)
continue
# Only add any other streams if we're on master. # Only add any other streams if we're on master.
if hs.config.worker_app is not None: if hs.config.worker_app is not None:
continue continue

View File

@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.handler = hs.get_account_data_handler()
self._is_worker = hs.config.worker_app is not None
async def on_PUT(self, request, user_id, account_data_type): async def on_PUT(self, request, user_id, account_data_type):
if self._is_worker:
raise Exception("Cannot handle PUT /account_data on worker")
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
max_id = await self.store.add_account_data_for_user( await self.handler.add_account_data_for_user(user_id, account_data_type, body)
user_id, account_data_type, body
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {} return 200, {}
@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.handler = hs.get_account_data_handler()
self._is_worker = hs.config.worker_app is not None
async def on_PUT(self, request, user_id, room_id, account_data_type): async def on_PUT(self, request, user_id, room_id, account_data_type):
if self._is_worker:
raise Exception("Cannot handle PUT /account_data on worker")
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers", " Use /rooms/!roomId:server.name/read_markers",
) )
max_id = await self.store.add_account_data_to_room( await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, body user_id, room_id, account_data_type, body
) )
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {} return 200, {}
async def on_GET(self, request, user_id, room_id, account_data_type): async def on_GET(self, request, user_id, room_id, account_data_type):

View File

@ -58,8 +58,7 @@ class TagServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.handler = hs.get_account_data_handler()
self.notifier = hs.get_notifier()
async def on_PUT(self, request, user_id, room_id, tag): async def on_PUT(self, request, user_id, room_id, tag):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -68,9 +67,7 @@ class TagServlet(RestServlet):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) await self.handler.add_tag_to_room(user_id, room_id, tag, body)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {} return 200, {}
@ -79,9 +76,7 @@ class TagServlet(RestServlet):
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) await self.handler.remove_tag_from_room(user_id, room_id, tag)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {} return 200, {}

View File

@ -45,7 +45,9 @@ class PickIdpResource(DirectServeHtmlResource):
self._server_name = hs.hostname self._server_name = hs.hostname
async def _async_render_GET(self, request: SynapseRequest) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
client_redirect_url = parse_string(request, "redirectUrl", required=True) client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding="utf-8"
)
idp = parse_string(request, "idp", required=False) idp = parse_string(request, "idp", required=False)
# if we need to pick an IdP, do so # if we need to pick an IdP, do so

View File

@ -55,6 +55,7 @@ from synapse.federation.sender import FederationSender
from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
from synapse.handlers.account_data import AccountDataHandler
from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.acme import AcmeHandler from synapse.handlers.acme import AcmeHandler
from synapse.handlers.admin import AdminHandler from synapse.handlers.admin import AdminHandler
@ -711,6 +712,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_module_api(self) -> ModuleApi: def get_module_api(self) -> ModuleApi:
return ModuleApi(self, self.get_auth_handler()) return ModuleApi(self, self.get_auth_handler())
@cache_in_self
def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self)
async def remove_pusher(self, app_id: str, push_key: str, user_id: str): async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@ -160,9 +160,13 @@ class DataStore(
database, database,
stream_name="caches", stream_name="caches",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance", tables=[
instance_column="instance_name", (
id_column="stream_id", "cache_invalidation_stream_by_instance",
"instance_name",
"stream_id",
)
],
sequence_name="cache_invalidation_stream_seq", sequence_name="cache_invalidation_stream_seq",
writers=[], writers=[],
) )

View File

@ -14,14 +14,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
from synapse.api.constants import AccountDataTypes from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The ABCMeta metaclass ensures that it cannot be instantiated without class AccountDataWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement """This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer. `get_max_account_data_stream_id` which can be called in the initializer.
""" """
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
self._can_write_to_account_data = (
self._instance_name in hs.config.worker.writers.account_data
)
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="account_data",
instance_name=self._instance_name,
tables=[
("room_account_data", "instance_name", "stream_id"),
("room_tags_revisions", "instance_name", "stream_id"),
("account_data", "instance_name", "stream_id"),
],
sequence_name="account_data_sequence",
writers=hs.config.worker.writers.account_data,
)
else:
self._can_write_to_account_data = True
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
)
else:
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
)
account_max = self.get_max_account_data_stream_id() account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max "AccountDataAndTagsChangeCache", account_max
@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
@abc.abstractmethod def get_max_account_data_stream_id(self) -> int:
def get_max_account_data_stream_id(self):
"""Get the current max stream ID for account data stream """Get the current max stream ID for account data stream
Returns: Returns:
int int
""" """
raise NotImplementedError() return self._account_data_id_gen.get_current_token()
@cached() @cached()
async def get_account_data_for_user( async def get_account_data_for_user(
@ -307,25 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
) )
) )
def process_replication_rows(self, stream_name, instance_name, token, rows):
class AccountDataStore(AccountDataWorkerStore): if stream_name == TagAccountDataStream.NAME:
def __init__(self, database: DatabasePool, db_conn, hs): self._account_data_id_gen.advance(instance_name, token)
self._account_data_id_gen = StreamIdGenerator( for row in rows:
db_conn, self.get_tags_for_user.invalidate((row.user_id,))
"room_account_data", self._account_data_stream_cache.entity_has_changed(row.user_id, token)
"stream_id", elif stream_name == AccountDataStream.NAME:
extra_tables=[("room_tags_revisions", "stream_id")], self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
(row.data_type, row.user_id)
) )
self.get_account_data_for_user.invalidate((row.user_id,))
super().__init__(database, db_conn, hs) self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
def get_max_account_data_stream_id(self) -> int: (row.user_id, row.room_id, row.data_type)
"""Get the current max stream id for the private user data stream )
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
Returns: return super().process_replication_rows(stream_name, instance_name, token, rows)
The maximum stream ID.
"""
return self._account_data_id_gen.get_current_token()
async def add_account_data_to_room( async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@ -341,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns: Returns:
The maximum stream ID. The maximum stream ID.
""" """
assert self._can_write_to_account_data
content_json = json_encoder.encode(content) content_json = json_encoder.encode(content)
async with self._account_data_id_gen.get_next() as next_id: async with self._account_data_id_gen.get_next() as next_id:
@ -381,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns: Returns:
The maximum stream ID. The maximum stream ID.
""" """
assert self._can_write_to_account_data
async with self._account_data_id_gen.get_next() as next_id: async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_user_account_data", "add_user_account_data",
@ -463,3 +512,7 @@ class AccountDataStore(AccountDataWorkerStore):
# Invalidate the cache for any ignored users which were added or removed. # Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users: for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
class AccountDataStore(AccountDataWorkerStore):
pass

View File

@ -54,9 +54,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
db=database, db=database,
stream_name="to_device", stream_name="to_device",
instance_name=self._instance_name, instance_name=self._instance_name,
table="device_inbox", tables=[("device_inbox", "instance_name", "stream_id")],
instance_column="instance_name",
id_column="stream_id",
sequence_name="device_inbox_sequence", sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device, writers=hs.config.worker.writers.to_device,
) )

View File

@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore):
(rotate_to_stream_ordering,), (rotate_to_stream_ordering,),
) )
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
"""
Purges old push actions for a user and room before a given
stream_ordering.
We however keep a months worth of highlighted notifications, so that
users can still get a list of recent highlights.
Args:
txn: The transcation
room_id: Room ID to delete from
user_id: user ID to delete for
stream_ordering: The lowest stream ordering which will
not be deleted.
"""
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id, user_id),
)
# We need to join on the events table to get the received_ts for
# event_push_actions and sqlite won't let us use a join in a delete so
# we can't just delete where received_ts < x. Furthermore we can
# only identify event_push_actions by a tuple of room_id, event_id
# we we can't use a subquery.
# Instead, we look up the stream ordering for the last event in that
# room received before the threshold time and delete event_push_actions
# in the room with a stream_odering before that.
txn.execute(
"DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND "
" stream_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
)
txn.execute(
"""
DELETE FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
""",
(room_id, user_id, stream_ordering),
)
class EventPushActionsStore(EventPushActionsWorkerStore): class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index" EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@ -894,52 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions return push_actions
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
"""
Purges old push actions for a user and room before a given
stream_ordering.
We however keep a months worth of highlighted notifications, so that
users can still get a list of recent highlights.
Args:
txn: The transcation
room_id: Room ID to delete from
user_id: user ID to delete for
stream_ordering: The lowest stream ordering which will
not be deleted.
"""
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id, user_id),
)
# We need to join on the events table to get the received_ts for
# event_push_actions and sqlite won't let us use a join in a delete so
# we can't just delete where received_ts < x. Furthermore we can
# only identify event_push_actions by a tuple of room_id, event_id
# we we can't use a subquery.
# Instead, we look up the stream ordering for the last event in that
# room received before the threshold time and delete event_push_actions
# in the room with a stream_odering before that.
txn.execute(
"DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND "
" stream_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
)
txn.execute(
"""
DELETE FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
""",
(room_id, user_id, stream_ordering),
)
def _action_has_highlight(actions): def _action_has_highlight(actions):
for action in actions: for action in actions:

View File

@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database, db=database,
stream_name="events", stream_name="events",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
table="events", tables=[("events", "instance_name", "stream_ordering")],
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_stream_seq", sequence_name="events_stream_seq",
writers=hs.config.worker.writers.events, writers=hs.config.worker.writers.events,
) )
@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database, db=database,
stream_name="backfill", stream_name="backfill",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
table="events", tables=[("events", "instance_name", "stream_ordering")],
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_backfill_stream_seq", sequence_name="events_backfill_stream_seq",
positive=False, positive=False,
writers=hs.config.worker.writers.events, writers=hs.config.worker.writers.events,

View File

@ -14,15 +14,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The ABCMeta metaclass ensures that it cannot be instantiated without class ReceiptsWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_max_receipt_stream_id` which can be called in the initializer.
"""
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
self._instance_name in hs.config.worker.writers.receipts
)
self._receipts_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="account_data",
instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")],
sequence_name="receipts_sequence",
writers=hs.config.worker.writers.receipts,
)
else:
self._can_write_to_receipts = True
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
else:
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
) )
@abc.abstractmethod
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
"""Get the current max stream ID for receipts stream """Get the current max stream ID for receipts stream
Returns: Returns:
int int
""" """
raise NotImplementedError() return self._receipts_id_gen.get_current_token()
@cached() @cached()
async def get_users_with_read_receipts_in_room(self, room_id): async def get_users_with_read_receipts_in_room(self, room_id):
@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
self.get_users_with_read_receipts_in_room.invalidate((room_id,)) self.get_users_with_read_receipts_in_room.invalidate((room_id,))
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
class ReceiptsStore(ReceiptsWorkerStore): self.get_receipts_for_user.invalidate((user_id, receipt_type))
def __init__(self, database: DatabasePool, db_conn, hs): self._get_linearized_receipts_for_room.invalidate_many((room_id,))
# We instantiate this first as the ReceiptsWorkerStore constructor self.get_last_receipt_event_id_for_user.invalidate(
# needs to be able to call get_max_receipt_stream_id (user_id, room_id, receipt_type)
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
) )
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
super().__init__(database, db_conn, hs) def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
def get_max_receipt_stream_id(self): return super().process_replication_rows(stream_name, instance_name, token, rows)
return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn( def insert_linearized_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown) (or 0 if the event is unknown)
""" """
assert self._can_write_to_receipts
res = self.db_pool.simple_select_one_txn( res = self.db_pool.simple_select_one_txn(
txn, txn,
table="events", table="events",
@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
) )
return None return None
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after( txn.call_after(
self._invalidate_get_users_with_receipts_in_room, self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
room_id,
receipt_type,
user_id,
)
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
) )
txn.call_after( txn.call_after(
self._receipts_stream_cache.entity_has_changed, room_id, stream_id self._receipts_stream_cache.entity_has_changed, room_id, stream_id
) )
txn.call_after(
self.get_last_receipt_event_id_for_user.invalidate,
(user_id, room_id, receipt_type),
)
self.db_pool.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="receipts_linearized", table="receipts_linearized",
@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
Automatically does conversion between linearized and graph Automatically does conversion between linearized and graph
representations. representations.
""" """
assert self._can_write_to_receipts
if not event_ids: if not event_ids:
return None return None
@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
async def insert_graph_receipt( async def insert_graph_receipt(
self, room_id, receipt_type, user_id, event_ids, data self, room_id, receipt_type, user_id, event_ids, data
): ):
assert self._can_write_to_receipts
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"insert_graph_receipt", "insert_graph_receipt",
self.insert_graph_receipt_txn, self.insert_graph_receipt_txn,
@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
def insert_graph_receipt_txn( def insert_graph_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_ids, data self, txn, room_id, receipt_type, user_id, event_ids, data
): ):
assert self._can_write_to_receipts
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after( txn.call_after(
self._invalidate_get_users_with_receipts_in_room, self._invalidate_get_users_with_receipts_in_room,
@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"data": json_encoder.encode(data), "data": json_encoder.encode(data),
}, },
) )
class ReceiptsStore(ReceiptsWorkerStore):
pass

View File

@ -0,0 +1,20 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE room_account_data ADD COLUMN instance_name TEXT;
ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT;
ALTER TABLE account_data ADD COLUMN instance_name TEXT;
ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;

View File

@ -0,0 +1,32 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE SEQUENCE IF NOT EXISTS account_data_sequence;
-- We need to take the max across all the account_data tables as they share the
-- ID generator
SELECT setval('account_data_sequence', (
SELECT GREATEST(
(SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data),
(SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions),
(SELECT COALESCE(MAX(stream_id), 1) FROM account_data)
)
));
CREATE SEQUENCE IF NOT EXISTS receipts_sequence;
SELECT setval('receipts_sequence', (
SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized
));

View File

@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore):
) )
return {row["tag"]: db_to_json(row["content"]) for row in rows} return {row["tag"]: db_to_json(row["content"]) for row in rows}
class TagsStore(TagsWorkerStore):
async def add_tag_to_room( async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int: ) -> int:
@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore):
Returns: Returns:
The next account data ID. The next account data ID.
""" """
assert self._can_write_to_account_data
content_json = json_encoder.encode(content) content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id): def add_tag_txn(txn, next_id):
@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore):
Returns: Returns:
The next account data ID. The next account data ID.
""" """
assert self._can_write_to_account_data
def remove_tag_txn(txn, next_id): def remove_tag_txn(txn, next_id):
sql = ( sql = (
@ -250,6 +251,7 @@ class TagsStore(TagsWorkerStore):
room_id: The ID of the room. room_id: The ID of the room.
next_id: The the revision to advance to. next_id: The the revision to advance to.
""" """
assert self._can_write_to_account_data
txn.call_after( txn.call_after(
self._account_data_stream_cache.entity_has_changed, user_id, next_id self._account_data_stream_cache.entity_has_changed, user_id, next_id
@ -278,3 +280,7 @@ class TagsStore(TagsWorkerStore):
# which stream_id ends up in the table, as long as it is higher # which stream_id ends up in the table, as long as it is higher
# than the id that the client has. # than the id that the client has.
pass pass
class TagsStore(TagsWorkerStore):
pass

View File

@ -17,7 +17,7 @@ import logging
import threading import threading
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Union from typing import Dict, List, Optional, Set, Tuple, Union
import attr import attr
from typing_extensions import Deque from typing_extensions import Deque
@ -186,11 +186,12 @@ class MultiWriterIdGenerator:
Args: Args:
db_conn db_conn
db db
stream_name: A name for the stream. stream_name: A name for the stream, for use in the `stream_positions`
table. (Does not need to be the same as the replication stream name)
instance_name: The name of this instance. instance_name: The name of this instance.
table: Database table associated with stream. tables: List of tables associated with the stream. Tuple of table
instance_column: Column that stores the row's writer's instance name name, column name that stores the writer's instance name, and
id_column: Column that stores the stream ID. column name that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new sequence_name: The name of the postgres sequence used to generate new
IDs. IDs.
writers: A list of known writers to use to populate current positions writers: A list of known writers to use to populate current positions
@ -206,9 +207,7 @@ class MultiWriterIdGenerator:
db: DatabasePool, db: DatabasePool,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
table: str, tables: List[Tuple[str, str, str]],
instance_column: str,
id_column: str,
sequence_name: str, sequence_name: str,
writers: List[str], writers: List[str],
positive: bool = True, positive: bool = True,
@ -260,15 +259,16 @@ class MultiWriterIdGenerator:
self._sequence_gen = PostgresSequenceGenerator(sequence_name) self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged. # We check that the table and sequence haven't diverged.
for table, _, id_column in tables:
self._sequence_gen.check_consistency( self._sequence_gen.check_consistency(
db_conn, table=table, id_column=id_column, positive=positive db_conn, table=table, id_column=id_column, positive=positive
) )
# This goes and fills out the above state from the database. # This goes and fills out the above state from the database.
self._load_current_ids(db_conn, table, instance_column, id_column) self._load_current_ids(db_conn, tables)
def _load_current_ids( def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str self, db_conn, tables: List[Tuple[str, str, str]],
): ):
cur = db_conn.cursor(txn_name="_load_current_ids") cur = db_conn.cursor(txn_name="_load_current_ids")
@ -306,6 +306,8 @@ class MultiWriterIdGenerator:
# We add a GREATEST here to ensure that the result is always # We add a GREATEST here to ensure that the result is always
# positive. (This can be a problem for e.g. backfill streams where # positive. (This can be a problem for e.g. backfill streams where
# the server has never backfilled). # the server has never backfilled).
max_stream_id = 1
for table, _, id_column in tables:
sql = """ sql = """
SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
FROM %(table)s FROM %(table)s
@ -316,7 +318,10 @@ class MultiWriterIdGenerator:
} }
cur.execute(sql) cur.execute(sql)
(stream_id,) = cur.fetchone() (stream_id,) = cur.fetchone()
self._persisted_upto_position = stream_id
max_stream_id = max(max_stream_id, stream_id)
self._persisted_upto_position = max_stream_id
else: else:
# If we have a min_stream_id then we pull out everything greater # If we have a min_stream_id then we pull out everything greater
# than it from the DB so that we can prefill # than it from the DB so that we can prefill
@ -329,6 +334,10 @@ class MultiWriterIdGenerator:
# stream positions table before restart (or the stream position # stream positions table before restart (or the stream position
# table otherwise got out of date). # table otherwise got out of date).
self._persisted_upto_position = min_stream_id
rows = []
for table, instance_column, id_column in tables:
sql = """ sql = """
SELECT %(instance)s, %(id)s FROM %(table)s SELECT %(instance)s, %(id)s FROM %(table)s
WHERE ? %(cmp)s %(id)s WHERE ? %(cmp)s %(id)s
@ -340,10 +349,13 @@ class MultiWriterIdGenerator:
} }
cur.execute(sql, (min_stream_id * self._return_factor,)) cur.execute(sql, (min_stream_id * self._return_factor,))
self._persisted_upto_position = min_stream_id rows.extend(cur)
# Sort so that we handle rows in order for each instance.
rows.sort()
with self._lock: with self._lock:
for (instance, stream_id,) in cur: for (instance, stream_id,) in rows:
stream_id = self._return_factor * stream_id stream_id = self._return_factor * stream_id
self._add_persisted_position(stream_id) self._add_persisted_position(stream_id)

View File

@ -15,9 +15,8 @@
import time import time
import urllib.parse import urllib.parse
from html.parser import HTMLParser from typing import Any, Dict, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from urllib.parse import urlencode
from urllib.parse import parse_qs, urlencode, urlparse
from mock import Mock from mock import Mock
@ -38,6 +37,7 @@ from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2 from tests.handlers.test_saml import has_saml2
from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless from tests.unittest import HomeserverTestCase, override_config, skip_unless
try: try:
@ -69,6 +69,12 @@ TEST_SAML_METADATA = """
LOGIN_URL = b"/_matrix/client/r0/login" LOGIN_URL = b"/_matrix/client/r0/login"
TEST_URL = b"/_matrix/client/r0/account/whoami" TEST_URL = b"/_matrix/client/r0/account/whoami"
# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is +
TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="%26=o"'
# the query params in TEST_CLIENT_REDIRECT_URL
EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
class LoginRestServletTestCase(unittest.HomeserverTestCase): class LoginRestServletTestCase(unittest.HomeserverTestCase):
@ -389,23 +395,44 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
}, },
} }
# default OIDC provider
config["oidc_config"] = TEST_OIDC_CONFIG config["oidc_config"] = TEST_OIDC_CONFIG
# additional OIDC providers
config["oidc_providers"] = [
{
"idp_id": "idp1",
"idp_name": "IDP1",
"discover": False,
"issuer": "https://issuer1",
"client_id": "test-client-id",
"client_secret": "test-client-secret",
"scopes": ["profile"],
"authorization_endpoint": "https://issuer1/auth",
"token_endpoint": "https://issuer1/token",
"userinfo_endpoint": "https://issuer1/userinfo",
"user_mapping_provider": {
"config": {"localpart_template": "{{ user.sub }}"}
},
}
]
return config return config
def create_resource_dict(self) -> Dict[str, Resource]: def create_resource_dict(self) -> Dict[str, Resource]:
from synapse.rest.oidc import OIDCResource
d = super().create_resource_dict() d = super().create_resource_dict()
d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
d["/_synapse/oidc"] = OIDCResource(self.hs)
return d return d
def test_multi_sso_redirect(self): def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker""" """/login/sso/redirect should redirect to an identity picker"""
client_redirect_url = "https://x?<abc>"
# first hit the redirect url, which should redirect to our idp picker # first hit the redirect url, which should redirect to our idp picker
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url, "/_matrix/client/r0/login/sso/redirect?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
) )
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
uri = channel.headers.getRawHeaders("Location")[0] uri = channel.headers.getRawHeaders("Location")[0]
@ -415,46 +442,22 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class # parse the form to check it has fields assumed elsewhere in this class
class FormPageParser(HTMLParser): p = TestHtmlParser()
def __init__(self):
super().__init__()
# the values of the hidden inputs: map from name to value
self.hiddens = {} # type: Dict[str, Optional[str]]
# the values of the radio buttons
self.radios = [] # type: List[Optional[str]]
def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "input":
if attr_dict["type"] == "radio" and attr_dict["name"] == "idp":
self.radios.append(attr_dict["value"])
elif attr_dict["type"] == "hidden":
input_name = attr_dict["name"]
assert input_name
self.hiddens[input_name] = attr_dict["value"]
def error(_, message):
self.fail(message)
p = FormPageParser()
p.feed(channel.result["body"].decode("utf-8")) p.feed(channel.result["body"].decode("utf-8"))
p.close() p.close()
self.assertCountEqual(p.radios, ["cas", "oidc", "saml"]) self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"])
self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url) self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
def test_multi_sso_redirect_to_cas(self): def test_multi_sso_redirect_to_cas(self):
"""If CAS is chosen, should redirect to the CAS server""" """If CAS is chosen, should redirect to the CAS server"""
client_redirect_url = "https://x?<abc>"
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas", "/_synapse/client/pick_idp?redirectUrl="
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=cas",
shorthand=False, shorthand=False,
) )
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
@ -470,16 +473,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
service_uri = cas_uri_params["service"][0] service_uri = cas_uri_params["service"][0]
_, service_uri_query = service_uri.split("?", 1) _, service_uri_query = service_uri.split("?", 1)
service_uri_params = urllib.parse.parse_qs(service_uri_query) service_uri_params = urllib.parse.parse_qs(service_uri_query)
self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url) self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
def test_multi_sso_redirect_to_saml(self): def test_multi_sso_redirect_to_saml(self):
"""If SAML is chosen, should redirect to the SAML server""" """If SAML is chosen, should redirect to the SAML server"""
client_redirect_url = "https://x?<abc>"
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_synapse/client/pick_idp?redirectUrl=" "/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml", + "&idp=saml",
) )
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
@ -492,16 +493,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# the RelayState is used to carry the client redirect url # the RelayState is used to carry the client redirect url
saml_uri_params = urllib.parse.parse_qs(saml_uri_query) saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
relay_state_param = saml_uri_params["RelayState"][0] relay_state_param = saml_uri_params["RelayState"][0]
self.assertEqual(relay_state_param, client_redirect_url) self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
def test_multi_sso_redirect_to_oidc(self): def test_login_via_oidc(self):
"""If OIDC is chosen, should redirect to the OIDC auth endpoint""" """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
client_redirect_url = "https://x?<abc>"
# pick the default OIDC provider
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_synapse/client/pick_idp?redirectUrl=" "/_synapse/client/pick_idp?redirectUrl="
+ client_redirect_url + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=oidc", + "&idp=oidc",
) )
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, 302, channel.result)
@ -521,9 +522,41 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
self.assertEqual( self.assertEqual(
self._get_value_from_macaroon(macaroon, "client_redirect_url"), self._get_value_from_macaroon(macaroon, "client_redirect_url"),
client_redirect_url, TEST_CLIENT_REDIRECT_URL,
) )
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
# that should serve a confirmation page
self.assertEqual(channel.code, 200, channel.result)
self.assertTrue(
channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
)
p = TestHtmlParser()
p.feed(channel.text_body)
p.close()
# ... which should contain our redirect link
self.assertEqual(len(p.links), 1)
path, query = p.links[0].split("?", 1)
self.assertEqual(path, "https://x")
# it will have url-encoded the params properly, so we'll have to parse them
params = urllib.parse.parse_qsl(
query, keep_blank_values=True, strict_parsing=True, errors="strict"
)
self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
self.assertEqual(params[2][0], "loginToken")
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
login_token = params[2][1]
chan = self.make_request(
"POST", "/login", content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self): def test_multi_sso_redirect_to_unknown(self):
"""An unknown IdP should cause a 400""" """An unknown IdP should cause a 400"""
channel = self.make_request( channel = self.make_request(
@ -1082,7 +1115,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
# whitelist this client URI so we redirect straight to it rather than # whitelist this client URI so we redirect straight to it rather than
# serving a confirmation page # serving a confirmation page
config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} config["sso"] = {"client_whitelist": ["https://x"]}
return config return config
def create_resource_dict(self) -> Dict[str, Resource]: def create_resource_dict(self) -> Dict[str, Resource]:
@ -1095,11 +1128,10 @@ class UsernamePickerTestCase(HomeserverTestCase):
def test_username_picker(self): def test_username_picker(self):
"""Test the happy path of a username picker flow.""" """Test the happy path of a username picker flow."""
client_redirect_url = "https://whitelisted.client"
# do the start of the login flow # do the start of the login flow
channel = self.helper.auth_via_oidc( channel = self.helper.auth_via_oidc(
{"sub": "tester", "displayname": "Jonny"}, client_redirect_url {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
) )
# that should redirect to the username picker # that should redirect to the username picker
@ -1122,7 +1154,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
session = username_mapping_sessions[session_id] session = username_mapping_sessions[session_id]
self.assertEqual(session.remote_user_id, "tester") self.assertEqual(session.remote_user_id, "tester")
self.assertEqual(session.display_name, "Jonny") self.assertEqual(session.display_name, "Jonny")
self.assertEqual(session.client_redirect_url, client_redirect_url) self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
# the expiry time should be about 15 minutes away # the expiry time should be about 15 minutes away
expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
@ -1146,15 +1178,19 @@ class UsernamePickerTestCase(HomeserverTestCase):
) )
self.assertEqual(chan.code, 302, chan.result) self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location") location_headers = chan.headers.getRawHeaders("Location")
# ensure that the returned location starts with the requested redirect URL # ensure that the returned location matches the requested redirect URL
self.assertEqual( path, query = location_headers[0].split("?", 1)
location_headers[0][: len(client_redirect_url)], client_redirect_url self.assertEqual(path, "https://x")
# it will have url-encoded the params properly, so we'll have to parse them
params = urllib.parse.parse_qsl(
query, keep_blank_values=True, strict_parsing=True, errors="strict"
) )
self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
self.assertEqual(params[2][0], "loginToken")
# fish the login token out of the returned redirect uri # fish the login token out of the returned redirect uri
parts = urlparse(location_headers[0]) login_token = params[2][1]
query = parse_qs(parts.query)
login_token = query["loginToken"][0]
# finally, submit the matrix login token to the login API, which gives us our # finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id. # matrix access token, mxid, and device id.

View File

@ -20,8 +20,7 @@ import json
import re import re
import time import time
import urllib.parse import urllib.parse
from html.parser import HTMLParser from typing import Any, Dict, Mapping, MutableMapping, Optional
from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple
from mock import patch from mock import patch
@ -35,6 +34,7 @@ from synapse.types import JsonDict
from tests.server import FakeChannel, FakeSite, make_request from tests.server import FakeChannel, FakeSite, make_request
from tests.test_utils import FakeResponse from tests.test_utils import FakeResponse
from tests.test_utils.html_parsers import TestHtmlParser
@attr.s @attr.s
@ -440,10 +440,36 @@ class RestHelper:
# param that synapse passes to the IdP via query params, as well as the cookie # param that synapse passes to the IdP via query params, as well as the cookie
# that synapse passes to the client. # that synapse passes to the client.
oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1) oauth_uri_path, _ = oauth_uri.split("?", 1)
assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
"unexpected SSO URI " + oauth_uri_path "unexpected SSO URI " + oauth_uri_path
) )
return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
def complete_oidc_auth(
self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
) -> FakeChannel:
"""Mock out an OIDC authentication flow
Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
sent back to the OIDC provider.
Requires the OIDC callback resource to be mounted at the normal place.
Args:
oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
from initiate_sso_login or initiate_sso_ui_auth).
cookies: the cookies set by synapse's redirect endpoint, which will be
sent back to the callback endpoint.
user_info_dict: the remote userinfo that the OIDC provider should present.
Typically this should be '{"sub": "<remote user id>"}'.
Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint.
"""
_, oauth_uri_qs = oauth_uri.split("?", 1)
params = urllib.parse.parse_qs(oauth_uri_qs) params = urllib.parse.parse_qs(oauth_uri_qs)
callback_uri = "%s?%s" % ( callback_uri = "%s?%s" % (
urllib.parse.urlparse(params["redirect_uri"][0]).path, urllib.parse.urlparse(params["redirect_uri"][0]).path,
@ -456,9 +482,9 @@ class RestHelper:
expected_requests = [ expected_requests = [
# first we get a hit to the token endpoint, which we tell to return # first we get a hit to the token endpoint, which we tell to return
# a dummy OIDC access token # a dummy OIDC access token
("https://issuer.test/token", {"access_token": "TEST"}), (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
# and then one to the user_info endpoint, which returns our remote user id. # and then one to the user_info endpoint, which returns our remote user id.
("https://issuer.test/userinfo", user_info_dict), (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
] ]
async def mock_req(method: str, uri: str, data=None, headers=None): async def mock_req(method: str, uri: str, data=None, headers=None):
@ -542,25 +568,7 @@ class RestHelper:
channel.extract_cookies(cookies) channel.extract_cookies(cookies)
# parse the confirmation page to fish out the link. # parse the confirmation page to fish out the link.
class ConfirmationPageParser(HTMLParser): p = TestHtmlParser()
def __init__(self):
super().__init__()
self.links = [] # type: List[str]
def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "a":
href = attr_dict["href"]
if href:
self.links.append(href)
def error(_, message):
raise AssertionError(message)
p = ConfirmationPageParser()
p.feed(channel.text_body) p.feed(channel.text_body)
p.close() p.close()
assert len(p.links) == 1, "not exactly one link in confirmation page" assert len(p.links) == 1, "not exactly one link in confirmation page"
@ -570,6 +578,8 @@ class RestHelper:
# an 'oidc_config' suitable for login_via_oidc. # an 'oidc_config' suitable for login_via_oidc.
TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
TEST_OIDC_CONFIG = { TEST_OIDC_CONFIG = {
"enabled": True, "enabled": True,
"discover": False, "discover": False,
@ -578,7 +588,7 @@ TEST_OIDC_CONFIG = {
"client_secret": "test-client-secret", "client_secret": "test-client-secret",
"scopes": ["profile"], "scopes": ["profile"],
"authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
"token_endpoint": "https://issuer.test/token", "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
"userinfo_endpoint": "https://issuer.test/userinfo", "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
} }

View File

@ -74,7 +74,7 @@ class FakeChannel:
return int(self.result["code"]) return int(self.result["code"])
@property @property
def headers(self): def headers(self) -> Headers:
if not self.result: if not self.result:
raise Exception("No result yet.") raise Exception("No result yet.")
h = Headers() h = Headers()

View File

@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.db_pool, self.db_pool,
stream_name="test_stream", stream_name="test_stream",
instance_name=instance_name, instance_name=instance_name,
table="foobar", tables=[("foobar", "instance_name", "stream_id")],
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq", sequence_name="foobar_seq",
writers=writers, writers=writers,
) )
@ -487,9 +485,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.db_pool, self.db_pool,
stream_name="test_stream", stream_name="test_stream",
instance_name=instance_name, instance_name=instance_name,
table="foobar", tables=[("foobar", "instance_name", "stream_id")],
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq", sequence_name="foobar_seq",
writers=writers, writers=writers,
positive=False, positive=False,
@ -579,3 +575,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar1 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
txn.execute(
"""
CREATE TABLE foobar2 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
def _create_id_generator(
self, instance_name="master", writers=["master"]
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
stream_name="test_stream",
instance_name=instance_name,
tables=[
("foobar1", "instance_name", "stream_id"),
("foobar2", "instance_name", "stream_id"),
],
sequence_name="foobar_seq",
writers=writers,
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def _insert_rows(
self,
table: str,
instance_name: str,
number: int,
update_stream_table: bool = True,
):
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
def _insert(txn):
for _ in range(number):
txn.execute(
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
(instance_name,),
)
if update_stream_table:
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
""",
(instance_name,),
)
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def test_load_existing_stream(self):
"""Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table.
"""
self._insert_rows("foobar1", "first", 3)
self._insert_rows("foobar2", "second", 3)
self._insert_rows("foobar2", "second", 1, update_stream_table=False)
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
# The first ID gen will notice that it can advance its token to 7 as it
# has no in progress writes...
self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
# ... but the second ID gen doesn't know that.
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)

View File

@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from html.parser import HTMLParser
from typing import Dict, Iterable, List, Optional, Tuple
class TestHtmlParser(HTMLParser):
"""A generic HTML page parser which extracts useful things from the HTML"""
def __init__(self):
super().__init__()
# a list of links found in the doc
self.links = [] # type: List[str]
# the values of any hidden <input>s: map from name to value
self.hiddens = {} # type: Dict[str, Optional[str]]
# the values of any radio buttons: map from name to list of values
self.radios = {} # type: Dict[str, List[Optional[str]]]
def handle_starttag(
self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
) -> None:
attr_dict = dict(attrs)
if tag == "a":
href = attr_dict["href"]
if href:
self.links.append(href)
elif tag == "input":
input_name = attr_dict.get("name")
if attr_dict["type"] == "radio":
assert input_name
self.radios.setdefault(input_name, []).append(attr_dict["value"])
elif attr_dict["type"] == "hidden":
assert input_name
self.hiddens[input_name] = attr_dict["value"]
def error(_, message):
raise AssertionError(message)

View File

@ -103,6 +103,9 @@ usedevelop=true
[testenv:py35-old] [testenv:py35-old]
skip_install=True skip_install=True
deps = deps =
# Ensure a version of setuptools that supports Python 3.5 is installed.
setuptools < 51.0.0
# Old automat version for Twisted # Old automat version for Twisted
Automat == 0.3.0 Automat == 0.3.0