Allow moving account data and receipts streams off master (#9104)

pull/9161/head
Erik Johnston 2021-01-18 15:47:59 +00:00 committed by GitHub
parent f08ef64926
commit 6633a4015a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 854 additions and 279 deletions

1
changelog.d/9104.feature Normal file
View File

@ -0,0 +1 @@
Add experimental support for moving off receipts and account data persistence off master.

View File

@ -100,7 +100,16 @@ from synapse.rest.client.v1.profile import (
) )
from synapse.rest.client.v1.push_rule import PushRuleRestServlet from synapse.rest.client.v1.push_rule import PushRuleRestServlet
from synapse.rest.client.v1.voip import VoipRestServlet from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory from synapse.rest.client.v2_alpha import (
account_data,
groups,
read_marker,
receipts,
room_keys,
sync,
tags,
user_directory,
)
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
from synapse.rest.client.v2_alpha.account_data import ( from synapse.rest.client.v2_alpha.account_data import (
@ -531,6 +540,10 @@ class GenericWorkerServer(HomeServer):
room.register_deprecated_servlets(self, resource) room.register_deprecated_servlets(self, resource)
InitialSyncRestServlet(self).register(resource) InitialSyncRestServlet(self).register(resource)
room_keys.register_servlets(self, resource) room_keys.register_servlets(self, resource)
tags.register_servlets(self, resource)
account_data.register_servlets(self, resource)
receipts.register_servlets(self, resource)
read_marker.register_servlets(self, resource)
SendToDeviceRestServlet(self).register(resource) SendToDeviceRestServlet(self).register(resource)

View File

@ -56,6 +56,12 @@ class WriterLocations:
to_device = attr.ib( to_device = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter, default=["master"], type=List[str], converter=_instance_to_list_converter,
) )
account_data = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
)
receipts = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
)
class WorkerConfig(Config): class WorkerConfig(Config):
@ -127,7 +133,7 @@ class WorkerConfig(Config):
# Check that the configured writers for events and typing also appears in # Check that the configured writers for events and typing also appears in
# `instance_map`. # `instance_map`.
for stream in ("events", "typing", "to_device"): for stream in ("events", "typing", "to_device", "account_data", "receipts"):
instances = _instance_to_list_converter(getattr(self.writers, stream)) instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances: for instance in instances:
if instance != "master" and instance not in self.instance_map: if instance != "master" and instance not in self.instance_map:
@ -141,6 +147,16 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `to_device` messages." "Must only specify one instance to handle `to_device` messages."
) )
if len(self.writers.account_data) != 1:
raise ConfigError(
"Must only specify one instance to handle `account_data` messages."
)
if len(self.writers.receipts) != 1:
raise ConfigError(
"Must only specify one instance to handle `receipts` messages."
)
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
# Whether this worker should run background tasks or not. # Whether this worker should run background tasks or not.

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd # Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,14 +13,157 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, List, Tuple
from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet,
ReplicationRemoveTagRestServlet,
ReplicationRoomAccountDataRestServlet,
ReplicationUserAccountDataRestServlet,
)
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer from synapse.app.homeserver import HomeServer
class AccountDataHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
self._instance_name = hs.get_instance_name()
self._notifier = hs.get_notifier()
self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
self._account_data_writers = hs.config.worker.writers.account_data
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user.
Args:
user_id: The user to add a tag for.
room_id: The room to add a tag for.
account_data_type: The type of account_data to add.
content: A json object to associate with the tag.
Returns:
The maximum stream ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.add_account_data_to_room(
user_id, room_id, account_data_type, content
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._room_data_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
account_data_type=account_data_type,
content=content,
)
return response["max_stream_id"]
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user.
Args:
user_id: The user to add a tag for.
account_data_type: The type of account_data to add.
content: A json object to associate with the tag.
Returns:
The maximum stream ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.add_account_data_for_user(
user_id, account_data_type, content
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._user_data_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
account_data_type=account_data_type,
content=content,
)
return response["max_stream_id"]
async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int:
"""Add a tag to a room for a user.
Args:
user_id: The user to add a tag for.
room_id: The room to add a tag for.
tag: The tag name to add.
content: A json object to associate with the tag.
Returns:
The next account data ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.add_tag_to_room(
user_id, room_id, tag, content
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._add_tag_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
tag=tag,
content=content,
)
return response["max_stream_id"]
async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
"""Remove a tag from a room for a user.
Returns:
The next account data ID.
"""
if self._instance_name in self._account_data_writers:
max_stream_id = await self._store.remove_tag_from_room(
user_id, room_id, tag
)
self._notifier.on_new_event(
"account_data_key", max_stream_id, users=[user_id]
)
return max_stream_id
else:
response = await self._remove_tag_client(
instance_name=random.choice(self._account_data_writers),
user_id=user_id,
room_id=room_id,
tag=tag,
)
return response["max_stream_id"]
class AccountDataEventSource: class AccountDataEventSource:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()

View File

@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler):
super().__init__(hs) super().__init__(hs)
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.account_data_handler = hs.get_account_data_handler()
self.read_marker_linearizer = Linearizer(name="read_marker") self.read_marker_linearizer = Linearizer(name="read_marker")
self.notifier = hs.get_notifier()
async def received_client_read_marker( async def received_client_read_marker(
self, room_id: str, user_id: str, event_id: str self, room_id: str, user_id: str, event_id: str
@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler):
if should_update: if should_update:
content = {"event_id": event_id} content = {"event_id": event_id}
max_id = await self.store.add_account_data_to_room( await self.account_data_handler.add_account_data_to_room(
user_id, room_id, "m.fully_read", content user_id, room_id, "m.fully_read", content
) )
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])

View File

@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler):
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.federation = hs.get_federation_sender()
hs.get_federation_registry().register_edu_handler( # We only need to poke the federation sender explicitly if its on the
"m.receipt", self._received_remote_receipt # same instance. Other federation sender instances will get notified by
) # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
# in the receipts stream.
self.federation_sender = None
if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender()
# If we can handle the receipt EDUs we do so, otherwise we route them
# to the appropriate worker.
if hs.get_instance_name() in hs.config.worker.writers.receipts:
hs.get_federation_registry().register_edu_handler(
"m.receipt", self._received_remote_receipt
)
else:
hs.get_federation_registry().register_instances_for_edu(
"m.receipt", hs.config.worker.writers.receipts,
)
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler):
if not is_new: if not is_new:
return return
await self.federation.send_read_receipt(receipt) if self.federation_sender:
await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource: class ReceiptEventSource:

View File

@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.account_data_handler = hs.get_account_data_handler()
self.member_linearizer = Linearizer(name="member") self.member_linearizer = Linearizer(name="member")
@ -253,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
direct_rooms[key].append(new_room_id) direct_rooms[key].append(new_room_id)
# Save back to user's m.direct account data # Save back to user's m.direct account data
await self.store.add_account_data_for_user( await self.account_data_handler.add_account_data_for_user(
user_id, AccountDataTypes.DIRECT, direct_rooms user_id, AccountDataTypes.DIRECT, direct_rooms
) )
break break
@ -263,7 +264,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Copy each room tag to the new room # Copy each room tag to the new room
for tag, tag_content in room_tags.items(): for tag, tag_content in room_tags.items():
await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) await self.account_data_handler.add_tag_to_room(
user_id, new_room_id, tag, tag_content
)
async def update_membership( async def update_membership(
self, self,

View File

@ -15,6 +15,7 @@
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.replication.http import ( from synapse.replication.http import (
account_data,
devices, devices,
federation, federation,
login, login,
@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource):
presence.register_servlets(hs, self) presence.register_servlets(hs, self)
membership.register_servlets(hs, self) membership.register_servlets(hs, self)
streams.register_servlets(hs, self) streams.register_servlets(hs, self)
account_data.register_servlets(hs, self)
# The following can't currently be instantiated on workers. # The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None: if hs.config.worker.worker_app is None:

View File

@ -0,0 +1,187 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
"""Add user account data on the appropriate account data worker.
Request format:
POST /_synapse/replication/add_user_account_data/:user_id/:type
{
"content": { ... },
}
"""
NAME = "add_user_account_data"
PATH_ARGS = ("user_id", "account_data_type")
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, account_data_type, content):
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, account_data_type):
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_for_user(
user_id, account_data_type, content["content"]
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
"""Add room account data on the appropriate account data worker.
Request format:
POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type
{
"content": { ... },
}
"""
NAME = "add_room_account_data"
PATH_ARGS = ("user_id", "room_id", "account_data_type")
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, account_data_type, content):
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, room_id, account_data_type):
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, content["content"]
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationAddTagRestServlet(ReplicationEndpoint):
"""Add tag on the appropriate account data worker.
Request format:
POST /_synapse/replication/add_tag/:user_id/:room_id/:tag
{
"content": { ... },
}
"""
NAME = "add_tag"
PATH_ARGS = ("user_id", "room_id", "tag")
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, tag, content):
payload = {
"content": content,
}
return payload
async def _handle_request(self, request, user_id, room_id, tag):
content = parse_json_object_from_request(request)
max_stream_id = await self.handler.add_tag_to_room(
user_id, room_id, tag, content["content"]
)
return 200, {"max_stream_id": max_stream_id}
class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
"""Remove tag on the appropriate account data worker.
Request format:
POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag
{}
"""
NAME = "remove_tag"
PATH_ARGS = (
"user_id",
"room_id",
"tag",
)
CACHE = False
def __init__(self, hs):
super().__init__(hs)
self.handler = hs.get_account_data_handler()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload(user_id, room_id, tag):
return {}
async def _handle_request(self, request, user_id, room_id, tag):
max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,)
return 200, {"max_stream_id": max_stream_id}
def register_servlets(hs, http_server):
ReplicationUserAccountDataRestServlet(hs).register(http_server)
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
ReplicationAddTagRestServlet(hs).register(http_server)
ReplicationRemoveTagRestServlet(hs).register(http_server)

View File

@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
database, database,
stream_name="caches", stream_name="caches",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance", tables=[
instance_column="instance_name", (
id_column="stream_id", "cache_invalidation_stream_by_instance",
"instance_name",
"stream_id",
)
],
sequence_name="cache_invalidation_stream_seq", sequence_name="cache_invalidation_stream_seq",
writers=[], writers=[],
) # type: Optional[MultiWriterIdGenerator] ) # type: Optional[MultiWriterIdGenerator]

View File

@ -15,47 +15,9 @@
# limitations under the License. # limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): pass
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"account_data",
"stream_id",
extra_tables=[
("room_account_data", "stream_id"),
("room_tags_revisions", "stream_id"),
],
)
super().__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
(row.data_type, row.user_id)
)
self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type)
)
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@ -14,43 +14,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): pass
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
super().__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self._get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@ -51,11 +51,14 @@ from synapse.replication.tcp.commands import (
from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import ( from synapse.replication.tcp.streams import (
STREAMS_MAP, STREAMS_MAP,
AccountDataStream,
BackfillStream, BackfillStream,
CachesStream, CachesStream,
EventsStream, EventsStream,
FederationStream, FederationStream,
ReceiptsStream,
Stream, Stream,
TagAccountDataStream,
ToDeviceStream, ToDeviceStream,
TypingStream, TypingStream,
) )
@ -132,6 +135,22 @@ class ReplicationCommandHandler:
continue continue
if isinstance(stream, (AccountDataStream, TagAccountDataStream)):
# Only add AccountDataStream and TagAccountDataStream as a source on the
# instance in charge of account_data persistence.
if hs.get_instance_name() in hs.config.worker.writers.account_data:
self._streams_to_replicate.append(stream)
continue
if isinstance(stream, ReceiptsStream):
# Only add ReceiptsStream as a source on the instance in charge of
# receipts.
if hs.get_instance_name() in hs.config.worker.writers.receipts:
self._streams_to_replicate.append(stream)
continue
# Only add any other streams if we're on master. # Only add any other streams if we're on master.
if hs.config.worker_app is not None: if hs.config.worker_app is not None:
continue continue

View File

@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.handler = hs.get_account_data_handler()
self._is_worker = hs.config.worker_app is not None
async def on_PUT(self, request, user_id, account_data_type): async def on_PUT(self, request, user_id, account_data_type):
if self._is_worker:
raise Exception("Cannot handle PUT /account_data on worker")
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
max_id = await self.store.add_account_data_for_user( await self.handler.add_account_data_for_user(user_id, account_data_type, body)
user_id, account_data_type, body
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {} return 200, {}
@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.handler = hs.get_account_data_handler()
self._is_worker = hs.config.worker_app is not None
async def on_PUT(self, request, user_id, room_id, account_data_type): async def on_PUT(self, request, user_id, room_id, account_data_type):
if self._is_worker:
raise Exception("Cannot handle PUT /account_data on worker")
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers", " Use /rooms/!roomId:server.name/read_markers",
) )
max_id = await self.store.add_account_data_to_room( await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, body user_id, room_id, account_data_type, body
) )
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {} return 200, {}
async def on_GET(self, request, user_id, room_id, account_data_type): async def on_GET(self, request, user_id, room_id, account_data_type):

View File

@ -58,8 +58,7 @@ class TagServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.handler = hs.get_account_data_handler()
self.notifier = hs.get_notifier()
async def on_PUT(self, request, user_id, room_id, tag): async def on_PUT(self, request, user_id, room_id, tag):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -68,9 +67,7 @@ class TagServlet(RestServlet):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) await self.handler.add_tag_to_room(user_id, room_id, tag, body)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {} return 200, {}
@ -79,9 +76,7 @@ class TagServlet(RestServlet):
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) await self.handler.remove_tag_from_room(user_id, room_id, tag)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {} return 200, {}

View File

@ -55,6 +55,7 @@ from synapse.federation.sender import FederationSender
from synapse.federation.transport.client import TransportLayerClient from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
from synapse.handlers.account_data import AccountDataHandler
from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.acme import AcmeHandler from synapse.handlers.acme import AcmeHandler
from synapse.handlers.admin import AdminHandler from synapse.handlers.admin import AdminHandler
@ -711,6 +712,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_module_api(self) -> ModuleApi: def get_module_api(self) -> ModuleApi:
return ModuleApi(self, self.get_auth_handler()) return ModuleApi(self, self.get_auth_handler())
@cache_in_self
def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self)
async def remove_pusher(self, app_id: str, push_key: str, user_id: str): async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@ -160,9 +160,13 @@ class DataStore(
database, database,
stream_name="caches", stream_name="caches",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance", tables=[
instance_column="instance_name", (
id_column="stream_id", "cache_invalidation_stream_by_instance",
"instance_name",
"stream_id",
)
],
sequence_name="cache_invalidation_stream_seq", sequence_name="cache_invalidation_stream_seq",
writers=[], writers=[],
) )

View File

@ -14,14 +14,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
from synapse.api.constants import AccountDataTypes from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The ABCMeta metaclass ensures that it cannot be instantiated without class AccountDataWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement """This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer. `get_max_account_data_stream_id` which can be called in the initializer.
""" """
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
self._can_write_to_account_data = (
self._instance_name in hs.config.worker.writers.account_data
)
self._account_data_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="account_data",
instance_name=self._instance_name,
tables=[
("room_account_data", "instance_name", "stream_id"),
("room_tags_revisions", "instance_name", "stream_id"),
("account_data", "instance_name", "stream_id"),
],
sequence_name="account_data_sequence",
writers=hs.config.worker.writers.account_data,
)
else:
self._can_write_to_account_data = True
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._account_data_id_gen = StreamIdGenerator(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
)
else:
self._account_data_id_gen = SlavedIdTracker(
db_conn,
"room_account_data",
"stream_id",
extra_tables=[("room_tags_revisions", "stream_id")],
)
account_max = self.get_max_account_data_stream_id() account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache( self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max "AccountDataAndTagsChangeCache", account_max
@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
@abc.abstractmethod def get_max_account_data_stream_id(self) -> int:
def get_max_account_data_stream_id(self):
"""Get the current max stream ID for account data stream """Get the current max stream ID for account data stream
Returns: Returns:
int int
""" """
raise NotImplementedError() return self._account_data_id_gen.get_current_token()
@cached() @cached()
async def get_account_data_for_user( async def get_account_data_for_user(
@ -307,25 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
) )
) )
def process_replication_rows(self, stream_name, instance_name, token, rows):
class AccountDataStore(AccountDataWorkerStore): if stream_name == TagAccountDataStream.NAME:
def __init__(self, database: DatabasePool, db_conn, hs): self._account_data_id_gen.advance(instance_name, token)
self._account_data_id_gen = StreamIdGenerator( for row in rows:
db_conn, self.get_tags_for_user.invalidate((row.user_id,))
"room_account_data", self._account_data_stream_cache.entity_has_changed(row.user_id, token)
"stream_id", elif stream_name == AccountDataStream.NAME:
extra_tables=[("room_tags_revisions", "stream_id")], self._account_data_id_gen.advance(instance_name, token)
) for row in rows:
if not row.room_id:
super().__init__(database, db_conn, hs) self.get_global_account_data_by_type_for_user.invalidate(
(row.data_type, row.user_id)
def get_max_account_data_stream_id(self) -> int: )
"""Get the current max stream id for the private user data stream self.get_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
Returns: self.get_account_data_for_room_and_type.invalidate(
The maximum stream ID. (row.user_id, row.room_id, row.data_type)
""" )
return self._account_data_id_gen.get_current_token() self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
async def add_account_data_to_room( async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
@ -341,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns: Returns:
The maximum stream ID. The maximum stream ID.
""" """
assert self._can_write_to_account_data
content_json = json_encoder.encode(content) content_json = json_encoder.encode(content)
async with self._account_data_id_gen.get_next() as next_id: async with self._account_data_id_gen.get_next() as next_id:
@ -381,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore):
Returns: Returns:
The maximum stream ID. The maximum stream ID.
""" """
assert self._can_write_to_account_data
async with self._account_data_id_gen.get_next() as next_id: async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_user_account_data", "add_user_account_data",
@ -463,3 +512,7 @@ class AccountDataStore(AccountDataWorkerStore):
# Invalidate the cache for any ignored users which were added or removed. # Invalidate the cache for any ignored users which were added or removed.
for ignored_user_id in previously_ignored_users ^ currently_ignored_users: for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
class AccountDataStore(AccountDataWorkerStore):
pass

View File

@ -54,9 +54,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
db=database, db=database,
stream_name="to_device", stream_name="to_device",
instance_name=self._instance_name, instance_name=self._instance_name,
table="device_inbox", tables=[("device_inbox", "instance_name", "stream_id")],
instance_column="instance_name",
id_column="stream_id",
sequence_name="device_inbox_sequence", sequence_name="device_inbox_sequence",
writers=hs.config.worker.writers.to_device, writers=hs.config.worker.writers.to_device,
) )

View File

@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore):
(rotate_to_stream_ordering,), (rotate_to_stream_ordering,),
) )
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
"""
Purges old push actions for a user and room before a given
stream_ordering.
We however keep a months worth of highlighted notifications, so that
users can still get a list of recent highlights.
Args:
txn: The transcation
room_id: Room ID to delete from
user_id: user ID to delete for
stream_ordering: The lowest stream ordering which will
not be deleted.
"""
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id, user_id),
)
# We need to join on the events table to get the received_ts for
# event_push_actions and sqlite won't let us use a join in a delete so
# we can't just delete where received_ts < x. Furthermore we can
# only identify event_push_actions by a tuple of room_id, event_id
# we we can't use a subquery.
# Instead, we look up the stream ordering for the last event in that
# room received before the threshold time and delete event_push_actions
# in the room with a stream_odering before that.
txn.execute(
"DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND "
" stream_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
)
txn.execute(
"""
DELETE FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
""",
(room_id, user_id, stream_ordering),
)
class EventPushActionsStore(EventPushActionsWorkerStore): class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index" EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
@ -894,52 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions return push_actions
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
"""
Purges old push actions for a user and room before a given
stream_ordering.
We however keep a months worth of highlighted notifications, so that
users can still get a list of recent highlights.
Args:
txn: The transcation
room_id: Room ID to delete from
user_id: user ID to delete for
stream_ordering: The lowest stream ordering which will
not be deleted.
"""
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id, user_id),
)
# We need to join on the events table to get the received_ts for
# event_push_actions and sqlite won't let us use a join in a delete so
# we can't just delete where received_ts < x. Furthermore we can
# only identify event_push_actions by a tuple of room_id, event_id
# we we can't use a subquery.
# Instead, we look up the stream ordering for the last event in that
# room received before the threshold time and delete event_push_actions
# in the room with a stream_odering before that.
txn.execute(
"DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND "
" stream_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
)
txn.execute(
"""
DELETE FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
""",
(room_id, user_id, stream_ordering),
)
def _action_has_highlight(actions): def _action_has_highlight(actions):
for action in actions: for action in actions:

View File

@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database, db=database,
stream_name="events", stream_name="events",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
table="events", tables=[("events", "instance_name", "stream_ordering")],
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_stream_seq", sequence_name="events_stream_seq",
writers=hs.config.worker.writers.events, writers=hs.config.worker.writers.events,
) )
@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore):
db=database, db=database,
stream_name="backfill", stream_name="backfill",
instance_name=hs.get_instance_name(), instance_name=hs.get_instance_name(),
table="events", tables=[("events", "instance_name", "stream_ordering")],
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_backfill_stream_seq", sequence_name="events_backfill_stream_seq",
positive=False, positive=False,
writers=hs.config.worker.writers.events, writers=hs.config.worker.writers.events,

View File

@ -14,15 +14,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# The ABCMeta metaclass ensures that it cannot be instantiated without class ReceiptsWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_max_receipt_stream_id` which can be called in the initializer.
"""
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self._instance_name = hs.get_instance_name()
if isinstance(database.engine, PostgresEngine):
self._can_write_to_receipts = (
self._instance_name in hs.config.worker.writers.receipts
)
self._receipts_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
stream_name="account_data",
instance_name=self._instance_name,
tables=[("receipts_linearized", "instance_name", "stream_id")],
sequence_name="receipts_sequence",
writers=hs.config.worker.writers.receipts,
)
else:
self._can_write_to_receipts = True
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
else:
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache( self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
) )
@abc.abstractmethod
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
"""Get the current max stream ID for receipts stream """Get the current max stream ID for receipts stream
Returns: Returns:
int int
""" """
raise NotImplementedError() return self._receipts_id_gen.get_current_token()
@cached() @cached()
async def get_users_with_read_receipts_in_room(self, room_id): async def get_users_with_read_receipts_in_room(self, room_id):
@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
self.get_users_with_read_receipts_in_room.invalidate((room_id,)) self.get_users_with_read_receipts_in_room.invalidate((room_id,))
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
class ReceiptsStore(ReceiptsWorkerStore): self.get_receipts_for_user.invalidate((user_id, receipt_type))
def __init__(self, database: DatabasePool, db_conn, hs): self._get_linearized_receipts_for_room.invalidate_many((room_id,))
# We instantiate this first as the ReceiptsWorkerStore constructor self.get_last_receipt_event_id_for_user.invalidate(
# needs to be able to call get_max_receipt_stream_id (user_id, room_id, receipt_type)
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
) )
self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
self.get_receipts_for_room.invalidate((room_id, receipt_type))
super().__init__(database, db_conn, hs) def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
def get_max_receipt_stream_id(self): return super().process_replication_rows(stream_name, instance_name, token, rows)
return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn( def insert_linearized_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown) (or 0 if the event is unknown)
""" """
assert self._can_write_to_receipts
res = self.db_pool.simple_select_one_txn( res = self.db_pool.simple_select_one_txn(
txn, txn,
table="events", table="events",
@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
) )
return None return None
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after( txn.call_after(
self._invalidate_get_users_with_receipts_in_room, self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
room_id,
receipt_type,
user_id,
)
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
) )
txn.call_after( txn.call_after(
self._receipts_stream_cache.entity_has_changed, room_id, stream_id self._receipts_stream_cache.entity_has_changed, room_id, stream_id
) )
txn.call_after(
self.get_last_receipt_event_id_for_user.invalidate,
(user_id, room_id, receipt_type),
)
self.db_pool.simple_upsert_txn( self.db_pool.simple_upsert_txn(
txn, txn,
table="receipts_linearized", table="receipts_linearized",
@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
Automatically does conversion between linearized and graph Automatically does conversion between linearized and graph
representations. representations.
""" """
assert self._can_write_to_receipts
if not event_ids: if not event_ids:
return None return None
@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
async def insert_graph_receipt( async def insert_graph_receipt(
self, room_id, receipt_type, user_id, event_ids, data self, room_id, receipt_type, user_id, event_ids, data
): ):
assert self._can_write_to_receipts
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"insert_graph_receipt", "insert_graph_receipt",
self.insert_graph_receipt_txn, self.insert_graph_receipt_txn,
@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore):
def insert_graph_receipt_txn( def insert_graph_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_ids, data self, txn, room_id, receipt_type, user_id, event_ids, data
): ):
assert self._can_write_to_receipts
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after( txn.call_after(
self._invalidate_get_users_with_receipts_in_room, self._invalidate_get_users_with_receipts_in_room,
@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"data": json_encoder.encode(data), "data": json_encoder.encode(data),
}, },
) )
class ReceiptsStore(ReceiptsWorkerStore):
pass

View File

@ -0,0 +1,20 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE room_account_data ADD COLUMN instance_name TEXT;
ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT;
ALTER TABLE account_data ADD COLUMN instance_name TEXT;
ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT;

View File

@ -0,0 +1,32 @@
/* Copyright 2021 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE SEQUENCE IF NOT EXISTS account_data_sequence;
-- We need to take the max across all the account_data tables as they share the
-- ID generator
SELECT setval('account_data_sequence', (
SELECT GREATEST(
(SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data),
(SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions),
(SELECT COALESCE(MAX(stream_id), 1) FROM account_data)
)
));
CREATE SEQUENCE IF NOT EXISTS receipts_sequence;
SELECT setval('receipts_sequence', (
SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized
));

View File

@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore):
) )
return {row["tag"]: db_to_json(row["content"]) for row in rows} return {row["tag"]: db_to_json(row["content"]) for row in rows}
class TagsStore(TagsWorkerStore):
async def add_tag_to_room( async def add_tag_to_room(
self, user_id: str, room_id: str, tag: str, content: JsonDict self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int: ) -> int:
@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore):
Returns: Returns:
The next account data ID. The next account data ID.
""" """
assert self._can_write_to_account_data
content_json = json_encoder.encode(content) content_json = json_encoder.encode(content)
def add_tag_txn(txn, next_id): def add_tag_txn(txn, next_id):
@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore):
Returns: Returns:
The next account data ID. The next account data ID.
""" """
assert self._can_write_to_account_data
def remove_tag_txn(txn, next_id): def remove_tag_txn(txn, next_id):
sql = ( sql = (
@ -250,6 +251,7 @@ class TagsStore(TagsWorkerStore):
room_id: The ID of the room. room_id: The ID of the room.
next_id: The the revision to advance to. next_id: The the revision to advance to.
""" """
assert self._can_write_to_account_data
txn.call_after( txn.call_after(
self._account_data_stream_cache.entity_has_changed, user_id, next_id self._account_data_stream_cache.entity_has_changed, user_id, next_id
@ -278,3 +280,7 @@ class TagsStore(TagsWorkerStore):
# which stream_id ends up in the table, as long as it is higher # which stream_id ends up in the table, as long as it is higher
# than the id that the client has. # than the id that the client has.
pass pass
class TagsStore(TagsWorkerStore):
pass

View File

@ -17,7 +17,7 @@ import logging
import threading import threading
from collections import deque from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Union from typing import Dict, List, Optional, Set, Tuple, Union
import attr import attr
from typing_extensions import Deque from typing_extensions import Deque
@ -186,11 +186,12 @@ class MultiWriterIdGenerator:
Args: Args:
db_conn db_conn
db db
stream_name: A name for the stream. stream_name: A name for the stream, for use in the `stream_positions`
table. (Does not need to be the same as the replication stream name)
instance_name: The name of this instance. instance_name: The name of this instance.
table: Database table associated with stream. tables: List of tables associated with the stream. Tuple of table
instance_column: Column that stores the row's writer's instance name name, column name that stores the writer's instance name, and
id_column: Column that stores the stream ID. column name that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new sequence_name: The name of the postgres sequence used to generate new
IDs. IDs.
writers: A list of known writers to use to populate current positions writers: A list of known writers to use to populate current positions
@ -206,9 +207,7 @@ class MultiWriterIdGenerator:
db: DatabasePool, db: DatabasePool,
stream_name: str, stream_name: str,
instance_name: str, instance_name: str,
table: str, tables: List[Tuple[str, str, str]],
instance_column: str,
id_column: str,
sequence_name: str, sequence_name: str,
writers: List[str], writers: List[str],
positive: bool = True, positive: bool = True,
@ -260,15 +259,16 @@ class MultiWriterIdGenerator:
self._sequence_gen = PostgresSequenceGenerator(sequence_name) self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged. # We check that the table and sequence haven't diverged.
self._sequence_gen.check_consistency( for table, _, id_column in tables:
db_conn, table=table, id_column=id_column, positive=positive self._sequence_gen.check_consistency(
) db_conn, table=table, id_column=id_column, positive=positive
)
# This goes and fills out the above state from the database. # This goes and fills out the above state from the database.
self._load_current_ids(db_conn, table, instance_column, id_column) self._load_current_ids(db_conn, tables)
def _load_current_ids( def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str self, db_conn, tables: List[Tuple[str, str, str]],
): ):
cur = db_conn.cursor(txn_name="_load_current_ids") cur = db_conn.cursor(txn_name="_load_current_ids")
@ -306,17 +306,22 @@ class MultiWriterIdGenerator:
# We add a GREATEST here to ensure that the result is always # We add a GREATEST here to ensure that the result is always
# positive. (This can be a problem for e.g. backfill streams where # positive. (This can be a problem for e.g. backfill streams where
# the server has never backfilled). # the server has never backfilled).
sql = """ max_stream_id = 1
SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) for table, _, id_column in tables:
FROM %(table)s sql = """
""" % { SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
"id": id_column, FROM %(table)s
"table": table, """ % {
"agg": "MAX" if self._positive else "-MIN", "id": id_column,
} "table": table,
cur.execute(sql) "agg": "MAX" if self._positive else "-MIN",
(stream_id,) = cur.fetchone() }
self._persisted_upto_position = stream_id cur.execute(sql)
(stream_id,) = cur.fetchone()
max_stream_id = max(max_stream_id, stream_id)
self._persisted_upto_position = max_stream_id
else: else:
# If we have a min_stream_id then we pull out everything greater # If we have a min_stream_id then we pull out everything greater
# than it from the DB so that we can prefill # than it from the DB so that we can prefill
@ -329,21 +334,28 @@ class MultiWriterIdGenerator:
# stream positions table before restart (or the stream position # stream positions table before restart (or the stream position
# table otherwise got out of date). # table otherwise got out of date).
sql = """
SELECT %(instance)s, %(id)s FROM %(table)s
WHERE ? %(cmp)s %(id)s
""" % {
"id": id_column,
"table": table,
"instance": instance_column,
"cmp": "<=" if self._positive else ">=",
}
cur.execute(sql, (min_stream_id * self._return_factor,))
self._persisted_upto_position = min_stream_id self._persisted_upto_position = min_stream_id
rows = []
for table, instance_column, id_column in tables:
sql = """
SELECT %(instance)s, %(id)s FROM %(table)s
WHERE ? %(cmp)s %(id)s
""" % {
"id": id_column,
"table": table,
"instance": instance_column,
"cmp": "<=" if self._positive else ">=",
}
cur.execute(sql, (min_stream_id * self._return_factor,))
rows.extend(cur)
# Sort so that we handle rows in order for each instance.
rows.sort()
with self._lock: with self._lock:
for (instance, stream_id,) in cur: for (instance, stream_id,) in rows:
stream_id = self._return_factor * stream_id stream_id = self._return_factor * stream_id
self._add_persisted_position(stream_id) self._add_persisted_position(stream_id)

View File

@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.db_pool, self.db_pool,
stream_name="test_stream", stream_name="test_stream",
instance_name=instance_name, instance_name=instance_name,
table="foobar", tables=[("foobar", "instance_name", "stream_id")],
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq", sequence_name="foobar_seq",
writers=writers, writers=writers,
) )
@ -487,9 +485,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.db_pool, self.db_pool,
stream_name="test_stream", stream_name="test_stream",
instance_name=instance_name, instance_name=instance_name,
table="foobar", tables=[("foobar", "instance_name", "stream_id")],
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq", sequence_name="foobar_seq",
writers=writers, writers=writers,
positive=False, positive=False,
@ -579,3 +575,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar1 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
txn.execute(
"""
CREATE TABLE foobar2 (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
def _create_id_generator(
self, instance_name="master", writers=["master"]
) -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
stream_name="test_stream",
instance_name=instance_name,
tables=[
("foobar1", "instance_name", "stream_id"),
("foobar2", "instance_name", "stream_id"),
],
sequence_name="foobar_seq",
writers=writers,
)
return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
def _insert_rows(
self,
table: str,
instance_name: str,
number: int,
update_stream_table: bool = True,
):
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
def _insert(txn):
for _ in range(number):
txn.execute(
"INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
(instance_name,),
)
if update_stream_table:
txn.execute(
"""
INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
""",
(instance_name,),
)
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
def test_load_existing_stream(self):
"""Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table.
"""
self._insert_rows("foobar1", "first", 3)
self._insert_rows("foobar2", "second", 3)
self._insert_rows("foobar2", "second", 1, update_stream_table=False)
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
# The first ID gen will notice that it can advance its token to 7 as it
# has no in progress writes...
self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
# ... but the second ID gen doesn't know that.
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)