Merge branch 'develop' into matrix-org-hotfixes

pull/8675/head
Brendan Abolivier 2020-09-03 15:30:00 +01:00
commit 505ea932f5
109 changed files with 2064 additions and 921 deletions

View File

@ -1,3 +1,16 @@
Upgrading to v1.20.0
====================
Shared rooms endpoint (MSC2666)
-------------------------------
This release contains a new unstable endpoint `/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*`
for fetching rooms one user has in common with another. This feature requires the
`update_user_directory` config flag to be `True`. If you are you are using a `synapse.app.user_dir`
worker, requests to this endpoint must be handled by that worker.
See `docs/workers.md <docs/workers.md>`_ for more details.
Upgrading Synapse
=================

1
changelog.d/7785.feature Normal file
View File

@ -0,0 +1 @@
Add an endpoint to query your shared rooms with another user as an implementation of [MSC2666](https://github.com/matrix-org/matrix-doc/pull/2666).

1
changelog.d/8059.feature Normal file
View File

@ -0,0 +1 @@
Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654).

1
changelog.d/8156.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8170.feature Normal file
View File

@ -0,0 +1 @@
Add experimental support for sharding event persister.

1
changelog.d/8189.doc Normal file
View File

@ -0,0 +1 @@
Explain better what GDPR-erased means when deactivating a user.

1
changelog.d/8196.misc Normal file
View File

@ -0,0 +1 @@
Fix `wait_for_stream_position` to allow multiple waiters on same stream ID.

1
changelog.d/8197.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8199.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8200.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8201.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8202.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8203.misc Normal file
View File

@ -0,0 +1 @@
Make `MultiWriterIDGenerator` work for streams that use negative values.

1
changelog.d/8204.misc Normal file
View File

@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

1
changelog.d/8205.misc Normal file
View File

@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

1
changelog.d/8207.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8213.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8214.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8222.misc Normal file
View File

@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

1
changelog.d/8223.bugfix Normal file
View File

@ -0,0 +1 @@
Fixes a longstanding bug where user directory updates could break when unexpected profile data was included in events.

1
changelog.d/8224.misc Normal file
View File

@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

1
changelog.d/8225.misc Normal file
View File

@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

1
changelog.d/8226.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a longstanding bug where stats updates could break when unexpected profile data was included in events.

1
changelog.d/8231.misc Normal file
View File

@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

1
changelog.d/8232.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `StreamStore`.

1
changelog.d/8235.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `StreamStore`.

1
changelog.d/8237.misc Normal file
View File

@ -0,0 +1 @@
Fix type hints in `SyncHandler`.

View File

@ -214,9 +214,11 @@ Deactivate Account
This API deactivates an account. It removes active access tokens, resets the
password, and deletes third-party IDs (to prevent the user requesting a
password reset). It can also mark the user as GDPR-erased (stopping their data
from distributed further, and deleting it entirely if there are no other
references to it).
password reset).
It can also mark the user as GDPR-erased. This means messages sent by the
user will still be visible by anyone that was in the room when these messages
were sent, but hidden from users joining the room afterwards.
The api is::

View File

@ -380,6 +380,7 @@ Handles searches in the user directory. It can handle REST endpoints matching
the following regular expressions:
^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$
^/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*$
When using this worker you must also set `update_user_directory: False` in the
shared configuration file to stop the main synapse running background

View File

@ -28,6 +28,7 @@ files =
synapse/handlers/saml_handler.py,
synapse/handlers/sync.py,
synapse/handlers/ui_auth,
synapse/http/federation/well_known_resolver.py,
synapse/http/server.py,
synapse/http/site.py,
synapse/logging/,
@ -42,6 +43,7 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
synapse/storage/engines,

View File

@ -334,6 +334,13 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we
can run out of file descriptors and infinite loop if we attempt to do too
many DNS queries at once
XXX: I'm confused by this. reactor.nameResolver does not use twisted.names unless
you explicitly install twisted.names as the resolver; rather it uses a GAIResolver
backed by the reactor's default threadpool (which is limited to 10 threads). So
(a) I don't understand why twisted ticket 9620 is relevant, and (b) I don't
understand why we would run out of FDs if we did too many lookups at once.
-- richvdh 2020/08/29
"""
new_resolver = _LimitedHostnameResolver(
reactor.nameResolver, max_dns_requests_in_flight

View File

@ -79,8 +79,7 @@ class AdminCmdServer(HomeServer):
pass
@defer.inlineCallbacks
def export_data_command(hs, args):
async def export_data_command(hs, args):
"""Export data for a user.
Args:
@ -91,10 +90,8 @@ def export_data_command(hs, args):
user_id = args.user_id
directory = args.output_directory
res = yield defer.ensureDeferred(
hs.get_handlers().admin_handler.export_user_data(
user_id, FileExfiltrationWriter(user_id, directory=directory)
)
res = await hs.get_handlers().admin_handler.export_user_data(
user_id, FileExfiltrationWriter(user_id, directory=directory)
)
print(res)
@ -232,14 +229,15 @@ def start(config_options):
# We also make sure that `_base.start` gets run before we actually run the
# command.
@defer.inlineCallbacks
def run(_reactor):
async def run():
with LoggingContext("command"):
yield _base.start(ss, [])
yield args.func(ss, args)
_base.start(ss, [])
await args.func(ss, args)
_base.start_worker_reactor(
"synapse-admin-cmd", config, run_command=lambda: task.react(run)
"synapse-admin-cmd",
config,
run_command=lambda: task.react(lambda _reactor: defer.ensureDeferred(run())),
)

View File

@ -411,26 +411,24 @@ def setup(config_options):
return provision
@defer.inlineCallbacks
def reprovision_acme():
async def reprovision_acme():
"""
Provision a certificate from ACME, if required, and reload the TLS
certificate if it's renewed.
"""
reprovisioned = yield defer.ensureDeferred(do_acme())
reprovisioned = await do_acme()
if reprovisioned:
_base.refresh_certificate(hs)
@defer.inlineCallbacks
def start():
async def start():
try:
# Run the ACME provisioning code, if it's enabled.
if hs.config.acme_enabled:
acme = hs.get_acme_handler()
# Start up the webservices which we will respond to ACME
# challenges with, and then provision.
yield defer.ensureDeferred(acme.start_listening())
yield defer.ensureDeferred(do_acme())
await acme.start_listening()
await do_acme()
# Check if it needs to be reprovisioned every day.
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
@ -439,8 +437,8 @@ def setup(config_options):
if hs.config.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
yield defer.ensureDeferred(oidc.load_metadata())
yield defer.ensureDeferred(oidc.load_jwks())
await oidc.load_metadata()
await oidc.load_jwks()
_base.start(hs, config.listeners)
@ -456,7 +454,7 @@ def setup(config_options):
reactor.stop()
sys.exit(1)
reactor.callWhenRunning(start)
reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
return hs

View File

@ -14,18 +14,20 @@
# limitations under the License.
import logging
import urllib
from typing import TYPE_CHECKING, Optional
from prometheus_client import Counter
from twisted.internet import defer
from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient
from synapse.types import ThirdPartyInstanceID
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
from synapse.appservice import ApplicationService
logger = logging.getLogger(__name__)
sent_transactions_counter = Counter(
@ -163,19 +165,20 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_3pe to %s threw exception %s", uri, ex)
return []
def get_3pe_protocol(self, service, protocol):
async def get_3pe_protocol(
self, service: "ApplicationService", protocol: str
) -> Optional[JsonDict]:
if service.url is None:
return {}
@defer.inlineCallbacks
def _get():
async def _get() -> Optional[JsonDict]:
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.parse.quote(protocol),
)
try:
info = yield defer.ensureDeferred(self.get_json(uri, {}))
info = await self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
logger.warning(
@ -196,7 +199,7 @@ class ApplicationServiceApi(SimpleHttpClient):
return None
key = (service.id, protocol)
return self.protocol_meta_cache.wrap(key, _get)
return await self.protocol_meta_cache.wrap(key, _get)
async def push_bulk(self, service, events, txn_id=None):
if service.url is None:

View File

@ -832,11 +832,26 @@ class ShardedWorkerHandlingConfig:
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.
"""
# If multiple instances are not defined we always return true.
# If multiple instances are not defined we always return true
if not self.instances or len(self.instances) == 1:
return True
return self.get_instance(key) == instance_name
def get_instance(self, key: str) -> str:
"""Get the instance responsible for handling the given key.
Note: For things like federation sending the config for which instance
is sending is known only to the sender instance if there is only one.
Therefore `should_handle` should be used where possible.
"""
if not self.instances:
return "master"
if len(self.instances) == 1:
return self.instances[0]
# We shard by taking the hash, modulo it by the number of instances and
# then checking whether this instance matches the instance at that
# index.
@ -846,7 +861,7 @@ class ShardedWorkerHandlingConfig:
dest_hash = sha256(key.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances))
return self.instances[remainder] == instance_name
return self.instances[remainder]
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]

View File

@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig:
instances: List[str]
def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ...
def get_instance(self, key: str) -> str: ...

View File

@ -13,12 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import attr
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
from .server import ListenerConfig, parse_listener_def
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
"""Helper for allowing parsing a string or list of strings to a config
option expecting a list of strings.
"""
if isinstance(obj, str):
return [obj]
return obj
@attr.s
class InstanceLocationConfig:
"""The host and port to talk to an instance via HTTP replication.
@ -33,11 +45,13 @@ class WriterLocations:
"""Specifies the instances that write various streams.
Attributes:
events: The instance that writes to the event and backfill streams.
events: The instance that writes to the typing stream.
events: The instances that write to the event and backfill streams.
typing: The instance that writes to the typing stream.
"""
events = attr.ib(default="master", type=str)
events = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter
)
typing = attr.ib(default="master", type=str)
@ -105,15 +119,18 @@ class WorkerConfig(Config):
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
# Check that the configured writer for events and typing also appears in
# Check that the configured writers for events and typing also appears in
# `instance_map`.
for stream in ("events", "typing"):
instance = getattr(self.writers, stream)
if instance != "master" and instance not in self.instance_map:
raise ConfigError(
"Instance %r is configured to write %s but does not appear in `instance_map` config."
% (instance, stream)
)
instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances:
if instance != "master" and instance not in self.instance_map:
raise ConfigError(
"Instance %r is configured to write %s but does not appear in `instance_map` config."
% (instance, stream)
)
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\

View File

@ -18,7 +18,7 @@
import abc
import os
from distutils.util import strtobool
from typing import Dict, Optional, Type
from typing import Dict, Optional, Tuple, Type
from unpaddedbase64 import encode_base64
@ -120,7 +120,7 @@ class _EventInternalMetadata(object):
# be here
before = DictProperty("before") # type: str
after = DictProperty("after") # type: str
order = DictProperty("order") # type: int
order = DictProperty("order") # type: Tuple[int, int]
def get_dict(self) -> JsonDict:
return dict(self._dict)

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Any, Dict, List, Optional, Tuple, Union
import attr
from nacl.signing import SigningKey
@ -97,14 +97,14 @@ class EventBuilder(object):
def is_state(self):
return self._state_key is not None
async def build(self, prev_event_ids):
async def build(self, prev_event_ids: List[str]) -> EventBase:
"""Transform into a fully signed and hashed event
Args:
prev_event_ids (list[str]): The event IDs to use as the prev events
prev_event_ids: The event IDs to use as the prev events
Returns:
FrozenEvent
The signed and hashed event.
"""
state_ids = await self._state.get_current_state_ids(
@ -114,8 +114,13 @@ class EventBuilder(object):
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
auth_events = await self._store.add_event_hashes(auth_ids)
prev_events = await self._store.add_event_hashes(prev_event_ids)
# The types of auth/prev events changes between event versions.
auth_events = await self._store.add_event_hashes(
auth_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events = await self._store.add_event_hashes(
prev_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
else:
auth_events = auth_ids
prev_events = prev_event_ids
@ -138,7 +143,7 @@ class EventBuilder(object):
"unsigned": self.unsigned,
"depth": depth,
"prev_state": [],
}
} # type: Dict[str, Any]
if self.is_state():
event_dict["state_key"] = self._state_key

View File

@ -234,7 +234,9 @@ class DeviceWorkerHandler(BaseHandler):
return result
async def on_federation_query_user_devices(self, user_id):
stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
user_id
)
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = await self.store.get_e2e_cross_signing_key(
user_id, "self_signing"

View File

@ -353,7 +353,7 @@ class E2eKeysHandler(object):
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
results = await self.store.get_e2e_device_keys(local_query)
results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
# Build the result structure
for user_id, device_keys in results.items():
@ -734,7 +734,7 @@ class E2eKeysHandler(object):
# fetch our stored devices. This is used to 1. verify
# signatures on the master key, and 2. to compare with what
# was sent if the device was signed
devices = await self.store.get_e2e_device_keys([(user_id, None)])
devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)])
if user_id not in devices:
raise NotFoundError("No device keys found")

View File

@ -923,7 +923,8 @@ class FederationHandler(BaseHandler):
)
)
await self._handle_new_events(dest, ev_infos, backfilled=True)
if ev_infos:
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@ -1216,7 +1217,7 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events(
destination, event_infos,
destination, room_id, event_infos,
)
def _sanity_check_event(self, ev):
@ -1363,15 +1364,15 @@ class FederationHandler(BaseHandler):
)
max_stream_id = await self._persist_auth_tree(
origin, auth_chain, state, event, room_version_obj
origin, room_id, auth_chain, state, event, room_version_obj
)
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
self.config.worker.writers.events, "events", max_stream_id
self.config.worker.events_shard_config.get_instance(room_id),
"events",
max_stream_id,
)
# Check whether this room is the result of an upgrade of a room we already know
@ -1625,7 +1626,7 @@ class FederationHandler(BaseHandler):
)
context = await self.state_handler.compute_event_context(event)
await self.persist_events_and_notify([(event, context)])
await self.persist_events_and_notify(event.room_id, [(event, context)])
return event
@ -1652,7 +1653,9 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
stream_id = await self.persist_events_and_notify([(event, context)])
stream_id = await self.persist_events_and_notify(
event.room_id, [(event, context)]
)
return event, stream_id
@ -1900,7 +1903,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
[(event, context)], backfilled=backfilled
event.room_id, [(event, context)], backfilled=backfilled
)
except Exception:
run_in_background(
@ -1913,6 +1916,7 @@ class FederationHandler(BaseHandler):
async def _handle_new_events(
self,
origin: str,
room_id: str,
event_infos: Iterable[_NewEventInfo],
backfilled: bool = False,
) -> None:
@ -1944,6 +1948,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
room_id,
[
(ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
@ -1954,6 +1959,7 @@ class FederationHandler(BaseHandler):
async def _persist_auth_tree(
self,
origin: str,
room_id: str,
auth_events: List[EventBase],
state: List[EventBase],
event: EventBase,
@ -1968,6 +1974,7 @@ class FederationHandler(BaseHandler):
Args:
origin: Where the events came from
room_id,
auth_events
state
event
@ -2042,17 +2049,20 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
await self.persist_events_and_notify(
room_id,
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
]
],
)
new_event_context = await self.state_handler.compute_event_context(
event, old_state=state
)
return await self.persist_events_and_notify([(event, new_event_context)])
return await self.persist_events_and_notify(
room_id, [(event, new_event_context)]
)
async def _prep_event(
self,
@ -2903,6 +2913,7 @@ class FederationHandler(BaseHandler):
async def persist_events_and_notify(
self,
room_id: str,
event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> int:
@ -2910,14 +2921,19 @@ class FederationHandler(BaseHandler):
necessary.
Args:
event_and_contexts:
room_id: The room ID of events being persisted.
event_and_contexts: Sequence of events with their associated
context that should be persisted. All events must belong to
the same room.
backfilled: Whether these events are a result of
backfilling or not
"""
if self.config.worker.writers.events != self._instance_name:
instance = self.config.worker.events_shard_config.get_instance(room_id)
if instance != self._instance_name:
result = await self._send_events(
instance_name=self.config.worker.writers.events,
instance_name=instance,
store=self.store,
room_id=room_id,
event_and_contexts=event_and_contexts,
backfilled=backfilled,
)

View File

@ -49,14 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import (
Collection,
Requester,
RoomAlias,
StreamToken,
UserID,
create_requester,
)
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
@ -383,9 +376,8 @@ class EventCreationHandler(object):
self.notifier = hs.get_notifier()
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
self._is_event_writer = (
self.config.worker.writers.events == hs.get_instance_name()
)
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self.room_invite_state_types = self.hs.config.room_invite_state_types
@ -448,7 +440,7 @@ class EventCreationHandler(object):
event_dict: dict,
token_id: Optional[str] = None,
txn_id: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
prev_event_ids: Optional[List[str]] = None,
require_consent: bool = True,
) -> Tuple[EventBase, EventContext]:
"""
@ -788,7 +780,7 @@ class EventCreationHandler(object):
self,
builder: EventBuilder,
requester: Optional[Requester] = None,
prev_event_ids: Optional[Collection[str]] = None,
prev_event_ids: Optional[List[str]] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client
@ -913,9 +905,10 @@ class EventCreationHandler(object):
try:
# If we're a worker we need to hit out to the master.
if not self._is_event_writer:
writer_instance = self._events_shard_config.get_instance(event.room_id)
if writer_instance != self._instance_name:
result = await self.send_event(
instance_name=self.config.worker.writers.events,
instance_name=writer_instance,
event_id=event.event_id,
store=self.store,
requester=requester,
@ -983,7 +976,9 @@ class EventCreationHandler(object):
This should only be run on the instance in charge of persisting events.
"""
assert self._is_event_writer
assert self._events_shard_config.should_handle(
self._instance_name, event.room_id
)
if ratelimit:
# We check if this is a room admin redacting an event so that we

View File

@ -14,15 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Optional
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.streams.config import PaginationConfig
from synapse.types import Requester, RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@ -247,15 +250,16 @@ class PaginationHandler(object):
)
return purge_id
async def _purge_history(self, purge_id, room_id, token, delete_local_events):
async def _purge_history(
self, purge_id: str, room_id: str, token: str, delete_local_events: bool
) -> None:
"""Carry out a history purge on a room.
Args:
purge_id (str): The id for this purge
room_id (str): The room to purge from
token (str): topological token to delete events before
delete_local_events (bool): True to delete local events as well as
remote ones
purge_id: The id for this purge
room_id: The room to purge from
token: topological token to delete events before
delete_local_events: True to delete local events as well as remote ones
"""
self._purges_in_progress_by_room.add(room_id)
try:
@ -291,9 +295,9 @@ class PaginationHandler(object):
"""
return self._purges_by_id.get(purge_id)
async def purge_room(self, room_id):
async def purge_room(self, room_id: str) -> None:
"""Purge the given room from the database"""
with (await self.pagination_lock.write(room_id)):
with await self.pagination_lock.write(room_id):
# check we know about the room
await self.store.get_room_version_id(room_id)
@ -307,23 +311,22 @@ class PaginationHandler(object):
async def get_messages(
self,
requester,
room_id=None,
pagin_config=None,
as_client_event=True,
event_filter=None,
):
requester: Requester,
room_id: Optional[str] = None,
pagin_config: Optional[PaginationConfig] = None,
as_client_event: bool = True,
event_filter: Optional[Filter] = None,
) -> Dict[str, Any]:
"""Get messages in a room.
Args:
requester (Requester): The user requesting messages.
room_id (str): The room they want messages from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any.
as_client_event (bool): True to get events in client-server format.
event_filter (Filter): Filter to apply to results or None
requester: The user requesting messages.
room_id: The room they want messages from.
pagin_config: The pagination config rules to apply, if any.
as_client_event: True to get events in client-server format.
event_filter: Filter to apply to results or None
Returns:
dict: Pagination API results
Pagination API results
"""
user_id = requester.user.to_string()
@ -343,7 +346,7 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
with (await self.pagination_lock.read(room_id)):
with await self.pagination_lock.read(room_id):
(
membership,
member_event_id,

View File

@ -161,6 +161,9 @@ class BaseProfileHandler(BaseHandler):
Codes.FORBIDDEN,
)
if not isinstance(new_displayname, str):
raise SynapseError(400, "Invalid displayname")
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@ -235,6 +238,9 @@ class BaseProfileHandler(BaseHandler):
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
)
if not isinstance(new_avatar_url, str):
raise SynapseError(400, "Invalid displayname")
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)

View File

@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler):
# Always wait for room creation to progate before returning
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", last_stream_id
self.hs.config.worker.events_shard_config.get_instance(room_id),
"events",
last_stream_id,
)
return result, last_stream_id
@ -1260,10 +1262,10 @@ class RoomShutdownHandler(object):
# We now wait for the create room to come back in via replication so
# that we can assume that all the joins/invites have propogated before
# we try and auto join below.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", stream_id
self.hs.config.worker.events_shard_config.get_instance(new_room_id),
"events",
stream_id,
)
else:
new_room_id = None
@ -1293,7 +1295,9 @@ class RoomShutdownHandler(object):
# Wait for leave to come in over replication before trying to forget.
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", stream_id
self.hs.config.worker.events_shard_config.get_instance(room_id),
"events",
stream_id,
)
await self.room_member_handler.forget(target_requester.user, room_id)

View File

@ -38,15 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import (
Collection,
JsonDict,
Requester,
RoomAlias,
RoomID,
StateMap,
UserID,
)
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@ -91,13 +83,6 @@ class RoomMemberHandler(object):
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
self._event_stream_writer_instance = hs.config.worker.writers.events
self._is_on_event_persistence_instance = (
self._event_stream_writer_instance == hs.get_instance_name()
)
if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence
self._join_rate_limiter_local = Ratelimiter(
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
@ -185,7 +170,7 @@ class RoomMemberHandler(object):
target: UserID,
room_id: str,
membership: str,
prev_event_ids: Collection[str],
prev_event_ids: List[str],
txn_id: Optional[str] = None,
ratelimit: bool = True,
content: Optional[dict] = None,

View File

@ -16,7 +16,7 @@
import itertools
import logging
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@ -44,6 +44,9 @@ from synapse.util.caches.response_cache import ResponseCache
from synapse.util.metrics import Measure, measure_func
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# Debug logger for https://github.com/matrix-org/synapse/issues/4422
@ -96,7 +99,12 @@ class TimelineBatch:
__bool__ = __nonzero__ # python3
@attr.s(slots=True, frozen=True)
# We can't freeze this class, because we need to update it after it's instantiated to
# update its unread count. This is because we calculate the unread count for a room only
# if there are updates for it, which we check after the instance has been created.
# This should not be a big deal because we update the notification counts afterwards as
# well anyway.
@attr.s(slots=True)
class JoinedSyncResult:
room_id = attr.ib(type=str)
timeline = attr.ib(type=TimelineBatch)
@ -105,6 +113,7 @@ class JoinedSyncResult:
account_data = attr.ib(type=List[JsonDict])
unread_notifications = attr.ib(type=JsonDict)
summary = attr.ib(type=Optional[JsonDict])
unread_count = attr.ib(type=int)
def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@ -239,7 +248,7 @@ class SyncResult:
class SyncHandler(object):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs_config = hs.config
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@ -714,9 +723,8 @@ class SyncHandler(object):
]
missing_hero_state = await self.store.get_events(missing_hero_event_ids)
missing_hero_state = missing_hero_state.values()
for s in missing_hero_state:
for s in missing_hero_state.values():
cache.set(s.state_key, s.event_id)
state[(EventTypes.Member, s.state_key)] = s
@ -934,7 +942,7 @@ class SyncHandler(object):
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
) -> Optional[Dict[str, str]]:
) -> Dict[str, int]:
with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
@ -942,15 +950,10 @@ class SyncHandler(object):
receipt_type="m.read",
)
if last_unread_event_id:
notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
return notifs
# There is no new information in this period, so your notification
# count is whatever it was last time.
return None
notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
return notifs
async def generate_sync_result(
self,
@ -1773,7 +1776,7 @@ class SyncHandler(object):
ignored_users: Set[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[List[JsonDict]],
tags: Optional[Dict[str, Dict[str, Any]]],
account_data: Dict[str, JsonDict],
always_include: bool = False,
):
@ -1889,7 +1892,7 @@ class SyncHandler(object):
)
if room_builder.rtype == "joined":
unread_notifications = {} # type: Dict[str, str]
unread_notifications = {} # type: Dict[str, int]
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@ -1898,14 +1901,16 @@ class SyncHandler(object):
account_data=account_data_events,
unread_notifications=unread_notifications,
summary=summary,
unread_count=0,
)
if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
if notifs is not None:
unread_notifications["notification_count"] = notifs["notify_count"]
unread_notifications["highlight_count"] = notifs["highlight_count"]
unread_notifications["notification_count"] = notifs["notify_count"]
unread_notifications["highlight_count"] = notifs["highlight_count"]
room_sync.unread_count = notifs["unread_count"]
sync_result_builder.joined.append(room_sync)

View File

@ -234,7 +234,7 @@ class UserDirectoryHandler(StateDeltasHandler):
async def _handle_room_publicity_change(
self, room_id, prev_event_id, event_id, typ
):
"""Handle a room having potentially changed from/to world_readable/publically
"""Handle a room having potentially changed from/to world_readable/publicly
joinable.
Args:
@ -388,9 +388,15 @@ class UserDirectoryHandler(StateDeltasHandler):
prev_name = prev_event.content.get("displayname")
new_name = event.content.get("displayname")
# If the new name is an unexpected form, do not update the directory.
if not isinstance(new_name, str):
new_name = prev_name
prev_avatar = prev_event.content.get("avatar_url")
new_avatar = event.content.get("avatar_url")
# If the new avatar is an unexpected form, do not update the directory.
if not isinstance(new_avatar, str):
new_avatar = prev_avatar
if prev_name != new_name or prev_avatar != new_avatar:
await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)

View File

@ -134,8 +134,8 @@ class MatrixFederationAgent(object):
and not _is_ip_literal(parsed_uri.hostname)
and not parsed_uri.port
):
well_known_result = yield self._well_known_resolver.get_well_known(
parsed_uri.hostname
well_known_result = yield defer.ensureDeferred(
self._well_known_resolver.get_well_known(parsed_uri.hostname)
)
delegated_server = well_known_result.delegated_server

View File

@ -16,6 +16,7 @@
import logging
import random
import time
from typing import Callable, Dict, Optional, Tuple
import attr
@ -23,6 +24,7 @@ from twisted.internet import defer
from twisted.web.client import RedirectAgent, readBody
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
@ -99,15 +101,14 @@ class WellKnownResolver(object):
self._well_known_agent = RedirectAgent(agent)
self.user_agent = user_agent
@defer.inlineCallbacks
def get_well_known(self, server_name):
async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult:
"""Attempt to fetch and parse a .well-known file for the given server
Args:
server_name (bytes): name of the server, from the requested url
server_name: name of the server, from the requested url
Returns:
Deferred[WellKnownLookupResult]: The result of the lookup
The result of the lookup
"""
if server_name == b"kde.org":
@ -128,7 +129,9 @@ class WellKnownResolver(object):
# requests for the same server in parallel?
try:
with Measure(self._clock, "get_well_known"):
result, cache_period = yield self._fetch_well_known(server_name)
result, cache_period = await self._fetch_well_known(
server_name
) # type: Tuple[Optional[bytes], float]
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:
@ -157,18 +160,17 @@ class WellKnownResolver(object):
return WellKnownLookupResult(delegated_server=result)
@defer.inlineCallbacks
def _fetch_well_known(self, server_name):
async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]:
"""Actually fetch and parse a .well-known, without checking the cache
Args:
server_name (bytes): name of the server, from the requested url
server_name: name of the server, from the requested url
Raises:
_FetchWellKnownFailure if we fail to lookup a result
Returns:
Deferred[Tuple[bytes,int]]: The lookup result and cache period.
The lookup result and cache period.
"""
had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
@ -176,7 +178,7 @@ class WellKnownResolver(object):
# We do this in two steps to differentiate between possibly transient
# errors (e.g. can't connect to host, 503 response) and more permenant
# errors (such as getting a 404 response).
response, body = yield self._make_well_known_request(
response, body = await self._make_well_known_request(
server_name, retry=had_valid_well_known
)
@ -219,20 +221,20 @@ class WellKnownResolver(object):
return result, cache_period
@defer.inlineCallbacks
def _make_well_known_request(self, server_name, retry):
async def _make_well_known_request(
self, server_name: bytes, retry: bool
) -> Tuple[IResponse, bytes]:
"""Make the well known request.
This will retry the request if requested and it fails (with unable
to connect or receives a 5xx error).
Args:
server_name (bytes)
retry (bool): Whether to retry the request if it fails.
server_name: name of the server, from the requested url
retry: Whether to retry the request if it fails.
Returns:
Deferred[tuple[IResponse, bytes]] Returns the response object and
body. Response may be a non-200 response.
Returns the response object and body. Response may be a non-200 response.
"""
uri = b"https://%s/.well-known/matrix/server" % (server_name,)
uri_str = uri.decode("ascii")
@ -247,12 +249,12 @@ class WellKnownResolver(object):
logger.info("Fetching %s", uri_str)
try:
response = yield make_deferred_yieldable(
response = await make_deferred_yieldable(
self._well_known_agent.request(
b"GET", uri, headers=Headers(headers)
)
)
body = yield make_deferred_yieldable(readBody(response))
body = await make_deferred_yieldable(readBody(response))
if 500 <= response.code < 600:
raise Exception("Non-200 response %s" % (response.code,))
@ -269,21 +271,24 @@ class WellKnownResolver(object):
logger.info("Error fetching %s: %s. Retrying", uri_str, e)
# Sleep briefly in the hopes that they come back up
yield self._clock.sleep(0.5)
await self._clock.sleep(0.5)
def _cache_period_from_headers(headers, time_now=time.time):
def _cache_period_from_headers(
headers: Headers, time_now: Callable[[], float] = time.time
) -> Optional[float]:
cache_controls = _parse_cache_control(headers)
if b"no-store" in cache_controls:
return 0
if b"max-age" in cache_controls:
try:
max_age = int(cache_controls[b"max-age"])
return max_age
except ValueError:
pass
max_age = cache_controls[b"max-age"]
if max_age:
try:
return int(max_age)
except ValueError:
pass
expires = headers.getRawHeaders(b"expires")
if expires is not None:
@ -299,7 +304,7 @@ def _cache_period_from_headers(headers, time_now=time.time):
return None
def _parse_cache_control(headers):
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
cache_controls = {}
for hdr in headers.getRawHeaders(b"cache-control", []):
for directive in hdr.split(b","):

View File

@ -19,8 +19,10 @@ from collections import namedtuple
from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.event_auth import get_user_power_level
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache
@ -51,6 +53,48 @@ push_rules_delta_state_cache_metric = register_cache(
)
STATE_EVENT_TYPES_TO_MARK_UNREAD = {
EventTypes.Topic,
EventTypes.Name,
EventTypes.RoomAvatar,
EventTypes.Tombstone,
}
def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
# Exclude rejected and soft-failed events.
if context.rejected or event.internal_metadata.is_soft_failed():
return False
# Exclude notices.
if (
not event.is_state()
and event.type == EventTypes.Message
and event.content.get("msgtype") == "m.notice"
):
return False
# Exclude edits.
relates_to = event.content.get("m.relates_to", {})
if relates_to.get("rel_type") == RelationTypes.REPLACE:
return False
# Mark events that have a non-empty string body as unread.
body = event.content.get("body")
if isinstance(body, str) and body:
return True
# Mark some state events as unread.
if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
return True
# Mark encrypted events as unread.
if not event.is_state() and event.type == EventTypes.Encrypted:
return True
return False
class BulkPushRuleEvaluator(object):
"""Calculates the outcome of push rules for an event for all users in the
room at once.
@ -133,9 +177,12 @@ class BulkPushRuleEvaluator(object):
return pl_event.content if pl_event else {}, sender_level
async def action_for_event_by_user(self, event, context) -> None:
"""Given an event and context, evaluate the push rules and insert the
results into the event_push_actions_staging table.
"""Given an event and context, evaluate the push rules, check if the message
should increment the unread count, and insert the results into the
event_push_actions_staging table.
"""
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {}
@ -172,6 +219,8 @@ class BulkPushRuleEvaluator(object):
if event.type == EventTypes.Member and event.state_key == uid:
display_name = event.content.get("displayname", None)
actions_by_user[uid] = []
for rule in rules:
if "enabled" in rule and not rule["enabled"]:
continue
@ -189,7 +238,9 @@ class BulkPushRuleEvaluator(object):
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
# the event)
await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
await self.store.add_push_actions_to_staging(
event.event_id, actions_by_user, count_as_unread,
)
def _condition_checker(evaluator, conditions, uid, display_name, cache):
@ -369,8 +420,8 @@ class RulesForRoom(object):
Args:
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
updated with any new rules.
member_event_ids (list): List of event ids for membership events that
have happened since the last time we filled rules_by_user
member_event_ids (dict): Dict of user id to event id for membership events
that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
"""
@ -390,34 +441,19 @@ class RulesForRoom(object):
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
interested_in_user_ids = {
user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
}
logger.debug("Joined: %r", interested_in_user_ids)
logger.debug("Joined: %r", user_ids)
if_users_with_pushers = await self.store.get_if_users_have_pushers(
interested_in_user_ids, on_invalidate=self.invalidate_all_cb
)
user_ids = {
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
}
logger.debug("With pushers: %r", user_ids)
users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
self.room_id, on_invalidate=self.invalidate_all_cb
)
logger.debug("With receipts: %r", users_with_receipts)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in interested_in_user_ids:
user_ids.add(uid)
# Previously we only considered users with pushers or read receipts in that
# room. We can't do this anymore because we use push actions to calculate unread
# counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here.
user_ids = list(filter(self.is_mine_id, user_ids))
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb

View File

@ -36,7 +36,7 @@ async def get_badge_count(store, user_id):
)
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
badge += 1 if notifs["notify_count"] else 0
badge += 1 if notifs["unread_count"] else 0
return badge

View File

@ -66,7 +66,9 @@ REQUIREMENTS = [
"msgpack>=0.5.2",
"phonenumbers>=8.2.0",
"prometheus_client>=0.0.18,<0.9.0",
# we use attr.validators.deep_iterable, which arrived in 19.1.0
# we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note:
# Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33
# is out in November.)
"attrs>=19.1.0",
"netaddr>=0.7.18",
"Jinja2>=2.9",

View File

@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler
@staticmethod
async def _serialize_payload(store, event_and_contexts, backfilled):
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
"""
Args:
store
room_id (str)
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether or not the events are the result of
backfilling
@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
}
)
payload = {"events": event_payloads, "backfilled": backfilled}
payload = {
"events": event_payloads,
"backfilled": backfilled,
"room_id": room_id,
}
return payload
@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
room_id = content["room_id"]
backfilled = content["backfilled"]
event_payloads = content["events"]
@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
max_stream_id = await self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled
room_id, event_and_contexts, backfilled
)
return 200, {"max_stream_id": max_stream_id}

View File

@ -48,6 +48,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)

View File

@ -14,7 +14,6 @@
# limitations under the License.
"""A replication client for use by synapse workers.
"""
import heapq
import logging
from typing import TYPE_CHECKING, Dict, List, Tuple
@ -219,9 +218,8 @@ class ReplicationDataHandler:
waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
# We insert into the list using heapq as it is more efficient than
# pushing then resorting each time.
heapq.heappush(waiting_list, (position, deferred))
waiting_list.append((position, deferred))
waiting_list.sort(key=lambda t: t[0])
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):

View File

@ -109,7 +109,7 @@ class ReplicationCommandHandler:
if isinstance(stream, (EventsStream, BackfillStream)):
# Only add EventStream and BackfillStream as a source on the
# instance in charge of event persistence.
if hs.config.worker.writers.events == hs.get_instance_name():
if hs.get_instance_name() in hs.config.worker.writers.events:
self._streams_to_replicate.append(stream)
continue

View File

@ -19,7 +19,7 @@ from typing import List, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
from ._base import Stream, StreamUpdateResult, Token
"""Handling of the 'events' replication stream
@ -117,7 +117,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
current_token_without_instance(self._store.get_current_events_token),
self._store._stream_id_gen.get_current_token_for_writer,
self._update_function,
)

View File

@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import (
room_keys,
room_upgrade_rest_servlet,
sendtodevice,
shared_rooms,
sync,
tags,
thirdparty,
@ -125,3 +126,6 @@ class ClientRestResource(JsonResource):
synapse.rest.admin.register_servlets_for_client_rest_resource(
hs, client_resource
)
# unstable
shared_rooms.register_servlets(hs, client_resource)

View File

@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Half-Shot
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet
from synapse.types import UserID
from ._base import client_patterns
logger = logging.getLogger(__name__)
class UserSharedRoomsServlet(RestServlet):
"""
GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
"""
PATTERNS = client_patterns(
"/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
releases=(), # This is an unstable feature
)
def __init__(self, hs):
super(UserSharedRoomsServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.user_directory_active = hs.config.update_user_directory
async def on_GET(self, request, user_id):
if not self.user_directory_active:
raise SynapseError(
code=400,
msg="The user directory is disabled on this server. Cannot determine shared rooms.",
errcode=Codes.FORBIDDEN,
)
UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
if user_id == requester.user.to_string():
raise SynapseError(
code=400,
msg="You cannot request a list of shared rooms with yourself",
errcode=Codes.FORBIDDEN,
)
rooms = await self.store.get_shared_rooms_for_users(
requester.user.to_string(), user_id
)
return 200, {"joined": list(rooms)}
def register_servlets(hs, http_server):
UserSharedRoomsServlet(hs).register(http_server)

View File

@ -425,6 +425,7 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
result["org.matrix.msc2654.unread_count"] = room.unread_count
return result

View File

@ -60,6 +60,8 @@ class VersionsRestServlet(RestServlet):
"org.matrix.e2e_cross_signing": True,
# Implements additional endpoints as described in MSC2432
"org.matrix.msc2432": True,
# Implements additional endpoints as described in MSC2666
"uk.half-shot.msc2666": True,
},
},
)

View File

@ -433,7 +433,7 @@ class BackgroundUpdater(object):
"background_updates", keyvalues={"update_name": update_name}
)
def _background_update_progress(self, update_name: str, progress: dict):
async def _background_update_progress(self, update_name: str, progress: dict):
"""Update the progress of a background update
Args:
@ -441,7 +441,7 @@ class BackgroundUpdater(object):
progress: The progress of the update.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,

View File

@ -28,6 +28,7 @@ from typing import (
Optional,
Tuple,
TypeVar,
cast,
overload,
)
@ -35,7 +36,6 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@ -507,8 +507,9 @@ class DatabasePool(object):
self._txn_perf_counters.update(desc, duration)
sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks
def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
async def runInteraction(
self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
) -> R:
"""Starts a transaction on the database and runs a given function
Arguments:
@ -521,7 +522,7 @@ class DatabasePool(object):
kwargs: named args to pass to `func`
Returns:
Deferred: The result of func
The result of func
"""
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
@ -530,16 +531,14 @@ class DatabasePool(object):
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
result = yield defer.ensureDeferred(
self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
**kwargs
)
result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
**kwargs
)
for after_callback, after_args, after_kwargs in after_callbacks:
@ -549,7 +548,7 @@ class DatabasePool(object):
after_callback(*after_args, **after_kwargs)
raise
return result
return cast(R, result)
async def runWithConnection(
self, func: "Callable[..., R]", *args: Any, **kwargs: Any
@ -604,6 +603,18 @@ class DatabasePool(object):
results = [dict(zip(col_headers, row)) for row in cursor]
return results
@overload
async def execute(
self, desc: str, decoder: Literal[None], query: str, *args: Any
) -> List[Tuple[Any, ...]]:
...
@overload
async def execute(
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
) -> R:
...
async def execute(
self,
desc: str,
@ -1088,6 +1099,28 @@ class DatabasePool(object):
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
@overload
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one_onecol",
) -> Any:
...
@overload
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one_onecol",
) -> Optional[Any]:
...
async def simple_select_one_onecol(
self,
table: str,
@ -1116,6 +1149,30 @@ class DatabasePool(object):
allow_none=allow_none,
)
@overload
@classmethod
def simple_select_one_onecol_txn(
cls,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[False] = False,
) -> Any:
...
@overload
@classmethod
def simple_select_one_onecol_txn(
cls,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[True] = True,
) -> Optional[Any]:
...
@classmethod
def simple_select_one_onecol_txn(
cls,

View File

@ -68,7 +68,7 @@ class Databases(object):
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
if hs.config.worker.writers.events == hs.get_instance_name():
if hs.get_instance_name() in hs.config.worker.writers.events:
persist_events = PersistEventsStore(hs, database, main)
if "state" in database_config.databases:

View File

@ -18,7 +18,7 @@
import calendar
import logging
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
@ -264,6 +264,9 @@ class DataStore(
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None
@ -291,16 +294,16 @@ class DataStore(
return [UserPresenceState(**row) for row in rows]
def count_daily_users(self):
async def count_daily_users(self) -> int:
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_daily_users", self._count_users, yesterday
)
def count_monthly_users(self):
async def count_monthly_users(self) -> int:
"""
Counts the number of users who used this homeserver in the last 30 days.
Note this method is intended for phonehome metrics only and is different
@ -308,7 +311,7 @@ class DataStore(
amongst other things, includes a 3 day grace period before a user counts.
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago
)
@ -327,15 +330,15 @@ class DataStore(
(count,) = txn.fetchone()
return count
def count_r30_users(self):
async def count_r30_users(self) -> Dict[str, int]:
"""
Counts the number of 30 day retained users, defined as:-
* Users who have created their accounts more than 30 days ago
* Where last seen at most 30 days ago
* Where account creation and last_seen are > 30 days apart
Returns counts globaly for a given user as well as breaking
by platform
Returns:
A mapping of counts globally as well as broken out by platform.
"""
def _count_r30_users(txn):
@ -408,7 +411,7 @@ class DataStore(
return results
return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self):
"""
@ -418,7 +421,7 @@ class DataStore(
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
return today_start * 1000
def generate_user_daily_visits(self):
async def generate_user_daily_visits(self) -> None:
"""
Generates daily visit data for use in cohort/ retention analysis
"""
@ -473,7 +476,7 @@ class DataStore(
# frequently
self._last_user_visit_update = now
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits
)
@ -497,22 +500,28 @@ class DataStore(
desc="get_users",
)
def get_users_paginate(
self, start, limit, user_id=None, name=None, guests=True, deactivated=False
):
async def get_users_paginate(
self,
start: int,
limit: int,
user_id: Optional[str] = None,
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
) -> Tuple[List[Dict[str, Any]], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
Args:
start (int): start number to begin the query from
limit (int): number of rows to retrieve
user_id (string): search for user_id. ignored if name is not None
name (string): search for local part of user_id or display name
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
start: start number to begin the query from
limit: number of rows to retrieve
user_id: search for user_id. ignored if name is not None
name: search for local part of user_id or display name
guests: whether to in include guest users
deactivated: whether to include deactivated users
Returns:
defer.Deferred: resolves to list[dict[str, Any]], int
A tuple of a list of mappings from user to information and a count of total users.
"""
def get_users_paginate_txn(txn):
@ -555,7 +564,7 @@ class DataStore(
users = self.db_pool.cursor_to_dict(txn)
return users, count
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_users_paginate_txn", get_users_paginate_txn
)

View File

@ -16,9 +16,7 @@
import abc
import logging
from typing import List, Optional, Tuple
from twisted.internet import defer
from typing import Dict, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cached()
def get_account_data_for_user(self, user_id):
async def get_account_data_for_user(
self, user_id: str
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a user.
Args:
user_id(str): The user to get the account_data for.
user_id: The user to get the account_data for.
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
A 2-tuple of a dict of global account_data and a dict mapping from
room_id string to per room account_data dicts.
"""
def get_account_data_for_user_txn(txn):
@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
return None
@cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
async def get_account_data_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:
user_id(str): The user to get the account_data for.
room_id(str): The room to get the account_data for.
user_id: The user to get the account_data for.
room_id: The room to get the account_data for.
Returns:
A deferred dict of the room account_data
A dict of the room account_data
"""
def get_account_data_for_room_txn(txn):
@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@cached(num_args=3, max_entries=5000)
def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
async def get_account_data_for_room_and_type(
self, user_id: str, room_id: str, account_data_type: str
) -> Optional[JsonDict]:
"""Get the client account_data of given type for a user for a room.
Args:
user_id(str): The user to get the account_data for.
room_id(str): The room to get the account_data for.
account_data_type (str): The account data type to get.
user_id: The user to get the account_data for.
room_id: The room to get the account_data for.
account_data_type: The account data type to get.
Returns:
A deferred of the room account_data for that type, or None if
there isn't any set.
The room account_data for that type, or None if there isn't any set.
"""
def get_account_data_for_room_and_type_txn(txn):
@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return db_to_json(content_json) if content_json else None
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_updated_room_account_data", get_updated_room_account_data_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):
async def get_updated_account_data_for_user(
self, user_id: str, stream_id: int
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a that's changed for a user
Args:
user_id(str): The user to get the account_data for.
stream_id(int): The point in the stream since which to get updates
user_id: The user to get the account_data for.
stream_id: The point in the stream since which to get updates
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id)
)
if not changed:
return defer.succeed(({}, {}))
return ({}, {})
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token()
def _update_max_stream_id(self, next_id: int):
async def _update_max_stream_id(self, next_id: int) -> None:
"""Update the max stream_id
Args:
@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)

View File

@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
@wrap_as_background_process("update_client_ips")
def _update_client_ips_batch(self):
async def _update_client_ips_batch(self) -> None:
# If the DB pool has already terminated, don't try updating
if not self.db_pool.is_running():
@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
to_update = self._batch_row_update
self._batch_row_update = {}
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)

View File

@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
update included in the response), and the list of updates, where
each update is a pair of EDU type and EDU contents.
"""
now_stream_id = self._device_list_id_gen.get_current_token()
now_stream_id = self.get_device_stream_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
@ -254,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
List of objects representing an device update EDU
"""
devices = (
await self.db_pool.runInteraction(
"_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn,
await self.get_e2e_device_keys_and_signatures(
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
@ -292,17 +291,17 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = stream_id
if device is not None:
key_json = device.get("key_json", None)
key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)
if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
if device.signatures:
for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
device_display_name = device.get("device_display_name", None)
device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
else:
@ -312,9 +311,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results
def _get_last_device_update_for_remote_user(
async def _get_last_device_update_for_remote_user(
self, destination: str, user_id: str, from_stream_id: int
):
) -> int:
def f(txn):
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
@ -325,12 +324,16 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0]
return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
return await self.db_pool.runInteraction(
"get_last_device_update_for_remote_user", f
)
def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
async def mark_as_sent_devices_by_remote(
self, destination: str, stream_id: int
) -> None:
"""Mark that updates have successfully been sent to the destination.
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
@ -412,8 +415,10 @@ class DeviceWorkerStore(SQLBaseStore):
},
)
@abc.abstractmethod
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
"""Get the current stream id from the _device_list_id_gen"""
...
@trace
async def get_user_devices_from_cache(
@ -481,51 +486,6 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
def get_devices_with_keys_by_user(self, user_id: str):
"""Get all devices (with any device keys) for a user
Returns:
Deferred which resolves to (stream_id, devices)
"""
return self.db_pool.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn,
user_id,
)
def _get_devices_with_keys_by_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> Tuple[int, List[JsonDict]]:
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
result = {"device_id": device_id}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str]
) -> Set[str]:
@ -726,7 +686,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale",
)
def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user.
"""
@ -740,7 +700,7 @@ class DeviceWorkerStore(SQLBaseStore):
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed",
_mark_remote_user_device_list_as_unsubscribed_txn,
)
@ -1001,9 +961,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
desc="update_device",
)
def update_remote_device_list_cache_entry(
async def update_remote_device_list_cache_entry(
self, user_id: str, device_id: str, content: JsonDict, stream_id: int
):
) -> None:
"""Updates a single device in the cache of a remote user's devicelist.
Note: assumes that we are the only thread that can be updating this user's
@ -1014,11 +974,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: ID of decivice being updated
content: new data on this device
stream_id: the version of the device list
Returns:
Deferred[None]
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
@ -1070,9 +1027,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
lock=False,
)
def update_remote_device_list_cache(
async def update_remote_device_list_cache(
self, user_id: str, devices: List[dict], stream_id: int
):
) -> None:
"""Replace the entire cache of the remote user's devices.
Note: assumes that we are the only thread that can be updating this user's
@ -1082,11 +1039,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: User to update device list for
devices: list of device objects supplied over federation
stream_id: the version of the device list
Returns:
Deferred[None]
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
@ -1096,7 +1050,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
):
) -> None:
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)

View File

@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
def update_aliases_for_room(
async def update_aliases_for_room(
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
):
) -> None:
"""Repoint all of the aliases for a given room, to a different room.
Args:
@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)

View File

@ -14,8 +14,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
from twisted.enterprise.adbapi import Connection
@ -23,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@ -31,19 +34,67 @@ if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
@attr.s
class DeviceKeyLookupResult:
"""The type returned by get_e2e_device_keys_and_signatures"""
display_name = attr.ib(type=Optional[str])
# the key data from e2e_device_keys_json. Typically includes fields like
# "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
# key) and "signatures" (a signature of the structure by the ed25519 key)
key_json = attr.ib(type=Optional[str])
# cross-signing sigs
signatures = attr.ib(type=Optional[Dict], default=None)
class EndToEndKeyWorkerStore(SQLBaseStore):
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
) -> Tuple[int, List[JsonDict]]:
"""Get all devices (with any device keys) for a user
Returns:
(stream_id, devices)
"""
now_stream_id = self.get_device_stream_token()
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
result = {"device_id": device_id}
key_json = device.key_json
if key_json:
result["keys"] = db_to_json(key_json)
if device.signatures:
for sig_user_id, sigs in device.signatures.items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
device_display_name = device.display_name
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
@trace
async def get_e2e_device_keys(
self, query_list, include_all_devices=False, include_deleted_devices=False
):
"""Fetch a list of device keys.
async def get_e2e_device_keys_for_cs_api(
self, query_list: List[Tuple[str, Optional[str]]]
) -> Dict[str, Dict[str, JsonDict]]:
"""Fetch a list of device keys, formatted suitably for the C/S API.
Args:
query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices
that don't have device keys
include_deleted_devices (bool): whether to include null entries for
devices which no longer exist (but were in the query_list).
This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data. The key data will be a dict in the same format as the
@ -53,13 +104,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list:
return {}
results = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_txn,
query_list,
include_all_devices,
include_deleted_devices,
)
results = await self.get_e2e_device_keys_and_signatures(query_list)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
@ -67,13 +112,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_keys in results.items():
rv[user_id] = {}
for device_id, device_info in device_keys.items():
r = db_to_json(device_info.pop("key_json"))
r = db_to_json(device_info.key_json)
r["unsigned"] = {}
display_name = device_info["device_display_name"]
display_name = device_info.display_name
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
if "signatures" in device_info:
for sig_user_id, sigs in device_info["signatures"].items():
if device_info.signatures:
for sig_user_id, sigs in device_info.signatures.items():
r.setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
@ -82,12 +127,45 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return rv
@trace
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
):
async def get_e2e_device_keys_and_signatures(
self,
query_list: List[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Fetch a list of device keys, together with their cross-signatures.
Args:
query_list: List of pairs of user_ids and device_ids. Device id can be None
to indicate "all devices for this user"
include_all_devices: whether to return devices without device keys
include_deleted_devices: whether to include null entries for
devices which no longer exist (but were in the query_list).
This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data.
"""
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)
result = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_and_signatures_txn,
query_list,
include_all_devices,
include_deleted_devices,
)
log_kv(result)
return result
def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
query_clauses = []
query_params = []
signature_query_clauses = []
@ -119,7 +197,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
sql = (
"SELECT user_id, device_id, "
" d.display_name AS device_display_name, "
" d.display_name, "
" k.key_json"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
@ -130,13 +208,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
txn.execute(sql, query_params)
rows = self.db_pool.cursor_to_dict(txn)
result = {}
for row in rows:
result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
for (user_id, device_id, display_name, key_json) in txn:
if include_deleted_devices:
deleted_devices.remove((row["user_id"], row["device_id"]))
result.setdefault(row["user_id"], {})[row["device_id"]] = row
deleted_devices.remove((user_id, device_id))
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
display_name, key_json
)
if include_deleted_devices:
for user_id, device_id in deleted_devices:
@ -167,13 +246,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# note that target_device_result will be None for deleted devices.
continue
target_device_signatures = target_device_result.setdefault("signatures", {})
target_device_signatures = target_device_result.signatures
if target_device_signatures is None:
target_device_signatures = target_device_result.signatures = {}
signing_user_signatures = target_device_signatures.setdefault(
signing_user_id, {}
)
signing_user_signatures[signing_key_id] = signature
log_kv(result)
return result
async def get_e2e_one_time_keys(
@ -252,10 +333,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
def count_e2e_one_time_keys(self, user_id, device_id):
async def count_e2e_one_time_keys(
self, user_id: str, device_id: str
) -> Dict[str, int]:
""" Count the number of one time keys the server has for a device
Returns:
Dict mapping from algorithm to number of keys for that algorithm.
A mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
@ -270,7 +353,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count
return result
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
@ -308,7 +391,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
list_name="user_ids",
num_args=1,
)
def _get_bare_e2e_cross_signing_keys_bulk(
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
@ -316,16 +399,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
the signatures for the calling user need to be fetched.
Args:
user_ids (list[str]): the users whose keys are being requested
user_ids: the users whose keys are being requested
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, either
their user ID will not be in the dict, or their user ID will map
to None.
A mapping from user ID to key type to key data. If a user's cross-signing
keys were not found, either their user ID will not be in the dict, or
their user ID will map to None.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
@ -541,9 +623,16 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
_get_all_user_signature_changes_for_remotes_txn,
)
@abc.abstractmethod
def get_device_stream_token(self) -> int:
"""Get the current stream id from the _device_list_id_gen"""
...
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
async def set_e2e_device_keys(
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
@ -579,12 +668,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."})
return True
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
) -> Dict[str, Dict[str, Dict[str, bytes]]]:
"""Take a list of one time keys out of the database.
Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
"""
@trace
def _claim_e2e_one_time_keys(txn):
@ -620,11 +718,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
return result
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
def delete_e2e_keys_by_device(self, user_id, device_id):
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn):
log_kv(
{
@ -647,7 +745,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)

View File

@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, "stream_ordering too old")
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
SELECT event_id FROM stream_ordering_to_exterm

View File

@ -15,7 +15,9 @@
# limitations under the License.
import logging
from typing import List
from typing import Dict, List, Optional, Tuple, Union
import attr
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
self, room_id: str, user_id: str, last_read_event_id: Optional[str],
) -> Dict[str, int]:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt.
Note that this function assumes the user to be a current member of the room,
since it's either called by the sync handler to handle joined room entries, or by
the HTTP pusher to calculate the badge of unread joined rooms.
Args:
room_id: The room to retrieve the counts in.
user_id: The user to retrieve the counts for.
last_read_event_id: The event associated with the latest read receipt for
this user in this room. None if no receipt for this user in this room.
Returns
A dict containing the counts mentioned earlier in this docstring,
respectively under the keys "notify_count", "highlight_count" and
"unread_count".
"""
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
self, txn, room_id, user_id, last_read_event_id,
):
sql = (
"SELECT stream_ordering"
" FROM events"
" WHERE room_id = ? AND event_id = ?"
)
txn.execute(sql, (room_id, last_read_event_id))
results = txn.fetchall()
if len(results) == 0:
return {"notify_count": 0, "highlight_count": 0}
stream_ordering = None
stream_ordering = results[0][0]
if last_read_event_id is not None:
stream_ordering = self.get_stream_id_for_event_txn(
txn, last_read_event_id, allow_none=True,
)
if stream_ordering is None:
# Either last_read_event_id is None, or it's an event we don't have (e.g.
# because it's been purged), in which case retrieve the stream ordering for
# the latest membership event from this user in this room (which we assume is
# a join).
event_id = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
retcol="event_id",
)
stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
)
def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
# First get number of notifications.
# We don't need to put a notif=1 clause as all rows always have
# notif=1
sql = (
"SELECT count(*)"
"SELECT"
" COUNT(CASE WHEN notif = 1 THEN 1 END),"
" COUNT(CASE WHEN highlight = 1 THEN 1 END),"
" COUNT(CASE WHEN unread = 1 THEN 1 END)"
" FROM event_push_actions ea"
" WHERE"
" user_id = ?"
" AND room_id = ?"
" AND stream_ordering > ?"
" WHERE user_id = ?"
" AND room_id = ?"
" AND stream_ordering > ?"
)
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
notify_count = row[0] if row else 0
(notif_count, highlight_count, unread_count) = (0, 0, 0)
if row:
(notif_count, highlight_count, unread_count) = row
txn.execute(
"""
SELECT notif_count FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
""",
SELECT notif_count, unread_count FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
""",
(room_id, user_id, stream_ordering),
)
rows = txn.fetchall()
if rows:
notify_count += rows[0][0]
# Now get the number of highlights
sql = (
"SELECT count(*)"
" FROM event_push_actions ea"
" WHERE"
" highlight = 1"
" AND user_id = ?"
" AND room_id = ?"
" AND stream_ordering > ?"
)
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
highlight_count = row[0] if row else 0
return {"notify_count": notify_count, "highlight_count": highlight_count}
if row:
notif_count += row[0]
unread_count += row[1]
return {
"notify_count": notif_count,
"unread_count": unread_count,
"highlight_count": highlight_count,
}
async def get_push_action_users_in_range(
self, min_stream_ordering, max_stream_ordering
@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@ -383,62 +409,66 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Now return the first `limit`
return notifs[:limit]
def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
async def get_if_maybe_push_in_range_for_user(
self, user_id: str, min_stream_ordering: int
) -> bool:
"""A fast check to see if there might be something to push for the
user since the given stream ordering. May return false positives.
Useful to know whether to bother starting a pusher on start up or not.
Args:
user_id (str)
min_stream_ordering (int)
user_id
min_stream_ordering
Returns:
Deferred[bool]: True if there may be push to process, False if
there definitely isn't.
True if there may be push to process, False if there definitely isn't.
"""
def _get_if_maybe_push_in_range_for_user_txn(txn):
sql = """
SELECT 1 FROM event_push_actions
WHERE user_id = ? AND stream_ordering > ?
WHERE user_id = ? AND stream_ordering > ? AND notif = 1
LIMIT 1
"""
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn,
)
async def add_push_actions_to_staging(self, event_id, user_id_actions):
async def add_push_actions_to_staging(
self,
event_id: str,
user_id_actions: Dict[str, List[Union[dict, str]]],
count_as_unread: bool,
) -> None:
"""Add the push actions for the event to the push action staging area.
Args:
event_id (str)
user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
user_id to list of push actions, where an action can either be
a string or dict.
Returns:
Deferred
event_id
user_id_actions: A mapping of user_id to list of push actions, where
an action can either be a string or dict.
count_as_unread: Whether this event should increment unread counts.
"""
if not user_id_actions:
return
# This is a helper function for generating the necessary tuple that
# can be used to inert into the `event_push_actions_staging` table.
# can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(user_id, actions):
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
return (
event_id, # event_id column
user_id, # user_id column
_serialize_action(actions, is_highlight), # actions column
1, # notif column
notif, # notif column
is_highlight, # highlight column
int(count_as_unread), # unread column
)
def _add_push_actions_to_staging_txn(txn):
@ -447,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
sql = """
INSERT INTO event_push_actions_staging
(event_id, user_id, actions, notif, highlight)
VALUES (?, ?, ?, ?, ?)
(event_id, user_id, actions, notif, highlight, unread)
VALUES (?, ?, ?, ?, ?, ?)
"""
txn.executemany(
@ -507,7 +537,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
)
def find_first_stream_ordering_after_ts(self, ts):
async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
"""Gets the stream ordering corresponding to a given timestamp.
Specifically, finds the stream_ordering of the first event that was
@ -516,13 +546,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
relatively slow.
Args:
ts (int): timestamp in millis
ts: timestamp in millis
Returns:
Deferred[int]: stream ordering of the first event received on/after
the timestamp
stream ordering of the first event received on/after the timestamp
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,
@ -813,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id,
coalesce(old.notif_count, 0) + upd.notif_count,
coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering,
old.user_id
FROM (
SELECT user_id, room_id, count(*) as notif_count,
SELECT user_id, room_id, count(*) as cnt,
max(stream_ordering) as stream_ordering
FROM event_push_actions
WHERE ? <= stream_ordering AND stream_ordering < ?
AND highlight = 0
AND %s = 1
GROUP BY user_id, room_id
) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
"""
txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
rows = txn.fetchall()
# First get the count of unread messages.
txn.execute(
sql % ("unread_count", "unread"),
(old_rotate_stream_ordering, rotate_to_stream_ordering),
)
logger.info("Rotating notifications, handling %d rows", len(rows))
# We need to merge results from the two requests (the one that retrieves the
# unread count and the one that retrieves the notifications count) into a single
# object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to
# populate.
summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
for row in txn:
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=row[2],
stream_ordering=row[3],
old_user_id=row[4],
notif_count=0,
)
# Then get the count of notifications.
txn.execute(
sql % ("notif_count", "notif"),
(old_rotate_stream_ordering, rotate_to_stream_ordering),
)
for row in txn:
if (row[0], row[1]) in summaries:
summaries[(row[0], row[1])].notif_count = row[2]
else:
# Because the rules on notifying are different than the rules on marking
# a message unread, we might end up with messages that notify but aren't
# marked unread, so we might not have a summary for this (user, room)
# tuple to complete.
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=0,
stream_ordering=row[3],
old_user_id=row[4],
notif_count=row[2],
)
logger.info("Rotating notifications, handling %d rows", len(summaries))
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
@ -840,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
table="event_push_summary",
values=[
{
"user_id": row[0],
"room_id": row[1],
"notif_count": row[2],
"stream_ordering": row[3],
"user_id": user_id,
"room_id": room_id,
"notif_count": summary.notif_count,
"unread_count": summary.unread_count,
"stream_ordering": summary.stream_ordering,
}
for row in rows
if row[4] is None
for ((user_id, room_id), summary) in summaries.items()
if summary.old_user_id is None
],
)
txn.executemany(
"""
UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ?
""",
((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
(
(
summary.notif_count,
summary.unread_count,
summary.stream_ordering,
user_id,
room_id,
)
for ((user_id, room_id), summary) in summaries.items()
if summary.old_user_id is not None
),
)
txn.execute(
@ -881,3 +961,15 @@ def _action_has_highlight(actions):
pass
return False
@attr.s
class _EventPushSummary:
"""Summary of pending event push actions for a given user in a given room.
Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
"""
unread_count = attr.ib(type=int)
stream_ordering = attr.ib(type=int)
old_user_id = attr.ib(type=str)
notif_count = attr.ib(type=int)

View File

@ -97,6 +97,7 @@ class PersistEventsStore:
self.store = main_data_store
self.database_engine = db.engine
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
@ -108,7 +109,7 @@ class PersistEventsStore:
# This should only exist on instances that are configured to write
assert (
hs.config.worker.writers.events == hs.get_instance_name()
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
async def _persist_events_and_state_updates(
@ -800,6 +801,7 @@ class PersistEventsStore:
table="events",
values=[
{
"instance_name": self._instance_name,
"stream_ordering": event.internal_metadata.stream_ordering,
"topological_ordering": event.depth,
"depth": event.depth,
@ -1296,9 +1298,9 @@ class PersistEventsStore:
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
topological_ordering, notif, highlight
topological_ordering, notif, highlight, unread
)
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
FROM event_push_actions_staging
WHERE event_id = ?
"""

View File

@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.writers.events == hs.get_instance_name():
# We are the process in charge of generating stream ids for events,
# so instantiate ID generators based on the database
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_stream_seq",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_backfill_stream_seq",
positive=False,
)
else:
# Another process is in charge of persisting events and generating
# stream IDs: rely on the replication streams to let us know which
# IDs we can process.
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
else:
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering"
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
self._get_event_cache = Cache(
"*getEvent*",
@ -823,20 +851,24 @@ class EventsWorkerStore(SQLBaseStore):
return event_dict
def _maybe_redact_event_row(self, original_ev, redactions, event_map):
def _maybe_redact_event_row(
self,
original_ev: EventBase,
redactions: Iterable[str],
event_map: Dict[str, EventBase],
) -> Optional[EventBase]:
"""Given an event object and a list of possible redacting event ids,
determine whether to honour any of those redactions and if so return a redacted
event.
Args:
original_ev (EventBase):
redactions (iterable[str]): list of event ids of potential redaction events
event_map (dict[str, EventBase]): other events which have been fetched, in
which we can look up the redaaction events. Map from event id to event.
original_ev: The original event.
redactions: list of event ids of potential redaction events
event_map: other events which have been fetched, in which we can
look up the redaaction events. Map from event id to event.
Returns:
Deferred[EventBase|None]: if the event should be redacted, a pruned
event object. Otherwise, None.
If the event should be redacted, a pruned event object. Otherwise, None.
"""
if original_ev.type == "m.room.create":
# we choose to ignore redactions of m.room.create events.
@ -946,17 +978,17 @@ class EventsWorkerStore(SQLBaseStore):
row = txn.fetchone()
return row[0] if row else 0
def get_current_state_event_counts(self, room_id):
async def get_current_state_event_counts(self, room_id: str) -> int:
"""
Gets the current number of state events in a room.
Args:
room_id (str)
room_id: The room ID to query.
Returns:
Deferred[int]
The current number of state events.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_current_state_event_counts",
self._get_current_state_event_counts_txn,
room_id,
@ -991,7 +1023,9 @@ class EventsWorkerStore(SQLBaseStore):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
async def get_all_new_forward_event_rows(
self, last_id: int, current_id: int, limit: int
) -> List[Tuple]:
"""Returns new events, for the Events replication stream
Args:
@ -999,7 +1033,7 @@ class EventsWorkerStore(SQLBaseStore):
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
Returns: Deferred[List[Tuple]]
Returns:
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
@ -1020,18 +1054,20 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
def get_ex_outlier_stream_rows(self, last_id, current_id):
async def get_ex_outlier_stream_rows(
self, last_id: int, current_id: int
) -> List[Tuple]:
"""Returns de-outliered events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
Returns: Deferred[List[Tuple]]
Returns:
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
@ -1054,7 +1090,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id))
return txn.fetchall()
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
)
@ -1226,11 +1262,11 @@ class EventsWorkerStore(SQLBaseStore):
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
def get_next_event_to_expire(self):
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
table, or None if there's no more event to expire.
Returns: Deferred[Optional[Tuple[str, int]]]
Returns:
A tuple containing the event ID as its first element and an expiry timestamp
as its second one, if there's at least one row in the event_expiry table.
None otherwise.
@ -1246,6 +1282,6 @@ class EventsWorkerStore(SQLBaseStore):
return txn.fetchone()
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)

View File

@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
def add_user_filter(self, user_localpart, user_filter):
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
return filter_id
return self.db_pool.runInteraction("add_user_filter", _do_txn)
return await self.db_pool.runInteraction("add_user_filter", _do_txn)

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@ -93,7 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe",
)
def get_url_cache(self, url, ts):
async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
)
return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
@ -237,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_cached_remote_media",
)
def update_cached_last_access_time(self, local_media, remote_media, time_ms):
async def update_cached_last_access_time(
self,
local_media: Iterable[str],
remote_media: Iterable[Tuple[str, str]],
time_ms: int,
):
"""Updates the last access time of the given media
Args:
local_media (iterable[str]): Set of media_ids
remote_media (iterable[(str, str)]): Set of (server_name, media_id)
local_media: Set of media_ids
remote_media: Set of (server_name, media_id)
time_ms: Current time in milliseconds
"""
@ -267,7 +272,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
)
@ -325,7 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
)
def delete_remote_media(self, media_origin, media_id):
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
def delete_remote_media_txn(txn):
self.db_pool.simple_delete_txn(
txn,
@ -338,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_remote_media", delete_remote_media_txn
)
def get_expired_url_cache(self, now_ts):
async def get_expired_url_cache(self, now_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository_url_cache"
" WHERE expires_ts < ?"
@ -354,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn
)
@ -371,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"delete_url_cache", _delete_url_cache_txn
)
def get_url_cache_media_before(self, before_ts):
async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository"
" WHERE created_ts < ? AND url_cache IS NOT NULL"
@ -383,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)

View File

@ -1,3 +1,5 @@
from typing import Optional
from synapse.storage._base import SQLBaseStore
@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
desc="insert_open_id_token",
)
def get_user_id_for_open_id_token(self, token, ts_now_ms):
async def get_user_id_for_open_id_token(
self, token: str, ts_now_ms: int
) -> Optional[str]:
def get_user_id_for_token_txn(txn):
sql = (
"SELECT user_id FROM open_id_tokens"
@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn
)

View File

@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore):
desc="delete_remote_profile_cache",
)
def get_remote_profile_cache_entries_that_expire(self, last_checked):
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> Dict[str, str]:
"""Get all users who haven't been checked since `last_checked`
"""
@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
return self.db_pool.cursor_to_dict(txn)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)

View File

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Any, Tuple
from typing import Any, List, Set, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
@ -25,25 +25,24 @@ logger = logging.getLogger(__name__)
class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
def purge_history(self, room_id, token, delete_local_events):
async def purge_history(
self, room_id: str, token: str, delete_local_events: bool
) -> Set[int]:
"""Deletes room history before a certain point
Args:
room_id (str):
token (str): A topological token to delete events before
delete_local_events (bool):
room_id:
token: A topological token to delete events before
delete_local_events:
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
Returns:
Deferred[set[int]]: The set of state groups that are referenced by
deleted events.
The set of state groups that are referenced by deleted events.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
@ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
return referenced_state_groups
def purge_room(self, room_id):
async def purge_room(self, room_id: str) -> List[int]:
"""Deletes all record of a room
Args:
room_id (str)
room_id
Returns:
Deferred[List[int]]: The list of state groups to delete.
The list of state groups to delete.
"""
return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id)
return await self.db_pool.runInteraction(
"purge_room", self._purge_room_txn, room_id
)
def _purge_room_txn(self, txn, room_id):
# First we fetch all the state groups that should be deleted, before

View File

@ -18,8 +18,6 @@ import abc
import logging
from typing import List, Tuple, Union
from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@ -149,9 +147,11 @@ class PushRulesWorkerStore(
)
return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
def have_push_rules_changed_for_user(self, user_id, last_id):
async def have_push_rules_changed_for_user(
self, user_id: str, last_id: int
) -> bool:
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
return False
else:
def have_push_rules_changed_txn(txn):
@ -163,7 +163,7 @@ class PushRulesWorkerStore(
(count,) = txn.fetchone()
return bool(count)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)

View File

@ -276,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
return results
def get_users_sent_receipts_between(self, last_id: int, current_id: int):
async def get_users_sent_receipts_between(
self, last_id: int, current_id: int
) -> List[str]:
"""Get all users who sent receipts between `last_id` exclusive and
`current_id` inclusive.
Returns:
Deferred[List[str]]
The list of users.
"""
if last_id == current_id:
@ -296,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [r[0] for r in txn]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
)
@ -553,8 +555,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
return stream_id, max_persisted_id
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
return self.db_pool.runInteraction(
async def insert_graph_receipt(
self, room_id, receipt_type, user_id, event_ids, data
):
return await self.db_pool.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,

View File

@ -17,7 +17,7 @@
import logging
import re
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@ -84,22 +84,22 @@ class RegistrationWorkerStore(SQLBaseStore):
return is_trial
@cached()
def get_user_by_access_token(self, token):
async def get_user_by_access_token(self, token: str) -> Optional[dict]:
"""Get a user from the given access token.
Args:
token (str): The access token of a user.
token: The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`,
`valid_until_ms`.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
@cached()
async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]:
"""Get the expiration timestamp for the account bearing a given user ID.
Args:
@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore):
return bool(res) if res else False
def set_server_admin(self, user, admin):
async def set_server_admin(self, user: UserID, admin: bool) -> None:
"""Sets whether a user is an admin of this homeserver.
Args:
user (UserID): user ID of the user to test
admin (bool): true iff the user is to be a server admin,
false otherwise.
user: user ID of the user to test
admin: true iff the user is to be a server admin, false otherwise.
"""
def set_server_admin_txn(txn):
@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn, self.get_user_by_id, (user.to_string(),)
)
return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (
@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
)
return True if res == UserTypes.SUPPORT else False
def get_users_by_id_case_insensitive(self, user_id):
async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
Returns:
A mapping of user_id -> password_hash.
"""
def f(txn):
@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return dict(txn)
return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction("count_users", _count_users)
def count_daily_user_type(self):
async def count_daily_user_type(self) -> Dict[str, int]:
"""
Counts 1) native non guest users
2) native guests users
@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
results[row[0]] = row[1]
return results
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_daily_user_type", _count_daily_user_type
)
@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
# Convert the integer into a boolean.
return res == 1
def get_threepid_validation_session(
self, medium, client_secret, address=None, sid=None, validated=True
):
async def get_threepid_validation_session(
self,
medium: Optional[str],
client_secret: str,
address: Optional[str] = None,
sid: Optional[str] = None,
validated: Optional[bool] = True,
) -> Optional[Dict[str, Any]]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata
Args:
medium (str|None): The medium of the 3PID
address (str|None): The address of the 3PID
sid (str|None): The ID of the validation session
client_secret (str): A unique string provided by the client to help identify this
medium: The medium of the 3PID
client_secret: A unique string provided by the client to help identify this
validation attempt
validated (bool|None): Whether sessions should be filtered by
address: The address of the 3PID
sid: The ID of the validation session
validated: Whether sessions should be filtered by
whether they have been validated already or not. None to
perform no filtering
Returns:
Deferred[dict|None]: A dict containing the following:
A dict containing the following:
* address - address of the 3pid
* medium - medium of the 3pid
* client_secret - a secret provided by the client for this validation session
@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore):
return rows[0]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
)
def delete_threepid_session(self, session_id):
async def delete_threepid_session(self, session_id: str) -> None:
"""Removes a threepid validation session from the database. This can
be done after validation has been performed and whatever action was
waiting on it has been carried out
Args:
session_id (str): The ID of the session to delete
session_id: The ID of the session to delete
"""
def delete_threepid_session_txn(txn):
@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore):
keyvalues={"session_id": session_id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_threepid_session", delete_threepid_session_txn
)
@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="add_access_token_to_user",
)
def register_user(
async def register_user(
self,
user_id,
password_hash=None,
was_guest=False,
make_guest=False,
appservice_id=None,
create_profile_with_displayname=None,
admin=False,
user_type=None,
shadow_banned=False,
):
user_id: str,
password_hash: Optional[str] = None,
was_guest: bool = False,
make_guest: bool = False,
appservice_id: Optional[str] = None,
create_profile_with_displayname: Optional[str] = None,
admin: bool = False,
user_type: Optional[str] = None,
shadow_banned: bool = False,
) -> None:
"""Attempts to register an account.
Args:
user_id (str): The desired user ID to register.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str): The ID of the appservice registering the user.
create_profile_with_displayname (unicode): Optionally create a profile for
user_id: The desired user ID to register.
password_hash: Optional. The password hash for this user.
was_guest: Whether this is a guest account being upgraded to a
non-guest account.
make_guest: True if the the new user should be guest, false to add a
regular user account.
appservice_id: The ID of the appservice registering the user.
create_profile_with_displayname: Optionally create a profile for
the user, setting their displayname to the given value
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
shadow_banned (bool): Whether the user is shadow-banned,
i.e. they may be told their requests succeeded but we ignore them.
admin: is an admin user?
user_type: type of user. One of the values from api.constants.UserTypes,
or None for a normal user.
shadow_banned: Whether the user is shadow-banned, i.e. they may be
told their requests succeeded but we ignore them.
Raises:
StoreError if the user_id could not be registered.
Returns:
Deferred
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"register_user",
self._register_user,
user_id,
@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="record_user_external_id",
)
def user_set_password_hash(self, user_id, password_hash):
async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
"""
NB. This does *not* evict any cache because the one use for this
removes most of the entries subsequently anyway so it would be
@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
def user_set_consent_version(self, user_id, consent_version):
async def user_set_consent_version(
self, user_id: str, consent_version: str
) -> None:
"""Updates the user table to record privacy policy consent
Args:
user_id (str): full mxid of the user to update
consent_version (str): version of the policy the user has consented
to
user_id: full mxid of the user to update
consent_version: version of the policy the user has consented to
Raises:
StoreError(404) if user not found
@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db_pool.runInteraction("user_set_consent_version", f)
await self.db_pool.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
async def user_set_consent_server_notice_sent(
self, user_id: str, consent_version: str
) -> None:
"""Updates the user table to record that we have sent the user a server
notice about privacy policy consent
Args:
user_id (str): full mxid of the user to update
consent_version (str): version of the policy we have notified the
user about
user_id: full mxid of the user to update
consent_version: version of the policy we have notified the user about
Raises:
StoreError(404) if user not found
@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
async def user_delete_access_tokens(
self,
user_id: str,
except_token_id: Optional[str] = None,
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
Invalidate access tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
except_token_id (str): list of access_tokens IDs which should
*not* be deleted
device_id (str|None): ID of device the tokens are associated with.
user_id: ID of user the tokens belong to
except_token_id: access_tokens ID which should *not* be deleted
device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
defer.Deferred[list[str, int, str|None, int]]: a list of
(token, token id, device id) for each of the deleted tokens
A tuple of (token, token id, device id) for each of the deleted tokens
"""
def f(txn):
@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return tokens_and_devices
return self.db_pool.runInteraction("user_delete_access_tokens", f)
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
async def delete_access_token(self, access_token: str) -> None:
def f(txn):
self.db_pool.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, self.get_user_by_access_token, (access_token,)
)
return self.db_pool.runInteraction("delete_access_token", f)
await self.db_pool.runInteraction("delete_access_token", f)
@cached()
async def is_guest(self, user_id: str) -> bool:
@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="get_users_pending_deactivation",
)
def validate_threepid_session(self, session_id, client_secret, token, current_ts):
async def validate_threepid_session(
self, session_id: str, client_secret: str, token: str, current_ts: int
) -> Optional[str]:
"""Attempt to validate a threepid session using a token
Args:
session_id (str): The id of a validation session
client_secret (str): A unique string provided by the client to
help identify this validation attempt
token (str): A validation token
current_ts (int): The current unix time in milliseconds. Used for
checking token expiry status
session_id: The id of a validation session
client_secret: A unique string provided by the client to help identify
this validation attempt
token: A validation token
current_ts: The current unix time in milliseconds. Used for checking
token expiry status
Raises:
ThreepidValidationError: if a matching validation token was not found or has
expired
Returns:
deferred str|None: A str representing a link to redirect the user
to if there is one.
A str representing a link to redirect the user to if there is one.
"""
# Insert everything into a transaction in order to run atomically
@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return next_link
# Return next_link if it exists
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"validate_threepid_session_txn", validate_threepid_session_txn
)
def start_or_continue_validation_session(
async def start_or_continue_validation_session(
self,
medium,
address,
session_id,
client_secret,
send_attempt,
next_link,
token,
token_expires,
):
medium: str,
address: str,
session_id: str,
client_secret: str,
send_attempt: int,
next_link: Optional[str],
token: str,
token_expires: int,
) -> None:
"""Creates a new threepid validation session if it does not already
exist and associates a new validation token with it
Args:
medium (str): The medium of the 3PID
address (str): The address of the 3PID
session_id (str): The id of this validation session
client_secret (str): A unique string provided by the client to
help identify this validation attempt
send_attempt (int): The latest send_attempt on this session
next_link (str|None): The link to redirect the user to upon
successful validation
token (str): The validation token
token_expires (int): The timestamp for which after the token
will no longer be valid
medium: The medium of the 3PID
address: The address of the 3PID
session_id: The id of this validation session
client_secret: A unique string provided by the client to help
identify this validation attempt
send_attempt: The latest send_attempt on this session
next_link: The link to redirect the user to upon successful validation
token: The validation token
token_expires: The timestamp for which after the token will no
longer be valid
"""
def start_or_continue_validation_session_txn(txn):
@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
def cull_expired_threepid_validation_tokens(self):
async def cull_expired_threepid_validation_tokens(self) -> None:
"""Remove threepid validation tokens with expiry dates that have passed"""
def cull_expired_threepid_validation_tokens_txn(txn, ts):
@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
DELETE FROM threepid_validation_token WHERE
expires < ?
"""
return txn.execute(sql, (ts,))
txn.execute(sql, (ts,))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),

View File

@ -34,38 +34,33 @@ logger = logging.getLogger(__name__)
class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
def get_relations_for_event(
async def get_relations_for_event(
self,
event_id,
relation_type=None,
event_type=None,
aggregation_key=None,
limit=5,
direction="b",
from_token=None,
to_token=None,
):
event_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[RelationPaginationToken] = None,
to_token: Optional[RelationPaginationToken] = None,
) -> PaginationChunk:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
event_id (str): Fetch events that relate to this event ID.
relation_type (str|None): Only fetch events with this relation
type, if given.
event_type (str|None): Only fetch events with this event type, if
given.
aggregation_key (str|None): Only fetch events with this aggregation
key, if given.
limit (int): Only fetch the most recent `limit` events.
direction (str): Whether to fetch the most recent first (`"b"`) or
the oldest first (`"f"`).
from_token (RelationPaginationToken|None): Fetch rows from the given
token, or from the start if None.
to_token (RelationPaginationToken|None): Fetch rows up to the given
token, or up to the end if None.
event_id: Fetch events that relate to this event ID.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
Deferred[PaginationChunk]: List of event IDs that match relations
requested. The rows are of the form `{"event_id": "..."}`.
List of event IDs that match relations requested. The rows are of
the form `{"event_id": "..."}`.
"""
where_clause = ["relates_to_id = ?"]
@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
)
@cached(tree=True)
def get_aggregation_groups_for_event(
async def get_aggregation_groups_for_event(
self,
event_id,
event_type=None,
limit=5,
direction="b",
from_token=None,
to_token=None,
):
event_id: str,
event_type: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[AggregationPaginationToken] = None,
to_token: Optional[AggregationPaginationToken] = None,
) -> PaginationChunk:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore):
on an event.
Args:
event_id (str): Fetch events that relate to this event ID.
event_type (str|None): Only fetch events with this event type, if
given.
limit (int): Only fetch the `limit` groups.
direction (str): Whether to fetch the highest count first (`"b"`) or
event_id: Fetch events that relate to this event ID.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
direction: Whether to fetch the highest count first (`"b"`) or
the lowest count first (`"f"`).
from_token (AggregationPaginationToken|None): Fetch rows from the
given token, or from the start if None.
to_token (AggregationPaginationToken|None): Fetch rows up to the
given token, or up to the end if None.
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
Deferred[PaginationChunk]: List of groups of annotations that
match. Each row is a dict with `type`, `key` and `count` fields.
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
"""
where_clause = ["relates_to_id = ?", "relation_type = ?"]
@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore):
return await self.get_event(edit_id, allow_none=True)
def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
async def has_user_annotated_event(
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
) -> bool:
"""Check if a user has already annotated an event with the same key
(e.g. already liked an event).
Args:
parent_id (str): The event being annotated
event_type (str): The event type of the annotation
aggregation_key (str): The aggregation key of the annotation
sender (str): The sender of the annotation
parent_id: The event being annotated
event_type: The event type of the annotation
aggregation_key: The aggregation key of the annotation
sender: The sender of the annotation
Returns:
Deferred[bool]
True if the event is already annotated.
"""
sql = """
@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore):
return bool(txn.fetchone())
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)

View File

@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
allow_none=True,
)
def get_room_with_stats(self, room_id: str):
async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve room with statistics.
Args:
@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
res["public"] = bool(res["public"])
return res
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id
)
@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore):
desc="get_public_room_ids",
)
def count_public_rooms(self, network_tuple, ignore_non_federatable):
async def count_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
ignore_non_federatable: bool,
) -> int:
"""Counts the number of public rooms as tracked in the room_stats_current
and room_stats_state table.
Args:
network_tuple (ThirdPartyInstanceID|None)
ignore_non_federatable (bool): If true filters out non-federatable rooms
network_tuple
ignore_non_federatable: If true filters out non-federatable rooms
"""
def _count_public_rooms_txn(txn):
@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, query_args)
return txn.fetchone()[0]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_public_rooms", _count_public_rooms_txn
)
@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore):
return row
def get_media_mxcs_in_room(self, room_id):
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
room_id (str)
room_id
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
The local and remote media as a lists of the media IDs.
"""
def _get_media_mxcs_in_room_txn(txn):
@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
async def quarantine_media_ids_in_room(
self, room_id: str, quarantined_by: str
) -> int:
"""For a room loops through all events with media and quarantines
the associated media
"""
@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
def quarantine_media_by_id(
async def quarantine_media_by_id(
self, server_name: str, media_id: str, quarantined_by: str,
):
) -> int:
"""quarantines a single local or remote media id
Args:
@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_id_txn
)
def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
async def quarantine_media_ids_by_user(
self, user_id: str, quarantined_by: str
) -> int:
"""quarantines all local media associated with a single user
Args:
@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore):
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
self.hs.get_notifier().on_new_replication_data()
def get_room_count(self):
"""Retrieve a list of all rooms
async def get_room_count(self) -> int:
"""Retrieve the total number of rooms.
"""
def f(txn):
@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
return self.db_pool.runInteraction("get_rooms", f)
return await self.db_pool.runInteraction("get_rooms", f)
async def add_event_report(
self,

View File

@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@ -152,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id: str):
return self.db_pool.runInteraction(
async def get_users_in_room(self, room_id: str) -> List[str]:
return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
@ -180,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return [r[0] for r in txn]
@cached(max_entries=100000)
def get_room_summary(self, room_id: str):
async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
""" Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
room_id: The room ID to query
Returns:
Deferred[dict[str, MemberSummary]:
dict of membership states, pointing to a MemberSummary named tuple.
dict of membership states, pointing to a MemberSummary named tuple.
"""
def _get_room_summary_txn(txn):
@ -261,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return res
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
return await self.db_pool.runInteraction(
"get_room_summary", _get_room_summary_txn
)
@cached()
def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
"""Get all the rooms the *local* user is invited to.
Args:
user_id: The user ID.
Returns:
A awaitable list of RoomsForUser.
A list of RoomsForUser.
"""
return self.get_rooms_for_local_user_where_membership_is(
return await self.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE]
)
@ -297,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return None
async def get_rooms_for_local_user_where_membership_is(
self, user_id: str, membership_list: List[str]
) -> Optional[List[RoomsForUser]]:
self, user_id: str, membership_list: Collection[str]
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@ -313,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
The RoomsForUser that the user matches the membership types.
"""
if not membership_list:
return None
return []
rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is",
@ -357,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(max_entries=500000, iterable=True)
def get_rooms_for_user_with_stream_ordering(self, user_id: str):
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
"""Returns a set of room_ids the user is currently joined to.
If a remote user only returns rooms this server is currently
@ -367,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id
Returns:
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
the rooms the user is in currently, along with the stream ordering
of the most recent join for that user and room.
Returns the rooms the user is in currently, along with the stream
ordering of the most recent join for that user and room.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_rooms_for_user_with_stream_ordering",
self._get_rooms_for_user_with_stream_ordering_txn,
user_id,
)
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
def _get_rooms_for_user_with_stream_ordering_txn(
self, txn, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
# for rooms the server is participating in.
@ -404,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (user_id, Membership.JOIN))
results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
return results
return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
@ -711,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return count == 0
@cached()
def get_forgotten_rooms_for_user(self, user_id: str):
async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
"""Gets all rooms the user has forgotten.
Args:
user_id
user_id: The user ID to query the rooms of.
Returns:
Deferred[set[str]]
The forgotten rooms.
"""
def _get_forgotten_rooms_for_user_txn(txn):
@ -744,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id,))
return {row[0] for row in txn if row[1] == 0}
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
@ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
def forget(self, user_id: str, room_id: str):
async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@ -994,7 +996,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.db_pool.runInteraction("forget_membership", f)
await self.db_pool.runInteraction("forget_membership", f)
class _JoinedHostsCache(object):

View File

@ -0,0 +1,16 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE events ADD COLUMN instance_name TEXT;

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
SELECT setval('events_stream_seq', (
SELECT COALESCE(MAX(stream_ordering), 1) FROM events
));
CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
SELECT setval('events_backfill_stream_seq', (
SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
));

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- We're hijacking the push actions to store unread messages and unread counts (specified
-- in MSC2654) because doing otherwise would result in either performance issues or
-- reimplementing a consequent bit of the push actions.
-- Add columns to event_push_actions and event_push_actions_staging to track unread
-- messages and calculate unread counts.
ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
-- Add column to event_push_summary
ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;

View File

@ -16,9 +16,10 @@
import logging
import re
from collections import namedtuple
from typing import List, Optional
from typing import List, Optional, Set
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
"count": count,
}
def _find_highlights_in_postgres(self, search_query, events):
async def _find_highlights_in_postgres(
self, search_query: str, events: List[EventBase]
) -> Set[str]:
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
highlight the matching parts.
Args:
search_query (str)
events (list): A list of events
search_query
events: A list of events
Returns:
deferred : A set of strings.
A set of strings.
"""
def f(txn):
@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
return highlight_words
return self.db_pool.runInteraction("_find_highlights", f)
return await self.db_pool.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):

View File

@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, List, Tuple
from unpaddedbase64 import encode_base64
from synapse.storage._base import SQLBaseStore
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
)
def get_event_reference_hashes(self, event_ids):
async def get_event_reference_hashes(
self, event_ids: Iterable[str]
) -> Dict[str, Dict[str, bytes]]:
"""Get all hashes for given events.
Args:
event_ids: The event IDs to get hashes for.
Returns:
A mapping of event ID to a mapping of algorithm to hash.
"""
def f(txn):
return {
event_id: self._get_event_reference_hashes_txn(txn, event_id)
for event_id in event_ids
}
return self.db_pool.runInteraction("get_event_reference_hashes", f)
return await self.db_pool.runInteraction("get_event_reference_hashes", f)
async def add_event_hashes(self, event_ids):
async def add_event_hashes(
self, event_ids: Iterable[str]
) -> List[Tuple[str, Dict[str, str]]]:
"""
Args:
event_ids: The event IDs
Returns:
A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
"""
hashes = await self.get_event_reference_hashes(event_ids)
hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
return list(hashes.items())
def _get_event_reference_hashes_txn(self, txn, event_id):
def _get_event_reference_hashes_txn(
self, txn: Cursor, event_id: str
) -> Dict[str, bytes]:
"""Get all the hashes for a given PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
txn:
event_id: Id for the Event.
Returns:
A dict[unicode, bytes] of algorithm -> hash.
A mapping of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"

View File

@ -224,14 +224,32 @@ class StatsStore(StateDeltasStore):
)
async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
"""
Args:
room_id
fields
"""
"""Update the state of a room.
# For whatever reason some of the fields may contain null bytes, which
# postgres isn't a fan of, so we replace those fields with null.
fields can contain the following keys with string values:
* join_rules
* history_visibility
* encryption
* name
* topic
* avatar
* canonical_alias
A is_federatable key can also be included with a boolean value.
Args:
room_id: The room ID to update the state of.
fields: The fields to update. This can include a partial list of the
above fields to only update some room information.
"""
# Ensure that the values to update are valid, they should be strings and
# not contain any null bytes.
#
# Invalid data gets overwritten with null.
#
# Note that a missing value should not be overwritten (it keeps the
# previous value).
sentinel = object()
for col in (
"join_rules",
"history_visibility",
@ -241,8 +259,8 @@ class StatsStore(StateDeltasStore):
"avatar",
"canonical_alias",
):
field = fields.get(col)
if field and "\0" in field:
field = fields.get(col, sentinel)
if field is not sentinel and (not isinstance(field, str) or "\0" in field):
fields[col] = None
await self.db_pool.simple_upsert(

View File

@ -39,7 +39,7 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from twisted.internet import defer
@ -47,12 +47,19 @@ from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken
from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@ -202,7 +209,7 @@ def _make_generic_sql_bound(
)
def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
@ -260,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
__metaclass__ = abc.ABCMeta
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@ -293,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._stream_order_on_start = self.get_room_max_stream_ordering()
@abc.abstractmethod
def get_room_max_stream_ordering(self):
def get_room_max_stream_ordering(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
def get_room_min_stream_ordering(self):
def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()
async def get_room_events_stream_for_rooms(
self,
room_ids: Iterable[str],
room_ids: Collection[str],
from_key: str,
to_key: str,
limit: int = 0,
@ -356,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
def get_rooms_that_changed(self, room_ids, from_key):
def get_rooms_that_changed(
self, room_ids: Collection[str], from_key: str
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
room_ids (list)
from_key (str): The room_key portion of a StreamToken
room_ids
from_key: The room_key portion of a StreamToken
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
from_id = RoomStreamToken.parse_stream_token(from_key).stream
return {
room_id
for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
if self._events_stream_cache.has_entity_changed(room_id, from_id)
}
async def get_room_events_stream_for_room(
@ -440,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(self, user_id, from_key, to_key):
async def get_membership_changes_for_user(
self, user_id: str, from_key: str, to_key: str
) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@ -593,8 +604,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A stream ID.
"""
return await self.db_pool.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
return await self.db_pool.runInteraction(
"get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
)
def get_stream_id_for_event_txn(
self, txn: LoggingTransaction, event_id: str, allow_none=False,
) -> int:
return self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="events",
keyvalues={"event_id": event_id},
retcol="stream_ordering",
allow_none=allow_none,
)
async def get_stream_token_for_event(self, event_id: str) -> str:
@ -646,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return row[0][0] if row else 0
def _get_max_topological_txn(self, txn, room_id):
def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
txn.execute(
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
@ -719,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_events_around_txn(
self,
txn,
txn: LoggingTransaction,
room_id: str,
event_id: str,
before_limit: int,
@ -747,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
# This cannot happen as `allow_none=False`.
assert results is not None
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
@ -856,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="update_federation_out_pos",
)
def _reset_federation_positions_txn(self, txn) -> None:
def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
@ -895,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
GROUP BY type
"""
txn.execute(sql)
min_positions = dict(txn) # Map from type -> min position
min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
# Ensure we do actually have some values here
assert set(min_positions) == {"federation", "events"}
@ -922,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _paginate_room_events_txn(
self,
txn,
txn: LoggingTransaction,
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,

View File

@ -43,7 +43,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
tags_by_room = {}
tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
@ -123,7 +123,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_updated_tags(
self, user_id: str, stream_id: int
) -> Dict[str, List[str]]:
) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the
given version

View File

@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore):
class UIAuthStore(UIAuthWorkerStore):
def delete_old_ui_auth_sessions(self, expiration_time: int):
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
Remove sessions which were last used earlier than the expiration time.
@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore):
This is an epoch time in milliseconds.
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_old_ui_auth_sessions",
self._delete_old_ui_auth_sessions_txn,
expiration_time,

View File

@ -15,7 +15,7 @@
import logging
import re
from typing import Any, Dict, Optional
from typing import Any, Dict, Iterable, Optional, Set, Tuple
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@ -365,10 +365,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False
def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
async def update_profile_in_user_dir(
self, user_id: str, display_name: str, avatar_url: str
) -> None:
"""
Update or add a user's profile in the user directory.
"""
# If the display name or avatar URL are unexpected types, overwrite them.
if not isinstance(display_name, str):
display_name = None
if not isinstance(avatar_url, str):
avatar_url = None
def _update_profile_in_user_dir_txn(txn):
new_entry = self.db_pool.simple_upsert_txn(
@ -458,17 +465,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
def add_users_who_share_private_room(self, room_id, user_id_tuples):
async def add_users_who_share_private_room(
self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
room_id (str)
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
room_id
user_id_tuples: iterable of 2-tuple of user IDs.
"""
def _add_users_who_share_room_txn(txn):
@ -484,17 +493,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
def add_users_in_public_rooms(self, room_id, user_ids):
async def add_users_in_public_rooms(
self, room_id: str, user_ids: Iterable[str]
) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
room_id (str)
user_ids (list[str])
room_id
user_ids
"""
def _add_users_in_public_rooms_txn(txn):
@ -508,11 +519,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
def delete_all_from_user_dir(self):
async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory
"""
@ -523,7 +534,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@ -555,7 +566,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(UserDirectoryStore, self).__init__(database, db_conn, hs)
def remove_from_user_dir(self, user_id):
async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
@ -578,7 +589,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"remove_from_user_dir", _remove_from_user_dir_txn
)
@ -605,14 +616,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return user_ids
def remove_user_who_share_room(self, user_id, room_id):
async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
"""
Deletes entries in the users_who_share_*_rooms table. The first
user should be a local user.
Args:
user_id (str)
room_id (str)
user_id
room_id
"""
def _remove_user_who_share_room_txn(txn):
@ -632,7 +643,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
keyvalues={"user_id": user_id, "room_id": room_id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
@ -664,6 +675,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
@cached()
async def get_shared_rooms_for_users(
self, user_id: str, other_user_id: str
) -> Set[str]:
"""
Returns the rooms that a local user shares with another local or remote user.
Args:
user_id: The MXID of a local user
other_user_id: The MXID of the other user
Returns:
A set of room ID's that the users share.
"""
def _get_shared_rooms_for_users_txn(txn):
txn.execute(
"""
SELECT p1.room_id
FROM users_in_public_rooms as p1
INNER JOIN users_in_public_rooms as p2
ON p1.room_id = p2.room_id
AND p1.user_id = ?
AND p2.user_id = ?
UNION
SELECT room_id
FROM users_who_share_private_rooms
WHERE
user_id = ?
AND other_user_id = ?
""",
(user_id, other_user_id, user_id, other_user_id),
)
rows = self.db_pool.cursor_to_dict(txn)
return rows
rows = await self.db_pool.runInteraction(
"get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
)
return {row["room_id"] for row in rows}
async def get_user_directory_stream_pos(self) -> int:
return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",

View File

@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
class UserErasureStore(UserErasureWorkerStore):
def mark_user_erased(self, user_id: str) -> None:
async def mark_user_erased(self, user_id: str) -> None:
"""Indicate that user_id wishes their message history to be erased.
Args:
@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
return self.db_pool.runInteraction("mark_user_erased", f)
await self.db_pool.runInteraction("mark_user_erased", f)
def mark_user_not_erased(self, user_id: str) -> None:
async def mark_user_not_erased(self, user_id: str) -> None:
"""Indicate that user_id is no longer erased.
Args:
@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
return self.db_pool.runInteraction("mark_user_not_erased", f)
await self.db_pool.runInteraction("mark_user_not_erased", f)

View File

@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
positive: Whether the IDs are positive (true) or negative (false).
When using negative IDs we go backwards from -1 to -2, -3, etc.
"""
def __init__(
@ -196,13 +198,19 @@ class MultiWriterIdGenerator:
instance_column: str,
id_column: str,
sequence_name: str,
positive: bool = True,
):
self._db = db
self._instance_name = instance_name
self._positive = positive
self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()
# Note: If we are a negative stream then we still store all the IDs as
# positive to make life easier for us, and simply negate the IDs when we
# return them.
self._current_positions = self._load_current_ids(
db_conn, table, instance_column, id_column
)
@ -223,8 +231,12 @@ class MultiWriterIdGenerator:
# gaps should be relatively rare it's still worth doing the book keeping
# that allows us to skip forwards when there are gapless runs of
# positions.
#
# We start at 1 here as a) the first generated stream ID will be 2, and
# b) other parts of the code assume that stream IDs are strictly greater
# than 0.
self._persisted_upto_position = (
min(self._current_positions.values()) if self._current_positions else 0
min(self._current_positions.values()) if self._current_positions else 1
)
self._known_persisted_positions = [] # type: List[int]
@ -233,13 +245,16 @@ class MultiWriterIdGenerator:
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
# If positive stream aggregate via MAX. For negative stream use MIN
# *and* negate the result to get a positive number.
sql = """
SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
GROUP BY %(instance)s
""" % {
"instance": instance_column,
"id": id_column,
"table": table,
"agg": "MAX" if self._positive else "-MIN",
}
cur = db_conn.cursor()
@ -269,15 +284,16 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
assert self.get_current_token_for_writer(self._instance_name) < next_id
with self._lock:
assert self._current_positions.get(self._instance_name, 0) < next_id
self._unfinished_ids.add(next_id)
@contextlib.contextmanager
def manager():
try:
yield next_id
# Multiply by the return factor so that the ID has correct sign.
yield self._return_factor * next_id
finally:
self._mark_id_as_finished(next_id)
@ -296,15 +312,15 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
assert max(self.get_positions().values(), default=0) < min(next_ids)
with self._lock:
assert max(self._current_positions.values(), default=0) < min(next_ids)
self._unfinished_ids.update(next_ids)
@contextlib.contextmanager
def manager():
try:
yield next_ids
yield [self._return_factor * i for i in next_ids]
finally:
for i in next_ids:
self._mark_id_as_finished(i)
@ -327,7 +343,7 @@ class MultiWriterIdGenerator:
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
return next_id
return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
@ -350,29 +366,32 @@ class MultiWriterIdGenerator:
equal to it have been successfully persisted.
"""
# Currently we don't support this operation, as it's not obvious how to
# condense the stream positions of multiple writers into a single int.
raise NotImplementedError()
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
"""
with self._lock:
return self._current_positions.get(instance_name, 0)
return self._return_factor * self._current_positions.get(instance_name, 0)
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
"""
with self._lock:
return dict(self._current_positions)
return {
name: self._return_factor * i
for name, i in self._current_positions.items()
}
def advance(self, instance_name: str, new_id: int):
"""Advance the postion of the named writer to the given ID, if greater
than existing entry.
"""
new_id *= self._return_factor
with self._lock:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
@ -390,7 +409,7 @@ class MultiWriterIdGenerator:
"""
with self._lock:
return self._persisted_upto_position
return self._return_factor * self._persisted_upto_position
def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position.

View File

@ -20,6 +20,7 @@ from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union
import attr
from typing_extensions import ContextManager
from twisted.internet import defer
from twisted.internet.defer import CancelledError
@ -338,11 +339,11 @@ class Linearizer(object):
class ReadWriteLock(object):
"""A deferred style read write lock.
"""An async read write lock.
Example:
with (yield read_write_lock.read("test_key")):
with await read_write_lock.read("test_key"):
# do some work
"""
@ -365,8 +366,7 @@ class ReadWriteLock(object):
# Latest writer queued
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks
def read(self, key):
async def read(self, key: str) -> ContextManager:
new_defer = defer.Deferred()
curr_readers = self.key_to_current_readers.setdefault(key, set())
@ -376,7 +376,8 @@ class ReadWriteLock(object):
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
yield make_deferred_yieldable(curr_writer)
if curr_writer:
await make_deferred_yieldable(curr_writer)
@contextmanager
def _ctx_manager():
@ -388,8 +389,7 @@ class ReadWriteLock(object):
return _ctx_manager()
@defer.inlineCallbacks
def write(self, key):
async def write(self, key: str) -> ContextManager:
new_defer = defer.Deferred()
curr_readers = self.key_to_current_readers.get(key, set())
@ -405,7 +405,7 @@ class ReadWriteLock(object):
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager
def _ctx_manager():

Some files were not shown because too many files have changed in this diff Show More