Merge branch 'develop' into matrix-org-hotfixes
commit
505ea932f5
13
UPGRADE.rst
13
UPGRADE.rst
|
@ -1,3 +1,16 @@
|
|||
Upgrading to v1.20.0
|
||||
====================
|
||||
|
||||
Shared rooms endpoint (MSC2666)
|
||||
-------------------------------
|
||||
|
||||
This release contains a new unstable endpoint `/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*`
|
||||
for fetching rooms one user has in common with another. This feature requires the
|
||||
`update_user_directory` config flag to be `True`. If you are you are using a `synapse.app.user_dir`
|
||||
worker, requests to this endpoint must be handled by that worker.
|
||||
See `docs/workers.md <docs/workers.md>`_ for more details.
|
||||
|
||||
|
||||
Upgrading Synapse
|
||||
=================
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Add an endpoint to query your shared rooms with another user as an implementation of [MSC2666](https://github.com/matrix-org/matrix-doc/pull/2666).
|
|
@ -0,0 +1 @@
|
|||
Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654).
|
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -0,0 +1 @@
|
|||
Add experimental support for sharding event persister.
|
|
@ -0,0 +1 @@
|
|||
Explain better what GDPR-erased means when deactivating a user.
|
|
@ -0,0 +1 @@
|
|||
Fix `wait_for_stream_position` to allow multiple waiters on same stream ID.
|
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -0,0 +1 @@
|
|||
Make `MultiWriterIDGenerator` work for streams that use negative values.
|
|
@ -0,0 +1 @@
|
|||
Refactor queries for device keys and cross-signatures.
|
|
@ -0,0 +1 @@
|
|||
Refactor queries for device keys and cross-signatures.
|
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -0,0 +1 @@
|
|||
Refactor queries for device keys and cross-signatures.
|
|
@ -0,0 +1 @@
|
|||
Fixes a longstanding bug where user directory updates could break when unexpected profile data was included in events.
|
|
@ -0,0 +1 @@
|
|||
Refactor queries for device keys and cross-signatures.
|
|
@ -0,0 +1 @@
|
|||
Refactor queries for device keys and cross-signatures.
|
|
@ -0,0 +1 @@
|
|||
Fix a longstanding bug where stats updates could break when unexpected profile data was included in events.
|
|
@ -0,0 +1 @@
|
|||
Refactor queries for device keys and cross-signatures.
|
|
@ -0,0 +1 @@
|
|||
Add type hints to `StreamStore`.
|
|
@ -0,0 +1 @@
|
|||
Add type hints to `StreamStore`.
|
|
@ -0,0 +1 @@
|
|||
Fix type hints in `SyncHandler`.
|
|
@ -214,9 +214,11 @@ Deactivate Account
|
|||
|
||||
This API deactivates an account. It removes active access tokens, resets the
|
||||
password, and deletes third-party IDs (to prevent the user requesting a
|
||||
password reset). It can also mark the user as GDPR-erased (stopping their data
|
||||
from distributed further, and deleting it entirely if there are no other
|
||||
references to it).
|
||||
password reset).
|
||||
|
||||
It can also mark the user as GDPR-erased. This means messages sent by the
|
||||
user will still be visible by anyone that was in the room when these messages
|
||||
were sent, but hidden from users joining the room afterwards.
|
||||
|
||||
The api is::
|
||||
|
||||
|
|
|
@ -380,6 +380,7 @@ Handles searches in the user directory. It can handle REST endpoints matching
|
|||
the following regular expressions:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$
|
||||
^/_matrix/client/unstable/uk.half-shot.msc2666/user/shared_rooms/.*$
|
||||
|
||||
When using this worker you must also set `update_user_directory: False` in the
|
||||
shared configuration file to stop the main synapse running background
|
||||
|
|
2
mypy.ini
2
mypy.ini
|
@ -28,6 +28,7 @@ files =
|
|||
synapse/handlers/saml_handler.py,
|
||||
synapse/handlers/sync.py,
|
||||
synapse/handlers/ui_auth,
|
||||
synapse/http/federation/well_known_resolver.py,
|
||||
synapse/http/server.py,
|
||||
synapse/http/site.py,
|
||||
synapse/logging/,
|
||||
|
@ -42,6 +43,7 @@ files =
|
|||
synapse/server_notices,
|
||||
synapse/spam_checker_api,
|
||||
synapse/state,
|
||||
synapse/storage/databases/main/stream.py,
|
||||
synapse/storage/databases/main/ui_auth.py,
|
||||
synapse/storage/database.py,
|
||||
synapse/storage/engines,
|
||||
|
|
|
@ -334,6 +334,13 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
|
|||
This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we
|
||||
can run out of file descriptors and infinite loop if we attempt to do too
|
||||
many DNS queries at once
|
||||
|
||||
XXX: I'm confused by this. reactor.nameResolver does not use twisted.names unless
|
||||
you explicitly install twisted.names as the resolver; rather it uses a GAIResolver
|
||||
backed by the reactor's default threadpool (which is limited to 10 threads). So
|
||||
(a) I don't understand why twisted ticket 9620 is relevant, and (b) I don't
|
||||
understand why we would run out of FDs if we did too many lookups at once.
|
||||
-- richvdh 2020/08/29
|
||||
"""
|
||||
new_resolver = _LimitedHostnameResolver(
|
||||
reactor.nameResolver, max_dns_requests_in_flight
|
||||
|
|
|
@ -79,8 +79,7 @@ class AdminCmdServer(HomeServer):
|
|||
pass
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def export_data_command(hs, args):
|
||||
async def export_data_command(hs, args):
|
||||
"""Export data for a user.
|
||||
|
||||
Args:
|
||||
|
@ -91,10 +90,8 @@ def export_data_command(hs, args):
|
|||
user_id = args.user_id
|
||||
directory = args.output_directory
|
||||
|
||||
res = yield defer.ensureDeferred(
|
||||
hs.get_handlers().admin_handler.export_user_data(
|
||||
user_id, FileExfiltrationWriter(user_id, directory=directory)
|
||||
)
|
||||
res = await hs.get_handlers().admin_handler.export_user_data(
|
||||
user_id, FileExfiltrationWriter(user_id, directory=directory)
|
||||
)
|
||||
print(res)
|
||||
|
||||
|
@ -232,14 +229,15 @@ def start(config_options):
|
|||
# We also make sure that `_base.start` gets run before we actually run the
|
||||
# command.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def run(_reactor):
|
||||
async def run():
|
||||
with LoggingContext("command"):
|
||||
yield _base.start(ss, [])
|
||||
yield args.func(ss, args)
|
||||
_base.start(ss, [])
|
||||
await args.func(ss, args)
|
||||
|
||||
_base.start_worker_reactor(
|
||||
"synapse-admin-cmd", config, run_command=lambda: task.react(run)
|
||||
"synapse-admin-cmd",
|
||||
config,
|
||||
run_command=lambda: task.react(lambda _reactor: defer.ensureDeferred(run())),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -411,26 +411,24 @@ def setup(config_options):
|
|||
|
||||
return provision
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def reprovision_acme():
|
||||
async def reprovision_acme():
|
||||
"""
|
||||
Provision a certificate from ACME, if required, and reload the TLS
|
||||
certificate if it's renewed.
|
||||
"""
|
||||
reprovisioned = yield defer.ensureDeferred(do_acme())
|
||||
reprovisioned = await do_acme()
|
||||
if reprovisioned:
|
||||
_base.refresh_certificate(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start():
|
||||
async def start():
|
||||
try:
|
||||
# Run the ACME provisioning code, if it's enabled.
|
||||
if hs.config.acme_enabled:
|
||||
acme = hs.get_acme_handler()
|
||||
# Start up the webservices which we will respond to ACME
|
||||
# challenges with, and then provision.
|
||||
yield defer.ensureDeferred(acme.start_listening())
|
||||
yield defer.ensureDeferred(do_acme())
|
||||
await acme.start_listening()
|
||||
await do_acme()
|
||||
|
||||
# Check if it needs to be reprovisioned every day.
|
||||
hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
|
||||
|
@ -439,8 +437,8 @@ def setup(config_options):
|
|||
if hs.config.oidc_enabled:
|
||||
oidc = hs.get_oidc_handler()
|
||||
# Loading the provider metadata also ensures the provider config is valid.
|
||||
yield defer.ensureDeferred(oidc.load_metadata())
|
||||
yield defer.ensureDeferred(oidc.load_jwks())
|
||||
await oidc.load_metadata()
|
||||
await oidc.load_jwks()
|
||||
|
||||
_base.start(hs, config.listeners)
|
||||
|
||||
|
@ -456,7 +454,7 @@ def setup(config_options):
|
|||
reactor.stop()
|
||||
sys.exit(1)
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
reactor.callWhenRunning(lambda: defer.ensureDeferred(start()))
|
||||
|
||||
return hs
|
||||
|
||||
|
|
|
@ -14,18 +14,20 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import urllib
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, ThirdPartyEntityKind
|
||||
from synapse.api.errors import CodeMessageException
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.types import ThirdPartyInstanceID
|
||||
from synapse.types import JsonDict, ThirdPartyInstanceID
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.appservice import ApplicationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
sent_transactions_counter = Counter(
|
||||
|
@ -163,19 +165,20 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
logger.warning("query_3pe to %s threw exception %s", uri, ex)
|
||||
return []
|
||||
|
||||
def get_3pe_protocol(self, service, protocol):
|
||||
async def get_3pe_protocol(
|
||||
self, service: "ApplicationService", protocol: str
|
||||
) -> Optional[JsonDict]:
|
||||
if service.url is None:
|
||||
return {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get():
|
||||
async def _get() -> Optional[JsonDict]:
|
||||
uri = "%s%s/thirdparty/protocol/%s" % (
|
||||
service.url,
|
||||
APP_SERVICE_PREFIX,
|
||||
urllib.parse.quote(protocol),
|
||||
)
|
||||
try:
|
||||
info = yield defer.ensureDeferred(self.get_json(uri, {}))
|
||||
info = await self.get_json(uri, {})
|
||||
|
||||
if not _is_valid_3pe_metadata(info):
|
||||
logger.warning(
|
||||
|
@ -196,7 +199,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
return None
|
||||
|
||||
key = (service.id, protocol)
|
||||
return self.protocol_meta_cache.wrap(key, _get)
|
||||
return await self.protocol_meta_cache.wrap(key, _get)
|
||||
|
||||
async def push_bulk(self, service, events, txn_id=None):
|
||||
if service.url is None:
|
||||
|
|
|
@ -832,11 +832,26 @@ class ShardedWorkerHandlingConfig:
|
|||
def should_handle(self, instance_name: str, key: str) -> bool:
|
||||
"""Whether this instance is responsible for handling the given key.
|
||||
"""
|
||||
|
||||
# If multiple instances are not defined we always return true.
|
||||
# If multiple instances are not defined we always return true
|
||||
if not self.instances or len(self.instances) == 1:
|
||||
return True
|
||||
|
||||
return self.get_instance(key) == instance_name
|
||||
|
||||
def get_instance(self, key: str) -> str:
|
||||
"""Get the instance responsible for handling the given key.
|
||||
|
||||
Note: For things like federation sending the config for which instance
|
||||
is sending is known only to the sender instance if there is only one.
|
||||
Therefore `should_handle` should be used where possible.
|
||||
"""
|
||||
|
||||
if not self.instances:
|
||||
return "master"
|
||||
|
||||
if len(self.instances) == 1:
|
||||
return self.instances[0]
|
||||
|
||||
# We shard by taking the hash, modulo it by the number of instances and
|
||||
# then checking whether this instance matches the instance at that
|
||||
# index.
|
||||
|
@ -846,7 +861,7 @@ class ShardedWorkerHandlingConfig:
|
|||
dest_hash = sha256(key.encode("utf8")).digest()
|
||||
dest_int = int.from_bytes(dest_hash, byteorder="little")
|
||||
remainder = dest_int % (len(self.instances))
|
||||
return self.instances[remainder] == instance_name
|
||||
return self.instances[remainder]
|
||||
|
||||
|
||||
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
|
||||
|
|
|
@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig:
|
|||
instances: List[str]
|
||||
def __init__(self, instances: List[str]) -> None: ...
|
||||
def should_handle(self, instance_name: str, key: str) -> bool: ...
|
||||
def get_instance(self, key: str) -> str: ...
|
||||
|
|
|
@ -13,12 +13,24 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
|
||||
from .server import ListenerConfig, parse_listener_def
|
||||
|
||||
|
||||
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
|
||||
"""Helper for allowing parsing a string or list of strings to a config
|
||||
option expecting a list of strings.
|
||||
"""
|
||||
|
||||
if isinstance(obj, str):
|
||||
return [obj]
|
||||
return obj
|
||||
|
||||
|
||||
@attr.s
|
||||
class InstanceLocationConfig:
|
||||
"""The host and port to talk to an instance via HTTP replication.
|
||||
|
@ -33,11 +45,13 @@ class WriterLocations:
|
|||
"""Specifies the instances that write various streams.
|
||||
|
||||
Attributes:
|
||||
events: The instance that writes to the event and backfill streams.
|
||||
events: The instance that writes to the typing stream.
|
||||
events: The instances that write to the event and backfill streams.
|
||||
typing: The instance that writes to the typing stream.
|
||||
"""
|
||||
|
||||
events = attr.ib(default="master", type=str)
|
||||
events = attr.ib(
|
||||
default=["master"], type=List[str], converter=_instance_to_list_converter
|
||||
)
|
||||
typing = attr.ib(default="master", type=str)
|
||||
|
||||
|
||||
|
@ -105,15 +119,18 @@ class WorkerConfig(Config):
|
|||
writers = config.get("stream_writers") or {}
|
||||
self.writers = WriterLocations(**writers)
|
||||
|
||||
# Check that the configured writer for events and typing also appears in
|
||||
# Check that the configured writers for events and typing also appears in
|
||||
# `instance_map`.
|
||||
for stream in ("events", "typing"):
|
||||
instance = getattr(self.writers, stream)
|
||||
if instance != "master" and instance not in self.instance_map:
|
||||
raise ConfigError(
|
||||
"Instance %r is configured to write %s but does not appear in `instance_map` config."
|
||||
% (instance, stream)
|
||||
)
|
||||
instances = _instance_to_list_converter(getattr(self.writers, stream))
|
||||
for instance in instances:
|
||||
if instance != "master" and instance not in self.instance_map:
|
||||
raise ConfigError(
|
||||
"Instance %r is configured to write %s but does not appear in `instance_map` config."
|
||||
% (instance, stream)
|
||||
)
|
||||
|
||||
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """\
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
import abc
|
||||
import os
|
||||
from distutils.util import strtobool
|
||||
from typing import Dict, Optional, Type
|
||||
from typing import Dict, Optional, Tuple, Type
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
|
@ -120,7 +120,7 @@ class _EventInternalMetadata(object):
|
|||
# be here
|
||||
before = DictProperty("before") # type: str
|
||||
after = DictProperty("after") # type: str
|
||||
order = DictProperty("order") # type: int
|
||||
order = DictProperty("order") # type: Tuple[int, int]
|
||||
|
||||
def get_dict(self) -> JsonDict:
|
||||
return dict(self._dict)
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from nacl.signing import SigningKey
|
||||
|
@ -97,14 +97,14 @@ class EventBuilder(object):
|
|||
def is_state(self):
|
||||
return self._state_key is not None
|
||||
|
||||
async def build(self, prev_event_ids):
|
||||
async def build(self, prev_event_ids: List[str]) -> EventBase:
|
||||
"""Transform into a fully signed and hashed event
|
||||
|
||||
Args:
|
||||
prev_event_ids (list[str]): The event IDs to use as the prev events
|
||||
prev_event_ids: The event IDs to use as the prev events
|
||||
|
||||
Returns:
|
||||
FrozenEvent
|
||||
The signed and hashed event.
|
||||
"""
|
||||
|
||||
state_ids = await self._state.get_current_state_ids(
|
||||
|
@ -114,8 +114,13 @@ class EventBuilder(object):
|
|||
|
||||
format_version = self.room_version.event_format
|
||||
if format_version == EventFormatVersions.V1:
|
||||
auth_events = await self._store.add_event_hashes(auth_ids)
|
||||
prev_events = await self._store.add_event_hashes(prev_event_ids)
|
||||
# The types of auth/prev events changes between event versions.
|
||||
auth_events = await self._store.add_event_hashes(
|
||||
auth_ids
|
||||
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
||||
prev_events = await self._store.add_event_hashes(
|
||||
prev_event_ids
|
||||
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
||||
else:
|
||||
auth_events = auth_ids
|
||||
prev_events = prev_event_ids
|
||||
|
@ -138,7 +143,7 @@ class EventBuilder(object):
|
|||
"unsigned": self.unsigned,
|
||||
"depth": depth,
|
||||
"prev_state": [],
|
||||
}
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
if self.is_state():
|
||||
event_dict["state_key"] = self._state_key
|
||||
|
|
|
@ -234,7 +234,9 @@ class DeviceWorkerHandler(BaseHandler):
|
|||
return result
|
||||
|
||||
async def on_federation_query_user_devices(self, user_id):
|
||||
stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
|
||||
stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
|
||||
user_id
|
||||
)
|
||||
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
|
||||
self_signing_key = await self.store.get_e2e_cross_signing_key(
|
||||
user_id, "self_signing"
|
||||
|
|
|
@ -353,7 +353,7 @@ class E2eKeysHandler(object):
|
|||
# make sure that each queried user appears in the result dict
|
||||
result_dict[user_id] = {}
|
||||
|
||||
results = await self.store.get_e2e_device_keys(local_query)
|
||||
results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
|
||||
|
||||
# Build the result structure
|
||||
for user_id, device_keys in results.items():
|
||||
|
@ -734,7 +734,7 @@ class E2eKeysHandler(object):
|
|||
# fetch our stored devices. This is used to 1. verify
|
||||
# signatures on the master key, and 2. to compare with what
|
||||
# was sent if the device was signed
|
||||
devices = await self.store.get_e2e_device_keys([(user_id, None)])
|
||||
devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)])
|
||||
|
||||
if user_id not in devices:
|
||||
raise NotFoundError("No device keys found")
|
||||
|
|
|
@ -923,7 +923,8 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
)
|
||||
|
||||
await self._handle_new_events(dest, ev_infos, backfilled=True)
|
||||
if ev_infos:
|
||||
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
|
||||
|
||||
# Step 2: Persist the rest of the events in the chunk one by one
|
||||
events.sort(key=lambda e: e.depth)
|
||||
|
@ -1216,7 +1217,7 @@ class FederationHandler(BaseHandler):
|
|||
event_infos.append(_NewEventInfo(event, None, auth))
|
||||
|
||||
await self._handle_new_events(
|
||||
destination, event_infos,
|
||||
destination, room_id, event_infos,
|
||||
)
|
||||
|
||||
def _sanity_check_event(self, ev):
|
||||
|
@ -1363,15 +1364,15 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
max_stream_id = await self._persist_auth_tree(
|
||||
origin, auth_chain, state, event, room_version_obj
|
||||
origin, room_id, auth_chain, state, event, room_version_obj
|
||||
)
|
||||
|
||||
# We wait here until this instance has seen the events come down
|
||||
# replication (if we're using replication) as the below uses caches.
|
||||
#
|
||||
# TODO: Currently the events stream is written to from master
|
||||
await self._replication.wait_for_stream_position(
|
||||
self.config.worker.writers.events, "events", max_stream_id
|
||||
self.config.worker.events_shard_config.get_instance(room_id),
|
||||
"events",
|
||||
max_stream_id,
|
||||
)
|
||||
|
||||
# Check whether this room is the result of an upgrade of a room we already know
|
||||
|
@ -1625,7 +1626,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
context = await self.state_handler.compute_event_context(event)
|
||||
await self.persist_events_and_notify([(event, context)])
|
||||
await self.persist_events_and_notify(event.room_id, [(event, context)])
|
||||
|
||||
return event
|
||||
|
||||
|
@ -1652,7 +1653,9 @@ class FederationHandler(BaseHandler):
|
|||
await self.federation_client.send_leave(host_list, event)
|
||||
|
||||
context = await self.state_handler.compute_event_context(event)
|
||||
stream_id = await self.persist_events_and_notify([(event, context)])
|
||||
stream_id = await self.persist_events_and_notify(
|
||||
event.room_id, [(event, context)]
|
||||
)
|
||||
|
||||
return event, stream_id
|
||||
|
||||
|
@ -1900,7 +1903,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
await self.persist_events_and_notify(
|
||||
[(event, context)], backfilled=backfilled
|
||||
event.room_id, [(event, context)], backfilled=backfilled
|
||||
)
|
||||
except Exception:
|
||||
run_in_background(
|
||||
|
@ -1913,6 +1916,7 @@ class FederationHandler(BaseHandler):
|
|||
async def _handle_new_events(
|
||||
self,
|
||||
origin: str,
|
||||
room_id: str,
|
||||
event_infos: Iterable[_NewEventInfo],
|
||||
backfilled: bool = False,
|
||||
) -> None:
|
||||
|
@ -1944,6 +1948,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
await self.persist_events_and_notify(
|
||||
room_id,
|
||||
[
|
||||
(ev_info.event, context)
|
||||
for ev_info, context in zip(event_infos, contexts)
|
||||
|
@ -1954,6 +1959,7 @@ class FederationHandler(BaseHandler):
|
|||
async def _persist_auth_tree(
|
||||
self,
|
||||
origin: str,
|
||||
room_id: str,
|
||||
auth_events: List[EventBase],
|
||||
state: List[EventBase],
|
||||
event: EventBase,
|
||||
|
@ -1968,6 +1974,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
Args:
|
||||
origin: Where the events came from
|
||||
room_id,
|
||||
auth_events
|
||||
state
|
||||
event
|
||||
|
@ -2042,17 +2049,20 @@ class FederationHandler(BaseHandler):
|
|||
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
|
||||
|
||||
await self.persist_events_and_notify(
|
||||
room_id,
|
||||
[
|
||||
(e, events_to_context[e.event_id])
|
||||
for e in itertools.chain(auth_events, state)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
new_event_context = await self.state_handler.compute_event_context(
|
||||
event, old_state=state
|
||||
)
|
||||
|
||||
return await self.persist_events_and_notify([(event, new_event_context)])
|
||||
return await self.persist_events_and_notify(
|
||||
room_id, [(event, new_event_context)]
|
||||
)
|
||||
|
||||
async def _prep_event(
|
||||
self,
|
||||
|
@ -2903,6 +2913,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
async def persist_events_and_notify(
|
||||
self,
|
||||
room_id: str,
|
||||
event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool = False,
|
||||
) -> int:
|
||||
|
@ -2910,14 +2921,19 @@ class FederationHandler(BaseHandler):
|
|||
necessary.
|
||||
|
||||
Args:
|
||||
event_and_contexts:
|
||||
room_id: The room ID of events being persisted.
|
||||
event_and_contexts: Sequence of events with their associated
|
||||
context that should be persisted. All events must belong to
|
||||
the same room.
|
||||
backfilled: Whether these events are a result of
|
||||
backfilling or not
|
||||
"""
|
||||
if self.config.worker.writers.events != self._instance_name:
|
||||
instance = self.config.worker.events_shard_config.get_instance(room_id)
|
||||
if instance != self._instance_name:
|
||||
result = await self._send_events(
|
||||
instance_name=self.config.worker.writers.events,
|
||||
instance_name=instance,
|
||||
store=self.store,
|
||||
room_id=room_id,
|
||||
event_and_contexts=event_and_contexts,
|
||||
backfilled=backfilled,
|
||||
)
|
||||
|
|
|
@ -49,14 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
Collection,
|
||||
Requester,
|
||||
RoomAlias,
|
||||
StreamToken,
|
||||
UserID,
|
||||
create_requester,
|
||||
)
|
||||
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.frozenutils import frozendict_json_encoder
|
||||
|
@ -383,9 +376,8 @@ class EventCreationHandler(object):
|
|||
self.notifier = hs.get_notifier()
|
||||
self.config = hs.config
|
||||
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
|
||||
self._is_event_writer = (
|
||||
self.config.worker.writers.events == hs.get_instance_name()
|
||||
)
|
||||
self._events_shard_config = self.config.worker.events_shard_config
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self.room_invite_state_types = self.hs.config.room_invite_state_types
|
||||
|
||||
|
@ -448,7 +440,7 @@ class EventCreationHandler(object):
|
|||
event_dict: dict,
|
||||
token_id: Optional[str] = None,
|
||||
txn_id: Optional[str] = None,
|
||||
prev_event_ids: Optional[Collection[str]] = None,
|
||||
prev_event_ids: Optional[List[str]] = None,
|
||||
require_consent: bool = True,
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
"""
|
||||
|
@ -788,7 +780,7 @@ class EventCreationHandler(object):
|
|||
self,
|
||||
builder: EventBuilder,
|
||||
requester: Optional[Requester] = None,
|
||||
prev_event_ids: Optional[Collection[str]] = None,
|
||||
prev_event_ids: Optional[List[str]] = None,
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
"""Create a new event for a local client
|
||||
|
||||
|
@ -913,9 +905,10 @@ class EventCreationHandler(object):
|
|||
|
||||
try:
|
||||
# If we're a worker we need to hit out to the master.
|
||||
if not self._is_event_writer:
|
||||
writer_instance = self._events_shard_config.get_instance(event.room_id)
|
||||
if writer_instance != self._instance_name:
|
||||
result = await self.send_event(
|
||||
instance_name=self.config.worker.writers.events,
|
||||
instance_name=writer_instance,
|
||||
event_id=event.event_id,
|
||||
store=self.store,
|
||||
requester=requester,
|
||||
|
@ -983,7 +976,9 @@ class EventCreationHandler(object):
|
|||
|
||||
This should only be run on the instance in charge of persisting events.
|
||||
"""
|
||||
assert self._is_event_writer
|
||||
assert self._events_shard_config.should_handle(
|
||||
self._instance_name, event.room_id
|
||||
)
|
||||
|
||||
if ratelimit:
|
||||
# We check if this is a room admin redacting an event so that we
|
||||
|
|
|
@ -14,15 +14,18 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import Requester, RoomStreamToken
|
||||
from synapse.util.async_helpers import ReadWriteLock
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
@ -247,15 +250,16 @@ class PaginationHandler(object):
|
|||
)
|
||||
return purge_id
|
||||
|
||||
async def _purge_history(self, purge_id, room_id, token, delete_local_events):
|
||||
async def _purge_history(
|
||||
self, purge_id: str, room_id: str, token: str, delete_local_events: bool
|
||||
) -> None:
|
||||
"""Carry out a history purge on a room.
|
||||
|
||||
Args:
|
||||
purge_id (str): The id for this purge
|
||||
room_id (str): The room to purge from
|
||||
token (str): topological token to delete events before
|
||||
delete_local_events (bool): True to delete local events as well as
|
||||
remote ones
|
||||
purge_id: The id for this purge
|
||||
room_id: The room to purge from
|
||||
token: topological token to delete events before
|
||||
delete_local_events: True to delete local events as well as remote ones
|
||||
"""
|
||||
self._purges_in_progress_by_room.add(room_id)
|
||||
try:
|
||||
|
@ -291,9 +295,9 @@ class PaginationHandler(object):
|
|||
"""
|
||||
return self._purges_by_id.get(purge_id)
|
||||
|
||||
async def purge_room(self, room_id):
|
||||
async def purge_room(self, room_id: str) -> None:
|
||||
"""Purge the given room from the database"""
|
||||
with (await self.pagination_lock.write(room_id)):
|
||||
with await self.pagination_lock.write(room_id):
|
||||
# check we know about the room
|
||||
await self.store.get_room_version_id(room_id)
|
||||
|
||||
|
@ -307,23 +311,22 @@ class PaginationHandler(object):
|
|||
|
||||
async def get_messages(
|
||||
self,
|
||||
requester,
|
||||
room_id=None,
|
||||
pagin_config=None,
|
||||
as_client_event=True,
|
||||
event_filter=None,
|
||||
):
|
||||
requester: Requester,
|
||||
room_id: Optional[str] = None,
|
||||
pagin_config: Optional[PaginationConfig] = None,
|
||||
as_client_event: bool = True,
|
||||
event_filter: Optional[Filter] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Get messages in a room.
|
||||
|
||||
Args:
|
||||
requester (Requester): The user requesting messages.
|
||||
room_id (str): The room they want messages from.
|
||||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
||||
config rules to apply, if any.
|
||||
as_client_event (bool): True to get events in client-server format.
|
||||
event_filter (Filter): Filter to apply to results or None
|
||||
requester: The user requesting messages.
|
||||
room_id: The room they want messages from.
|
||||
pagin_config: The pagination config rules to apply, if any.
|
||||
as_client_event: True to get events in client-server format.
|
||||
event_filter: Filter to apply to results or None
|
||||
Returns:
|
||||
dict: Pagination API results
|
||||
Pagination API results
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
|
@ -343,7 +346,7 @@ class PaginationHandler(object):
|
|||
|
||||
source_config = pagin_config.get_source_config("room")
|
||||
|
||||
with (await self.pagination_lock.read(room_id)):
|
||||
with await self.pagination_lock.read(room_id):
|
||||
(
|
||||
membership,
|
||||
member_event_id,
|
||||
|
|
|
@ -161,6 +161,9 @@ class BaseProfileHandler(BaseHandler):
|
|||
Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
if not isinstance(new_displayname, str):
|
||||
raise SynapseError(400, "Invalid displayname")
|
||||
|
||||
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
|
||||
raise SynapseError(
|
||||
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
|
||||
|
@ -235,6 +238,9 @@ class BaseProfileHandler(BaseHandler):
|
|||
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
if not isinstance(new_avatar_url, str):
|
||||
raise SynapseError(400, "Invalid displayname")
|
||||
|
||||
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
|
||||
raise SynapseError(
|
||||
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
|
||||
|
|
|
@ -804,7 +804,9 @@ class RoomCreationHandler(BaseHandler):
|
|||
|
||||
# Always wait for room creation to progate before returning
|
||||
await self._replication.wait_for_stream_position(
|
||||
self.hs.config.worker.writers.events, "events", last_stream_id
|
||||
self.hs.config.worker.events_shard_config.get_instance(room_id),
|
||||
"events",
|
||||
last_stream_id,
|
||||
)
|
||||
|
||||
return result, last_stream_id
|
||||
|
@ -1260,10 +1262,10 @@ class RoomShutdownHandler(object):
|
|||
# We now wait for the create room to come back in via replication so
|
||||
# that we can assume that all the joins/invites have propogated before
|
||||
# we try and auto join below.
|
||||
#
|
||||
# TODO: Currently the events stream is written to from master
|
||||
await self._replication.wait_for_stream_position(
|
||||
self.hs.config.worker.writers.events, "events", stream_id
|
||||
self.hs.config.worker.events_shard_config.get_instance(new_room_id),
|
||||
"events",
|
||||
stream_id,
|
||||
)
|
||||
else:
|
||||
new_room_id = None
|
||||
|
@ -1293,7 +1295,9 @@ class RoomShutdownHandler(object):
|
|||
|
||||
# Wait for leave to come in over replication before trying to forget.
|
||||
await self._replication.wait_for_stream_position(
|
||||
self.hs.config.worker.writers.events, "events", stream_id
|
||||
self.hs.config.worker.events_shard_config.get_instance(room_id),
|
||||
"events",
|
||||
stream_id,
|
||||
)
|
||||
|
||||
await self.room_member_handler.forget(target_requester.user, room_id)
|
||||
|
|
|
@ -38,15 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict
|
|||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
from synapse.types import (
|
||||
Collection,
|
||||
JsonDict,
|
||||
Requester,
|
||||
RoomAlias,
|
||||
RoomID,
|
||||
StateMap,
|
||||
UserID,
|
||||
)
|
||||
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.distributor import user_joined_room, user_left_room
|
||||
|
||||
|
@ -91,13 +83,6 @@ class RoomMemberHandler(object):
|
|||
self._enable_lookup = hs.config.enable_3pid_lookup
|
||||
self.allow_per_room_profiles = self.config.allow_per_room_profiles
|
||||
|
||||
self._event_stream_writer_instance = hs.config.worker.writers.events
|
||||
self._is_on_event_persistence_instance = (
|
||||
self._event_stream_writer_instance == hs.get_instance_name()
|
||||
)
|
||||
if self._is_on_event_persistence_instance:
|
||||
self.persist_event_storage = hs.get_storage().persistence
|
||||
|
||||
self._join_rate_limiter_local = Ratelimiter(
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
|
||||
|
@ -185,7 +170,7 @@ class RoomMemberHandler(object):
|
|||
target: UserID,
|
||||
room_id: str,
|
||||
membership: str,
|
||||
prev_event_ids: Collection[str],
|
||||
prev_event_ids: List[str],
|
||||
txn_id: Optional[str] = None,
|
||||
ratelimit: bool = True,
|
||||
content: Optional[dict] = None,
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
@ -44,6 +44,9 @@ from synapse.util.caches.response_cache import ResponseCache
|
|||
from synapse.util.metrics import Measure, measure_func
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Debug logger for https://github.com/matrix-org/synapse/issues/4422
|
||||
|
@ -96,7 +99,12 @@ class TimelineBatch:
|
|||
__bool__ = __nonzero__ # python3
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
# We can't freeze this class, because we need to update it after it's instantiated to
|
||||
# update its unread count. This is because we calculate the unread count for a room only
|
||||
# if there are updates for it, which we check after the instance has been created.
|
||||
# This should not be a big deal because we update the notification counts afterwards as
|
||||
# well anyway.
|
||||
@attr.s(slots=True)
|
||||
class JoinedSyncResult:
|
||||
room_id = attr.ib(type=str)
|
||||
timeline = attr.ib(type=TimelineBatch)
|
||||
|
@ -105,6 +113,7 @@ class JoinedSyncResult:
|
|||
account_data = attr.ib(type=List[JsonDict])
|
||||
unread_notifications = attr.ib(type=JsonDict)
|
||||
summary = attr.ib(type=Optional[JsonDict])
|
||||
unread_count = attr.ib(type=int)
|
||||
|
||||
def __nonzero__(self) -> bool:
|
||||
"""Make the result appear empty if there are no updates. This is used
|
||||
|
@ -239,7 +248,7 @@ class SyncResult:
|
|||
|
||||
|
||||
class SyncHandler(object):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs_config = hs.config
|
||||
self.store = hs.get_datastore()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
@ -714,9 +723,8 @@ class SyncHandler(object):
|
|||
]
|
||||
|
||||
missing_hero_state = await self.store.get_events(missing_hero_event_ids)
|
||||
missing_hero_state = missing_hero_state.values()
|
||||
|
||||
for s in missing_hero_state:
|
||||
for s in missing_hero_state.values():
|
||||
cache.set(s.state_key, s.event_id)
|
||||
state[(EventTypes.Member, s.state_key)] = s
|
||||
|
||||
|
@ -934,7 +942,7 @@ class SyncHandler(object):
|
|||
|
||||
async def unread_notifs_for_room_id(
|
||||
self, room_id: str, sync_config: SyncConfig
|
||||
) -> Optional[Dict[str, str]]:
|
||||
) -> Dict[str, int]:
|
||||
with Measure(self.clock, "unread_notifs_for_room_id"):
|
||||
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
|
||||
user_id=sync_config.user.to_string(),
|
||||
|
@ -942,15 +950,10 @@ class SyncHandler(object):
|
|||
receipt_type="m.read",
|
||||
)
|
||||
|
||||
if last_unread_event_id:
|
||||
notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
|
||||
room_id, sync_config.user.to_string(), last_unread_event_id
|
||||
)
|
||||
return notifs
|
||||
|
||||
# There is no new information in this period, so your notification
|
||||
# count is whatever it was last time.
|
||||
return None
|
||||
notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
|
||||
room_id, sync_config.user.to_string(), last_unread_event_id
|
||||
)
|
||||
return notifs
|
||||
|
||||
async def generate_sync_result(
|
||||
self,
|
||||
|
@ -1773,7 +1776,7 @@ class SyncHandler(object):
|
|||
ignored_users: Set[str],
|
||||
room_builder: "RoomSyncResultBuilder",
|
||||
ephemeral: List[JsonDict],
|
||||
tags: Optional[List[JsonDict]],
|
||||
tags: Optional[Dict[str, Dict[str, Any]]],
|
||||
account_data: Dict[str, JsonDict],
|
||||
always_include: bool = False,
|
||||
):
|
||||
|
@ -1889,7 +1892,7 @@ class SyncHandler(object):
|
|||
)
|
||||
|
||||
if room_builder.rtype == "joined":
|
||||
unread_notifications = {} # type: Dict[str, str]
|
||||
unread_notifications = {} # type: Dict[str, int]
|
||||
room_sync = JoinedSyncResult(
|
||||
room_id=room_id,
|
||||
timeline=batch,
|
||||
|
@ -1898,14 +1901,16 @@ class SyncHandler(object):
|
|||
account_data=account_data_events,
|
||||
unread_notifications=unread_notifications,
|
||||
summary=summary,
|
||||
unread_count=0,
|
||||
)
|
||||
|
||||
if room_sync or always_include:
|
||||
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
|
||||
|
||||
if notifs is not None:
|
||||
unread_notifications["notification_count"] = notifs["notify_count"]
|
||||
unread_notifications["highlight_count"] = notifs["highlight_count"]
|
||||
unread_notifications["notification_count"] = notifs["notify_count"]
|
||||
unread_notifications["highlight_count"] = notifs["highlight_count"]
|
||||
|
||||
room_sync.unread_count = notifs["unread_count"]
|
||||
|
||||
sync_result_builder.joined.append(room_sync)
|
||||
|
||||
|
|
|
@ -234,7 +234,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||
async def _handle_room_publicity_change(
|
||||
self, room_id, prev_event_id, event_id, typ
|
||||
):
|
||||
"""Handle a room having potentially changed from/to world_readable/publically
|
||||
"""Handle a room having potentially changed from/to world_readable/publicly
|
||||
joinable.
|
||||
|
||||
Args:
|
||||
|
@ -388,9 +388,15 @@ class UserDirectoryHandler(StateDeltasHandler):
|
|||
|
||||
prev_name = prev_event.content.get("displayname")
|
||||
new_name = event.content.get("displayname")
|
||||
# If the new name is an unexpected form, do not update the directory.
|
||||
if not isinstance(new_name, str):
|
||||
new_name = prev_name
|
||||
|
||||
prev_avatar = prev_event.content.get("avatar_url")
|
||||
new_avatar = event.content.get("avatar_url")
|
||||
# If the new avatar is an unexpected form, do not update the directory.
|
||||
if not isinstance(new_avatar, str):
|
||||
new_avatar = prev_avatar
|
||||
|
||||
if prev_name != new_name or prev_avatar != new_avatar:
|
||||
await self.store.update_profile_in_user_dir(user_id, new_name, new_avatar)
|
||||
|
|
|
@ -134,8 +134,8 @@ class MatrixFederationAgent(object):
|
|||
and not _is_ip_literal(parsed_uri.hostname)
|
||||
and not parsed_uri.port
|
||||
):
|
||||
well_known_result = yield self._well_known_resolver.get_well_known(
|
||||
parsed_uri.hostname
|
||||
well_known_result = yield defer.ensureDeferred(
|
||||
self._well_known_resolver.get_well_known(parsed_uri.hostname)
|
||||
)
|
||||
delegated_server = well_known_result.delegated_server
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
|
@ -23,6 +24,7 @@ from twisted.internet import defer
|
|||
from twisted.web.client import RedirectAgent, readBody
|
||||
from twisted.web.http import stringToDatetime
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import IResponse
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.util import Clock, json_decoder
|
||||
|
@ -99,15 +101,14 @@ class WellKnownResolver(object):
|
|||
self._well_known_agent = RedirectAgent(agent)
|
||||
self.user_agent = user_agent
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_well_known(self, server_name):
|
||||
async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult:
|
||||
"""Attempt to fetch and parse a .well-known file for the given server
|
||||
|
||||
Args:
|
||||
server_name (bytes): name of the server, from the requested url
|
||||
server_name: name of the server, from the requested url
|
||||
|
||||
Returns:
|
||||
Deferred[WellKnownLookupResult]: The result of the lookup
|
||||
The result of the lookup
|
||||
"""
|
||||
|
||||
if server_name == b"kde.org":
|
||||
|
@ -128,7 +129,9 @@ class WellKnownResolver(object):
|
|||
# requests for the same server in parallel?
|
||||
try:
|
||||
with Measure(self._clock, "get_well_known"):
|
||||
result, cache_period = yield self._fetch_well_known(server_name)
|
||||
result, cache_period = await self._fetch_well_known(
|
||||
server_name
|
||||
) # type: Tuple[Optional[bytes], float]
|
||||
|
||||
except _FetchWellKnownFailure as e:
|
||||
if prev_result and e.temporary:
|
||||
|
@ -157,18 +160,17 @@ class WellKnownResolver(object):
|
|||
|
||||
return WellKnownLookupResult(delegated_server=result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _fetch_well_known(self, server_name):
|
||||
async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]:
|
||||
"""Actually fetch and parse a .well-known, without checking the cache
|
||||
|
||||
Args:
|
||||
server_name (bytes): name of the server, from the requested url
|
||||
server_name: name of the server, from the requested url
|
||||
|
||||
Raises:
|
||||
_FetchWellKnownFailure if we fail to lookup a result
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[bytes,int]]: The lookup result and cache period.
|
||||
The lookup result and cache period.
|
||||
"""
|
||||
|
||||
had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
|
||||
|
@ -176,7 +178,7 @@ class WellKnownResolver(object):
|
|||
# We do this in two steps to differentiate between possibly transient
|
||||
# errors (e.g. can't connect to host, 503 response) and more permenant
|
||||
# errors (such as getting a 404 response).
|
||||
response, body = yield self._make_well_known_request(
|
||||
response, body = await self._make_well_known_request(
|
||||
server_name, retry=had_valid_well_known
|
||||
)
|
||||
|
||||
|
@ -219,20 +221,20 @@ class WellKnownResolver(object):
|
|||
|
||||
return result, cache_period
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _make_well_known_request(self, server_name, retry):
|
||||
async def _make_well_known_request(
|
||||
self, server_name: bytes, retry: bool
|
||||
) -> Tuple[IResponse, bytes]:
|
||||
"""Make the well known request.
|
||||
|
||||
This will retry the request if requested and it fails (with unable
|
||||
to connect or receives a 5xx error).
|
||||
|
||||
Args:
|
||||
server_name (bytes)
|
||||
retry (bool): Whether to retry the request if it fails.
|
||||
server_name: name of the server, from the requested url
|
||||
retry: Whether to retry the request if it fails.
|
||||
|
||||
Returns:
|
||||
Deferred[tuple[IResponse, bytes]] Returns the response object and
|
||||
body. Response may be a non-200 response.
|
||||
Returns the response object and body. Response may be a non-200 response.
|
||||
"""
|
||||
uri = b"https://%s/.well-known/matrix/server" % (server_name,)
|
||||
uri_str = uri.decode("ascii")
|
||||
|
@ -247,12 +249,12 @@ class WellKnownResolver(object):
|
|||
|
||||
logger.info("Fetching %s", uri_str)
|
||||
try:
|
||||
response = yield make_deferred_yieldable(
|
||||
response = await make_deferred_yieldable(
|
||||
self._well_known_agent.request(
|
||||
b"GET", uri, headers=Headers(headers)
|
||||
)
|
||||
)
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
body = await make_deferred_yieldable(readBody(response))
|
||||
|
||||
if 500 <= response.code < 600:
|
||||
raise Exception("Non-200 response %s" % (response.code,))
|
||||
|
@ -269,21 +271,24 @@ class WellKnownResolver(object):
|
|||
logger.info("Error fetching %s: %s. Retrying", uri_str, e)
|
||||
|
||||
# Sleep briefly in the hopes that they come back up
|
||||
yield self._clock.sleep(0.5)
|
||||
await self._clock.sleep(0.5)
|
||||
|
||||
|
||||
def _cache_period_from_headers(headers, time_now=time.time):
|
||||
def _cache_period_from_headers(
|
||||
headers: Headers, time_now: Callable[[], float] = time.time
|
||||
) -> Optional[float]:
|
||||
cache_controls = _parse_cache_control(headers)
|
||||
|
||||
if b"no-store" in cache_controls:
|
||||
return 0
|
||||
|
||||
if b"max-age" in cache_controls:
|
||||
try:
|
||||
max_age = int(cache_controls[b"max-age"])
|
||||
return max_age
|
||||
except ValueError:
|
||||
pass
|
||||
max_age = cache_controls[b"max-age"]
|
||||
if max_age:
|
||||
try:
|
||||
return int(max_age)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
expires = headers.getRawHeaders(b"expires")
|
||||
if expires is not None:
|
||||
|
@ -299,7 +304,7 @@ def _cache_period_from_headers(headers, time_now=time.time):
|
|||
return None
|
||||
|
||||
|
||||
def _parse_cache_control(headers):
|
||||
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
|
||||
cache_controls = {}
|
||||
for hdr in headers.getRawHeaders(b"cache-control", []):
|
||||
for directive in hdr.split(b","):
|
||||
|
|
|
@ -19,8 +19,10 @@ from collections import namedtuple
|
|||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.constants import EventTypes, Membership, RelationTypes
|
||||
from synapse.event_auth import get_user_power_level
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.state import POWER_KEY
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import register_cache
|
||||
|
@ -51,6 +53,48 @@ push_rules_delta_state_cache_metric = register_cache(
|
|||
)
|
||||
|
||||
|
||||
STATE_EVENT_TYPES_TO_MARK_UNREAD = {
|
||||
EventTypes.Topic,
|
||||
EventTypes.Name,
|
||||
EventTypes.RoomAvatar,
|
||||
EventTypes.Tombstone,
|
||||
}
|
||||
|
||||
|
||||
def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
|
||||
# Exclude rejected and soft-failed events.
|
||||
if context.rejected or event.internal_metadata.is_soft_failed():
|
||||
return False
|
||||
|
||||
# Exclude notices.
|
||||
if (
|
||||
not event.is_state()
|
||||
and event.type == EventTypes.Message
|
||||
and event.content.get("msgtype") == "m.notice"
|
||||
):
|
||||
return False
|
||||
|
||||
# Exclude edits.
|
||||
relates_to = event.content.get("m.relates_to", {})
|
||||
if relates_to.get("rel_type") == RelationTypes.REPLACE:
|
||||
return False
|
||||
|
||||
# Mark events that have a non-empty string body as unread.
|
||||
body = event.content.get("body")
|
||||
if isinstance(body, str) and body:
|
||||
return True
|
||||
|
||||
# Mark some state events as unread.
|
||||
if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
|
||||
return True
|
||||
|
||||
# Mark encrypted events as unread.
|
||||
if not event.is_state() and event.type == EventTypes.Encrypted:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class BulkPushRuleEvaluator(object):
|
||||
"""Calculates the outcome of push rules for an event for all users in the
|
||||
room at once.
|
||||
|
@ -133,9 +177,12 @@ class BulkPushRuleEvaluator(object):
|
|||
return pl_event.content if pl_event else {}, sender_level
|
||||
|
||||
async def action_for_event_by_user(self, event, context) -> None:
|
||||
"""Given an event and context, evaluate the push rules and insert the
|
||||
results into the event_push_actions_staging table.
|
||||
"""Given an event and context, evaluate the push rules, check if the message
|
||||
should increment the unread count, and insert the results into the
|
||||
event_push_actions_staging table.
|
||||
"""
|
||||
count_as_unread = _should_count_as_unread(event, context)
|
||||
|
||||
rules_by_user = await self._get_rules_for_event(event, context)
|
||||
actions_by_user = {}
|
||||
|
||||
|
@ -172,6 +219,8 @@ class BulkPushRuleEvaluator(object):
|
|||
if event.type == EventTypes.Member and event.state_key == uid:
|
||||
display_name = event.content.get("displayname", None)
|
||||
|
||||
actions_by_user[uid] = []
|
||||
|
||||
for rule in rules:
|
||||
if "enabled" in rule and not rule["enabled"]:
|
||||
continue
|
||||
|
@ -189,7 +238,9 @@ class BulkPushRuleEvaluator(object):
|
|||
# Mark in the DB staging area the push actions for users who should be
|
||||
# notified for this event. (This will then get handled when we persist
|
||||
# the event)
|
||||
await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
|
||||
await self.store.add_push_actions_to_staging(
|
||||
event.event_id, actions_by_user, count_as_unread,
|
||||
)
|
||||
|
||||
|
||||
def _condition_checker(evaluator, conditions, uid, display_name, cache):
|
||||
|
@ -369,8 +420,8 @@ class RulesForRoom(object):
|
|||
Args:
|
||||
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
|
||||
updated with any new rules.
|
||||
member_event_ids (list): List of event ids for membership events that
|
||||
have happened since the last time we filled rules_by_user
|
||||
member_event_ids (dict): Dict of user id to event id for membership events
|
||||
that have happened since the last time we filled rules_by_user
|
||||
state_group: The state group we are currently computing push rules
|
||||
for. Used when updating the cache.
|
||||
"""
|
||||
|
@ -390,34 +441,19 @@ class RulesForRoom(object):
|
|||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("Found members %r: %r", self.room_id, members.values())
|
||||
|
||||
interested_in_user_ids = {
|
||||
user_ids = {
|
||||
user_id
|
||||
for user_id, membership in members.values()
|
||||
if membership == Membership.JOIN
|
||||
}
|
||||
|
||||
logger.debug("Joined: %r", interested_in_user_ids)
|
||||
logger.debug("Joined: %r", user_ids)
|
||||
|
||||
if_users_with_pushers = await self.store.get_if_users_have_pushers(
|
||||
interested_in_user_ids, on_invalidate=self.invalidate_all_cb
|
||||
)
|
||||
|
||||
user_ids = {
|
||||
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
|
||||
}
|
||||
|
||||
logger.debug("With pushers: %r", user_ids)
|
||||
|
||||
users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
|
||||
self.room_id, on_invalidate=self.invalidate_all_cb
|
||||
)
|
||||
|
||||
logger.debug("With receipts: %r", users_with_receipts)
|
||||
|
||||
# any users with pushers must be ours: they have pushers
|
||||
for uid in users_with_receipts:
|
||||
if uid in interested_in_user_ids:
|
||||
user_ids.add(uid)
|
||||
# Previously we only considered users with pushers or read receipts in that
|
||||
# room. We can't do this anymore because we use push actions to calculate unread
|
||||
# counts, which don't rely on the user having pushers or sent a read receipt into
|
||||
# the room. Therefore we just need to filter for local users here.
|
||||
user_ids = list(filter(self.is_mine_id, user_ids))
|
||||
|
||||
rules_by_user = await self.store.bulk_get_push_rules(
|
||||
user_ids, on_invalidate=self.invalidate_all_cb
|
||||
|
|
|
@ -36,7 +36,7 @@ async def get_badge_count(store, user_id):
|
|||
)
|
||||
# return one badge count per conversation, as count per
|
||||
# message is so noisy as to be almost useless
|
||||
badge += 1 if notifs["notify_count"] else 0
|
||||
badge += 1 if notifs["unread_count"] else 0
|
||||
return badge
|
||||
|
||||
|
||||
|
|
|
@ -66,7 +66,9 @@ REQUIREMENTS = [
|
|||
"msgpack>=0.5.2",
|
||||
"phonenumbers>=8.2.0",
|
||||
"prometheus_client>=0.0.18,<0.9.0",
|
||||
# we use attr.validators.deep_iterable, which arrived in 19.1.0
|
||||
# we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note:
|
||||
# Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33
|
||||
# is out in November.)
|
||||
"attrs>=19.1.0",
|
||||
"netaddr>=0.7.18",
|
||||
"Jinja2>=2.9",
|
||||
|
|
|
@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
self.federation_handler = hs.get_handlers().federation_handler
|
||||
|
||||
@staticmethod
|
||||
async def _serialize_payload(store, event_and_contexts, backfilled):
|
||||
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
|
||||
"""
|
||||
Args:
|
||||
store
|
||||
room_id (str)
|
||||
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
|
||||
backfilled (bool): Whether or not the events are the result of
|
||||
backfilling
|
||||
|
@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
}
|
||||
)
|
||||
|
||||
payload = {"events": event_payloads, "backfilled": backfilled}
|
||||
payload = {
|
||||
"events": event_payloads,
|
||||
"backfilled": backfilled,
|
||||
"room_id": room_id,
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
|
@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
with Measure(self.clock, "repl_fed_send_events_parse"):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
room_id = content["room_id"]
|
||||
backfilled = content["backfilled"]
|
||||
|
||||
event_payloads = content["events"]
|
||||
|
@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
|||
logger.info("Got %d events from federation", len(event_and_contexts))
|
||||
|
||||
max_stream_id = await self.federation_handler.persist_events_and_notify(
|
||||
event_and_contexts, backfilled
|
||||
room_id, event_and_contexts, backfilled
|
||||
)
|
||||
|
||||
return 200, {"max_stream_id": max_stream_id}
|
||||
|
|
|
@ -48,6 +48,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
|||
"DeviceListFederationStreamChangeCache", device_list_max
|
||||
)
|
||||
|
||||
def get_device_stream_token(self) -> int:
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# limitations under the License.
|
||||
"""A replication client for use by synapse workers.
|
||||
"""
|
||||
import heapq
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple
|
||||
|
||||
|
@ -219,9 +218,8 @@ class ReplicationDataHandler:
|
|||
|
||||
waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
|
||||
|
||||
# We insert into the list using heapq as it is more efficient than
|
||||
# pushing then resorting each time.
|
||||
heapq.heappush(waiting_list, (position, deferred))
|
||||
waiting_list.append((position, deferred))
|
||||
waiting_list.sort(key=lambda t: t[0])
|
||||
|
||||
# We measure here to get in flight counts and average waiting time.
|
||||
with Measure(self._clock, "repl.wait_for_stream_position"):
|
||||
|
|
|
@ -109,7 +109,7 @@ class ReplicationCommandHandler:
|
|||
if isinstance(stream, (EventsStream, BackfillStream)):
|
||||
# Only add EventStream and BackfillStream as a source on the
|
||||
# instance in charge of event persistence.
|
||||
if hs.config.worker.writers.events == hs.get_instance_name():
|
||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||
self._streams_to_replicate.append(stream)
|
||||
|
||||
continue
|
||||
|
|
|
@ -19,7 +19,7 @@ from typing import List, Tuple, Type
|
|||
|
||||
import attr
|
||||
|
||||
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
|
||||
from ._base import Stream, StreamUpdateResult, Token
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
|
||||
|
@ -117,7 +117,7 @@ class EventsStream(Stream):
|
|||
self._store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(self._store.get_current_events_token),
|
||||
self._store._stream_id_gen.get_current_token_for_writer,
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import (
|
|||
room_keys,
|
||||
room_upgrade_rest_servlet,
|
||||
sendtodevice,
|
||||
shared_rooms,
|
||||
sync,
|
||||
tags,
|
||||
thirdparty,
|
||||
|
@ -125,3 +126,6 @@ class ClientRestResource(JsonResource):
|
|||
synapse.rest.admin.register_servlets_for_client_rest_resource(
|
||||
hs, client_resource
|
||||
)
|
||||
|
||||
# unstable
|
||||
shared_rooms.register_servlets(hs, client_resource)
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Half-Shot
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.types import UserID
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserSharedRoomsServlet(RestServlet):
|
||||
"""
|
||||
GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)",
|
||||
releases=(), # This is an unstable feature
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(UserSharedRoomsServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.user_directory_active = hs.config.update_user_directory
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
|
||||
if not self.user_directory_active:
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg="The user directory is disabled on this server. Cannot determine shared rooms.",
|
||||
errcode=Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
UserID.from_string(user_id)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
if user_id == requester.user.to_string():
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg="You cannot request a list of shared rooms with yourself",
|
||||
errcode=Codes.FORBIDDEN,
|
||||
)
|
||||
rooms = await self.store.get_shared_rooms_for_users(
|
||||
requester.user.to_string(), user_id
|
||||
)
|
||||
|
||||
return 200, {"joined": list(rooms)}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
UserSharedRoomsServlet(hs).register(http_server)
|
|
@ -425,6 +425,7 @@ class SyncRestServlet(RestServlet):
|
|||
result["ephemeral"] = {"events": ephemeral_events}
|
||||
result["unread_notifications"] = room.unread_notifications
|
||||
result["summary"] = room.summary
|
||||
result["org.matrix.msc2654.unread_count"] = room.unread_count
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
@ -60,6 +60,8 @@ class VersionsRestServlet(RestServlet):
|
|||
"org.matrix.e2e_cross_signing": True,
|
||||
# Implements additional endpoints as described in MSC2432
|
||||
"org.matrix.msc2432": True,
|
||||
# Implements additional endpoints as described in MSC2666
|
||||
"uk.half-shot.msc2666": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
|
|
@ -433,7 +433,7 @@ class BackgroundUpdater(object):
|
|||
"background_updates", keyvalues={"update_name": update_name}
|
||||
)
|
||||
|
||||
def _background_update_progress(self, update_name: str, progress: dict):
|
||||
async def _background_update_progress(self, update_name: str, progress: dict):
|
||||
"""Update the progress of a background update
|
||||
|
||||
Args:
|
||||
|
@ -441,7 +441,7 @@ class BackgroundUpdater(object):
|
|||
progress: The progress of the update.
|
||||
"""
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"background_update_progress",
|
||||
self._background_update_progress_txn,
|
||||
update_name,
|
||||
|
|
|
@ -28,6 +28,7 @@ from typing import (
|
|||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
|
@ -35,7 +36,6 @@ from prometheus_client import Histogram
|
|||
from typing_extensions import Literal
|
||||
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.config.database import DatabaseConnectionConfig
|
||||
|
@ -507,8 +507,9 @@ class DatabasePool(object):
|
|||
self._txn_perf_counters.update(desc, duration)
|
||||
sql_txn_timer.labels(desc).observe(duration)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
|
||||
async def runInteraction(
|
||||
self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
|
||||
) -> R:
|
||||
"""Starts a transaction on the database and runs a given function
|
||||
|
||||
Arguments:
|
||||
|
@ -521,7 +522,7 @@ class DatabasePool(object):
|
|||
kwargs: named args to pass to `func`
|
||||
|
||||
Returns:
|
||||
Deferred: The result of func
|
||||
The result of func
|
||||
"""
|
||||
after_callbacks = [] # type: List[_CallbackListEntry]
|
||||
exception_callbacks = [] # type: List[_CallbackListEntry]
|
||||
|
@ -530,16 +531,14 @@ class DatabasePool(object):
|
|||
logger.warning("Starting db txn '%s' from sentinel context", desc)
|
||||
|
||||
try:
|
||||
result = yield defer.ensureDeferred(
|
||||
self.runWithConnection(
|
||||
self.new_transaction,
|
||||
desc,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
result = await self.runWithConnection(
|
||||
self.new_transaction,
|
||||
desc,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
|
@ -549,7 +548,7 @@ class DatabasePool(object):
|
|||
after_callback(*after_args, **after_kwargs)
|
||||
raise
|
||||
|
||||
return result
|
||||
return cast(R, result)
|
||||
|
||||
async def runWithConnection(
|
||||
self, func: "Callable[..., R]", *args: Any, **kwargs: Any
|
||||
|
@ -604,6 +603,18 @@ class DatabasePool(object):
|
|||
results = [dict(zip(col_headers, row)) for row in cursor]
|
||||
return results
|
||||
|
||||
@overload
|
||||
async def execute(
|
||||
self, desc: str, decoder: Literal[None], query: str, *args: Any
|
||||
) -> List[Tuple[Any, ...]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def execute(
|
||||
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
|
||||
) -> R:
|
||||
...
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
desc: str,
|
||||
|
@ -1088,6 +1099,28 @@ class DatabasePool(object):
|
|||
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
|
||||
)
|
||||
|
||||
@overload
|
||||
async def simple_select_one_onecol(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
allow_none: Literal[False] = False,
|
||||
desc: str = "simple_select_one_onecol",
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def simple_select_one_onecol(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
allow_none: Literal[True] = True,
|
||||
desc: str = "simple_select_one_onecol",
|
||||
) -> Optional[Any]:
|
||||
...
|
||||
|
||||
async def simple_select_one_onecol(
|
||||
self,
|
||||
table: str,
|
||||
|
@ -1116,6 +1149,30 @@ class DatabasePool(object):
|
|||
allow_none=allow_none,
|
||||
)
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def simple_select_one_onecol_txn(
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
allow_none: Literal[False] = False,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def simple_select_one_onecol_txn(
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
allow_none: Literal[True] = True,
|
||||
) -> Optional[Any]:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def simple_select_one_onecol_txn(
|
||||
cls,
|
||||
|
|
|
@ -68,7 +68,7 @@ class Databases(object):
|
|||
|
||||
# If we're on a process that can persist events also
|
||||
# instantiate a `PersistEventsStore`
|
||||
if hs.config.worker.writers.events == hs.get_instance_name():
|
||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||
persist_events = PersistEventsStore(hs, database, main)
|
||||
|
||||
if "state" in database_config.databases:
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
import calendar
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import PresenceState
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
|
@ -264,6 +264,9 @@ class DataStore(
|
|||
# Used in _generate_user_daily_visits to keep track of progress
|
||||
self._last_user_visit_update = self._get_start_of_day()
|
||||
|
||||
def get_device_stream_token(self) -> int:
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
|
||||
def take_presence_startup_info(self):
|
||||
active_on_startup = self._presence_on_startup
|
||||
self._presence_on_startup = None
|
||||
|
@ -291,16 +294,16 @@ class DataStore(
|
|||
|
||||
return [UserPresenceState(**row) for row in rows]
|
||||
|
||||
def count_daily_users(self):
|
||||
async def count_daily_users(self) -> int:
|
||||
"""
|
||||
Counts the number of users who used this homeserver in the last 24 hours.
|
||||
"""
|
||||
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"count_daily_users", self._count_users, yesterday
|
||||
)
|
||||
|
||||
def count_monthly_users(self):
|
||||
async def count_monthly_users(self) -> int:
|
||||
"""
|
||||
Counts the number of users who used this homeserver in the last 30 days.
|
||||
Note this method is intended for phonehome metrics only and is different
|
||||
|
@ -308,7 +311,7 @@ class DataStore(
|
|||
amongst other things, includes a 3 day grace period before a user counts.
|
||||
"""
|
||||
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"count_monthly_users", self._count_users, thirty_days_ago
|
||||
)
|
||||
|
||||
|
@ -327,15 +330,15 @@ class DataStore(
|
|||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
def count_r30_users(self):
|
||||
async def count_r30_users(self) -> Dict[str, int]:
|
||||
"""
|
||||
Counts the number of 30 day retained users, defined as:-
|
||||
* Users who have created their accounts more than 30 days ago
|
||||
* Where last seen at most 30 days ago
|
||||
* Where account creation and last_seen are > 30 days apart
|
||||
|
||||
Returns counts globaly for a given user as well as breaking
|
||||
by platform
|
||||
Returns:
|
||||
A mapping of counts globally as well as broken out by platform.
|
||||
"""
|
||||
|
||||
def _count_r30_users(txn):
|
||||
|
@ -408,7 +411,7 @@ class DataStore(
|
|||
|
||||
return results
|
||||
|
||||
return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
|
||||
return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
|
||||
|
||||
def _get_start_of_day(self):
|
||||
"""
|
||||
|
@ -418,7 +421,7 @@ class DataStore(
|
|||
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
|
||||
return today_start * 1000
|
||||
|
||||
def generate_user_daily_visits(self):
|
||||
async def generate_user_daily_visits(self) -> None:
|
||||
"""
|
||||
Generates daily visit data for use in cohort/ retention analysis
|
||||
"""
|
||||
|
@ -473,7 +476,7 @@ class DataStore(
|
|||
# frequently
|
||||
self._last_user_visit_update = now
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"generate_user_daily_visits", _generate_user_daily_visits
|
||||
)
|
||||
|
||||
|
@ -497,22 +500,28 @@ class DataStore(
|
|||
desc="get_users",
|
||||
)
|
||||
|
||||
def get_users_paginate(
|
||||
self, start, limit, user_id=None, name=None, guests=True, deactivated=False
|
||||
):
|
||||
async def get_users_paginate(
|
||||
self,
|
||||
start: int,
|
||||
limit: int,
|
||||
user_id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
guests: bool = True,
|
||||
deactivated: bool = False,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""Function to retrieve a paginated list of users from
|
||||
users list. This will return a json list of users and the
|
||||
total number of users matching the filter criteria.
|
||||
|
||||
Args:
|
||||
start (int): start number to begin the query from
|
||||
limit (int): number of rows to retrieve
|
||||
user_id (string): search for user_id. ignored if name is not None
|
||||
name (string): search for local part of user_id or display name
|
||||
guests (bool): whether to in include guest users
|
||||
deactivated (bool): whether to include deactivated users
|
||||
start: start number to begin the query from
|
||||
limit: number of rows to retrieve
|
||||
user_id: search for user_id. ignored if name is not None
|
||||
name: search for local part of user_id or display name
|
||||
guests: whether to in include guest users
|
||||
deactivated: whether to include deactivated users
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]], int
|
||||
A tuple of a list of mappings from user to information and a count of total users.
|
||||
"""
|
||||
|
||||
def get_users_paginate_txn(txn):
|
||||
|
@ -555,7 +564,7 @@ class DataStore(
|
|||
users = self.db_pool.cursor_to_dict(txn)
|
||||
return users, count
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_users_paginate_txn", get_users_paginate_txn
|
||||
)
|
||||
|
||||
|
|
|
@ -16,9 +16,7 @@
|
|||
|
||||
import abc
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
raise NotImplementedError()
|
||||
|
||||
@cached()
|
||||
def get_account_data_for_user(self, user_id):
|
||||
async def get_account_data_for_user(
|
||||
self, user_id: str
|
||||
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
|
||||
"""Get all the client account_data for a user.
|
||||
|
||||
Args:
|
||||
user_id(str): The user to get the account_data for.
|
||||
user_id: The user to get the account_data for.
|
||||
Returns:
|
||||
A deferred pair of a dict of global account_data and a dict
|
||||
mapping from room_id string to per room account_data dicts.
|
||||
A 2-tuple of a dict of global account_data and a dict mapping from
|
||||
room_id string to per room account_data dicts.
|
||||
"""
|
||||
|
||||
def get_account_data_for_user_txn(txn):
|
||||
|
@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
|
||||
return global_account_data, by_room
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_account_data_for_user", get_account_data_for_user_txn
|
||||
)
|
||||
|
||||
|
@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
return None
|
||||
|
||||
@cached(num_args=2)
|
||||
def get_account_data_for_room(self, user_id, room_id):
|
||||
async def get_account_data_for_room(
|
||||
self, user_id: str, room_id: str
|
||||
) -> Dict[str, JsonDict]:
|
||||
"""Get all the client account_data for a user for a room.
|
||||
|
||||
Args:
|
||||
user_id(str): The user to get the account_data for.
|
||||
room_id(str): The room to get the account_data for.
|
||||
user_id: The user to get the account_data for.
|
||||
room_id: The room to get the account_data for.
|
||||
Returns:
|
||||
A deferred dict of the room account_data
|
||||
A dict of the room account_data
|
||||
"""
|
||||
|
||||
def get_account_data_for_room_txn(txn):
|
||||
|
@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
row["account_data_type"]: db_to_json(row["content"]) for row in rows
|
||||
}
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_account_data_for_room", get_account_data_for_room_txn
|
||||
)
|
||||
|
||||
@cached(num_args=3, max_entries=5000)
|
||||
def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
|
||||
async def get_account_data_for_room_and_type(
|
||||
self, user_id: str, room_id: str, account_data_type: str
|
||||
) -> Optional[JsonDict]:
|
||||
"""Get the client account_data of given type for a user for a room.
|
||||
|
||||
Args:
|
||||
user_id(str): The user to get the account_data for.
|
||||
room_id(str): The room to get the account_data for.
|
||||
account_data_type (str): The account data type to get.
|
||||
user_id: The user to get the account_data for.
|
||||
room_id: The room to get the account_data for.
|
||||
account_data_type: The account data type to get.
|
||||
Returns:
|
||||
A deferred of the room account_data for that type, or None if
|
||||
there isn't any set.
|
||||
The room account_data for that type, or None if there isn't any set.
|
||||
"""
|
||||
|
||||
def get_account_data_for_room_and_type_txn(txn):
|
||||
|
@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
|
||||
return db_to_json(content_json) if content_json else None
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
|
||||
)
|
||||
|
||||
|
@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
"get_updated_room_account_data", get_updated_room_account_data_txn
|
||||
)
|
||||
|
||||
def get_updated_account_data_for_user(self, user_id, stream_id):
|
||||
async def get_updated_account_data_for_user(
|
||||
self, user_id: str, stream_id: int
|
||||
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
|
||||
"""Get all the client account_data for a that's changed for a user
|
||||
|
||||
Args:
|
||||
user_id(str): The user to get the account_data for.
|
||||
stream_id(int): The point in the stream since which to get updates
|
||||
user_id: The user to get the account_data for.
|
||||
stream_id: The point in the stream since which to get updates
|
||||
Returns:
|
||||
A deferred pair of a dict of global account_data and a dict
|
||||
mapping from room_id string to per room account_data dicts.
|
||||
|
@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
|
|||
user_id, int(stream_id)
|
||||
)
|
||||
if not changed:
|
||||
return defer.succeed(({}, {}))
|
||||
return ({}, {})
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
|
||||
)
|
||||
|
||||
|
@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
|||
|
||||
return self._account_data_id_gen.get_current_token()
|
||||
|
||||
def _update_max_stream_id(self, next_id: int):
|
||||
async def _update_max_stream_id(self, next_id: int) -> None:
|
||||
"""Update the max stream_id
|
||||
|
||||
Args:
|
||||
|
@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
|
|||
)
|
||||
txn.execute(update_max_id_sql, (next_id, next_id))
|
||||
|
||||
return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
|
||||
await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
|
||||
|
|
|
@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
|||
self._batch_row_update[key] = (user_agent, device_id, now)
|
||||
|
||||
@wrap_as_background_process("update_client_ips")
|
||||
def _update_client_ips_batch(self):
|
||||
async def _update_client_ips_batch(self) -> None:
|
||||
|
||||
# If the DB pool has already terminated, don't try updating
|
||||
if not self.db_pool.is_running():
|
||||
|
@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
|||
to_update = self._batch_row_update
|
||||
self._batch_row_update = {}
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
|
||||
)
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
|
@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
update included in the response), and the list of updates, where
|
||||
each update is a pair of EDU type and EDU contents.
|
||||
"""
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
now_stream_id = self.get_device_stream_token()
|
||||
|
||||
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
|
||||
destination, int(from_stream_id)
|
||||
|
@ -254,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
List of objects representing an device update EDU
|
||||
"""
|
||||
devices = (
|
||||
await self.db_pool.runInteraction(
|
||||
"_get_e2e_device_keys_txn",
|
||||
self._get_e2e_device_keys_txn,
|
||||
await self.get_e2e_device_keys_and_signatures(
|
||||
query_map.keys(),
|
||||
include_all_devices=True,
|
||||
include_deleted_devices=True,
|
||||
|
@ -292,17 +291,17 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
prev_id = stream_id
|
||||
|
||||
if device is not None:
|
||||
key_json = device.get("key_json", None)
|
||||
key_json = device.key_json
|
||||
if key_json:
|
||||
result["keys"] = db_to_json(key_json)
|
||||
|
||||
if "signatures" in device:
|
||||
for sig_user_id, sigs in device["signatures"].items():
|
||||
if device.signatures:
|
||||
for sig_user_id, sigs in device.signatures.items():
|
||||
result["keys"].setdefault("signatures", {}).setdefault(
|
||||
sig_user_id, {}
|
||||
).update(sigs)
|
||||
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
device_display_name = device.display_name
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
else:
|
||||
|
@ -312,9 +311,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
|
||||
return results
|
||||
|
||||
def _get_last_device_update_for_remote_user(
|
||||
async def _get_last_device_update_for_remote_user(
|
||||
self, destination: str, user_id: str, from_stream_id: int
|
||||
):
|
||||
) -> int:
|
||||
def f(txn):
|
||||
prev_sent_id_sql = """
|
||||
SELECT coalesce(max(stream_id), 0) as stream_id
|
||||
|
@ -325,12 +324,16 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
rows = txn.fetchall()
|
||||
return rows[0][0]
|
||||
|
||||
return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_last_device_update_for_remote_user", f
|
||||
)
|
||||
|
||||
def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
|
||||
async def mark_as_sent_devices_by_remote(
|
||||
self, destination: str, stream_id: int
|
||||
) -> None:
|
||||
"""Mark that updates have successfully been sent to the destination.
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"mark_as_sent_devices_by_remote",
|
||||
self._mark_as_sent_devices_by_remote_txn,
|
||||
destination,
|
||||
|
@ -412,8 +415,10 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
},
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_device_stream_token(self) -> int:
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
"""Get the current stream id from the _device_list_id_gen"""
|
||||
...
|
||||
|
||||
@trace
|
||||
async def get_user_devices_from_cache(
|
||||
|
@ -481,51 +486,6 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
device["device_id"]: db_to_json(device["content"]) for device in devices
|
||||
}
|
||||
|
||||
def get_devices_with_keys_by_user(self, user_id: str):
|
||||
"""Get all devices (with any device keys) for a user
|
||||
|
||||
Returns:
|
||||
Deferred which resolves to (stream_id, devices)
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
"get_devices_with_keys_by_user",
|
||||
self._get_devices_with_keys_by_user_txn,
|
||||
user_id,
|
||||
)
|
||||
|
||||
def _get_devices_with_keys_by_user_txn(
|
||||
self, txn: LoggingTransaction, user_id: str
|
||||
) -> Tuple[int, List[JsonDict]]:
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
|
||||
|
||||
if devices:
|
||||
user_devices = devices[user_id]
|
||||
results = []
|
||||
for device_id, device in user_devices.items():
|
||||
result = {"device_id": device_id}
|
||||
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = db_to_json(key_json)
|
||||
|
||||
if "signatures" in device:
|
||||
for sig_user_id, sigs in device["signatures"].items():
|
||||
result["keys"].setdefault("signatures", {}).setdefault(
|
||||
sig_user_id, {}
|
||||
).update(sigs)
|
||||
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
|
||||
results.append(result)
|
||||
|
||||
return now_stream_id, results
|
||||
|
||||
return now_stream_id, []
|
||||
|
||||
async def get_users_whose_devices_changed(
|
||||
self, from_key: str, user_ids: Iterable[str]
|
||||
) -> Set[str]:
|
||||
|
@ -726,7 +686,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
desc="make_remote_user_device_cache_as_stale",
|
||||
)
|
||||
|
||||
def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
|
||||
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
|
||||
"""Mark that we no longer track device lists for remote user.
|
||||
"""
|
||||
|
||||
|
@ -740,7 +700,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"mark_remote_user_device_list_as_unsubscribed",
|
||||
_mark_remote_user_device_list_as_unsubscribed_txn,
|
||||
)
|
||||
|
@ -1001,9 +961,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
desc="update_device",
|
||||
)
|
||||
|
||||
def update_remote_device_list_cache_entry(
|
||||
async def update_remote_device_list_cache_entry(
|
||||
self, user_id: str, device_id: str, content: JsonDict, stream_id: int
|
||||
):
|
||||
) -> None:
|
||||
"""Updates a single device in the cache of a remote user's devicelist.
|
||||
|
||||
Note: assumes that we are the only thread that can be updating this user's
|
||||
|
@ -1014,11 +974,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
device_id: ID of decivice being updated
|
||||
content: new data on this device
|
||||
stream_id: the version of the device list
|
||||
|
||||
Returns:
|
||||
Deferred[None]
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"update_remote_device_list_cache_entry",
|
||||
self._update_remote_device_list_cache_entry_txn,
|
||||
user_id,
|
||||
|
@ -1070,9 +1027,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
lock=False,
|
||||
)
|
||||
|
||||
def update_remote_device_list_cache(
|
||||
async def update_remote_device_list_cache(
|
||||
self, user_id: str, devices: List[dict], stream_id: int
|
||||
):
|
||||
) -> None:
|
||||
"""Replace the entire cache of the remote user's devices.
|
||||
|
||||
Note: assumes that we are the only thread that can be updating this user's
|
||||
|
@ -1082,11 +1039,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
user_id: User to update device list for
|
||||
devices: list of device objects supplied over federation
|
||||
stream_id: the version of the device list
|
||||
|
||||
Returns:
|
||||
Deferred[None]
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"update_remote_device_list_cache",
|
||||
self._update_remote_device_list_cache_txn,
|
||||
user_id,
|
||||
|
@ -1096,7 +1050,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
|
||||
def _update_remote_device_list_cache_txn(
|
||||
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
|
||||
):
|
||||
) -> None:
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
|
||||
)
|
||||
|
|
|
@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
|
|||
|
||||
return room_id
|
||||
|
||||
def update_aliases_for_room(
|
||||
async def update_aliases_for_room(
|
||||
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Repoint all of the aliases for a given room, to a different room.
|
||||
|
||||
Args:
|
||||
|
@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
|
|||
txn, self.get_aliases_for_room, (new_room_id,)
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
|
||||
)
|
||||
|
|
|
@ -14,8 +14,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.enterprise.adbapi import Connection
|
||||
|
@ -23,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
|
|||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import make_in_list_sql_clause
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
@ -31,19 +34,67 @@ if TYPE_CHECKING:
|
|||
from synapse.handlers.e2e_keys import SignatureListItem
|
||||
|
||||
|
||||
@attr.s
|
||||
class DeviceKeyLookupResult:
|
||||
"""The type returned by get_e2e_device_keys_and_signatures"""
|
||||
|
||||
display_name = attr.ib(type=Optional[str])
|
||||
|
||||
# the key data from e2e_device_keys_json. Typically includes fields like
|
||||
# "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
|
||||
# key) and "signatures" (a signature of the structure by the ed25519 key)
|
||||
key_json = attr.ib(type=Optional[str])
|
||||
|
||||
# cross-signing sigs
|
||||
signatures = attr.ib(type=Optional[Dict], default=None)
|
||||
|
||||
|
||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
async def get_e2e_device_keys_for_federation_query(
|
||||
self, user_id: str
|
||||
) -> Tuple[int, List[JsonDict]]:
|
||||
"""Get all devices (with any device keys) for a user
|
||||
|
||||
Returns:
|
||||
(stream_id, devices)
|
||||
"""
|
||||
now_stream_id = self.get_device_stream_token()
|
||||
|
||||
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
|
||||
|
||||
if devices:
|
||||
user_devices = devices[user_id]
|
||||
results = []
|
||||
for device_id, device in user_devices.items():
|
||||
result = {"device_id": device_id}
|
||||
|
||||
key_json = device.key_json
|
||||
if key_json:
|
||||
result["keys"] = db_to_json(key_json)
|
||||
|
||||
if device.signatures:
|
||||
for sig_user_id, sigs in device.signatures.items():
|
||||
result["keys"].setdefault("signatures", {}).setdefault(
|
||||
sig_user_id, {}
|
||||
).update(sigs)
|
||||
|
||||
device_display_name = device.display_name
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
|
||||
results.append(result)
|
||||
|
||||
return now_stream_id, results
|
||||
|
||||
return now_stream_id, []
|
||||
|
||||
@trace
|
||||
async def get_e2e_device_keys(
|
||||
self, query_list, include_all_devices=False, include_deleted_devices=False
|
||||
):
|
||||
"""Fetch a list of device keys.
|
||||
async def get_e2e_device_keys_for_cs_api(
|
||||
self, query_list: List[Tuple[str, Optional[str]]]
|
||||
) -> Dict[str, Dict[str, JsonDict]]:
|
||||
"""Fetch a list of device keys, formatted suitably for the C/S API.
|
||||
Args:
|
||||
query_list(list): List of pairs of user_ids and device_ids.
|
||||
include_all_devices (bool): whether to include entries for devices
|
||||
that don't have device keys
|
||||
include_deleted_devices (bool): whether to include null entries for
|
||||
devices which no longer exist (but were in the query_list).
|
||||
This option only takes effect if include_all_devices is true.
|
||||
Returns:
|
||||
Dict mapping from user-id to dict mapping from device_id to
|
||||
key data. The key data will be a dict in the same format as the
|
||||
|
@ -53,13 +104,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
if not query_list:
|
||||
return {}
|
||||
|
||||
results = await self.db_pool.runInteraction(
|
||||
"get_e2e_device_keys",
|
||||
self._get_e2e_device_keys_txn,
|
||||
query_list,
|
||||
include_all_devices,
|
||||
include_deleted_devices,
|
||||
)
|
||||
results = await self.get_e2e_device_keys_and_signatures(query_list)
|
||||
|
||||
# Build the result structure, un-jsonify the results, and add the
|
||||
# "unsigned" section
|
||||
|
@ -67,13 +112,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
for user_id, device_keys in results.items():
|
||||
rv[user_id] = {}
|
||||
for device_id, device_info in device_keys.items():
|
||||
r = db_to_json(device_info.pop("key_json"))
|
||||
r = db_to_json(device_info.key_json)
|
||||
r["unsigned"] = {}
|
||||
display_name = device_info["device_display_name"]
|
||||
display_name = device_info.display_name
|
||||
if display_name is not None:
|
||||
r["unsigned"]["device_display_name"] = display_name
|
||||
if "signatures" in device_info:
|
||||
for sig_user_id, sigs in device_info["signatures"].items():
|
||||
if device_info.signatures:
|
||||
for sig_user_id, sigs in device_info.signatures.items():
|
||||
r.setdefault("signatures", {}).setdefault(
|
||||
sig_user_id, {}
|
||||
).update(sigs)
|
||||
|
@ -82,12 +127,45 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
return rv
|
||||
|
||||
@trace
|
||||
def _get_e2e_device_keys_txn(
|
||||
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
|
||||
):
|
||||
async def get_e2e_device_keys_and_signatures(
|
||||
self,
|
||||
query_list: List[Tuple[str, Optional[str]]],
|
||||
include_all_devices: bool = False,
|
||||
include_deleted_devices: bool = False,
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
"""Fetch a list of device keys, together with their cross-signatures.
|
||||
|
||||
Args:
|
||||
query_list: List of pairs of user_ids and device_ids. Device id can be None
|
||||
to indicate "all devices for this user"
|
||||
|
||||
include_all_devices: whether to return devices without device keys
|
||||
|
||||
include_deleted_devices: whether to include null entries for
|
||||
devices which no longer exist (but were in the query_list).
|
||||
This option only takes effect if include_all_devices is true.
|
||||
|
||||
Returns:
|
||||
Dict mapping from user-id to dict mapping from device_id to
|
||||
key data.
|
||||
"""
|
||||
set_tag("include_all_devices", include_all_devices)
|
||||
set_tag("include_deleted_devices", include_deleted_devices)
|
||||
|
||||
result = await self.db_pool.runInteraction(
|
||||
"get_e2e_device_keys",
|
||||
self._get_e2e_device_keys_and_signatures_txn,
|
||||
query_list,
|
||||
include_all_devices,
|
||||
include_deleted_devices,
|
||||
)
|
||||
|
||||
log_kv(result)
|
||||
return result
|
||||
|
||||
def _get_e2e_device_keys_and_signatures_txn(
|
||||
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
query_clauses = []
|
||||
query_params = []
|
||||
signature_query_clauses = []
|
||||
|
@ -119,7 +197,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
|
||||
sql = (
|
||||
"SELECT user_id, device_id, "
|
||||
" d.display_name AS device_display_name, "
|
||||
" d.display_name, "
|
||||
" k.key_json"
|
||||
" FROM devices d"
|
||||
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
|
||||
|
@ -130,13 +208,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
txn.execute(sql, query_params)
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
result = {}
|
||||
for row in rows:
|
||||
result = {} # type: Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]
|
||||
for (user_id, device_id, display_name, key_json) in txn:
|
||||
if include_deleted_devices:
|
||||
deleted_devices.remove((row["user_id"], row["device_id"]))
|
||||
result.setdefault(row["user_id"], {})[row["device_id"]] = row
|
||||
deleted_devices.remove((user_id, device_id))
|
||||
result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
|
||||
display_name, key_json
|
||||
)
|
||||
|
||||
if include_deleted_devices:
|
||||
for user_id, device_id in deleted_devices:
|
||||
|
@ -167,13 +246,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
# note that target_device_result will be None for deleted devices.
|
||||
continue
|
||||
|
||||
target_device_signatures = target_device_result.setdefault("signatures", {})
|
||||
target_device_signatures = target_device_result.signatures
|
||||
if target_device_signatures is None:
|
||||
target_device_signatures = target_device_result.signatures = {}
|
||||
|
||||
signing_user_signatures = target_device_signatures.setdefault(
|
||||
signing_user_id, {}
|
||||
)
|
||||
signing_user_signatures[signing_key_id] = signature
|
||||
|
||||
log_kv(result)
|
||||
return result
|
||||
|
||||
async def get_e2e_one_time_keys(
|
||||
|
@ -252,10 +333,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def count_e2e_one_time_keys(self, user_id, device_id):
|
||||
async def count_e2e_one_time_keys(
|
||||
self, user_id: str, device_id: str
|
||||
) -> Dict[str, int]:
|
||||
""" Count the number of one time keys the server has for a device
|
||||
Returns:
|
||||
Dict mapping from algorithm to number of keys for that algorithm.
|
||||
A mapping from algorithm to number of keys for that algorithm.
|
||||
"""
|
||||
|
||||
def _count_e2e_one_time_keys(txn):
|
||||
|
@ -270,7 +353,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
result[algorithm] = key_count
|
||||
return result
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||
)
|
||||
|
||||
|
@ -308,7 +391,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
list_name="user_ids",
|
||||
num_args=1,
|
||||
)
|
||||
def _get_bare_e2e_cross_signing_keys_bulk(
|
||||
async def _get_bare_e2e_cross_signing_keys_bulk(
|
||||
self, user_ids: List[str]
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Returns the cross-signing keys for a set of users. The output of this
|
||||
|
@ -316,16 +399,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
the signatures for the calling user need to be fetched.
|
||||
|
||||
Args:
|
||||
user_ids (list[str]): the users whose keys are being requested
|
||||
user_ids: the users whose keys are being requested
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
||||
data. If a user's cross-signing keys were not found, either
|
||||
their user ID will not be in the dict, or their user ID will map
|
||||
to None.
|
||||
A mapping from user ID to key type to key data. If a user's cross-signing
|
||||
keys were not found, either their user ID will not be in the dict, or
|
||||
their user ID will map to None.
|
||||
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_bare_e2e_cross_signing_keys_bulk",
|
||||
self._get_bare_e2e_cross_signing_keys_bulk_txn,
|
||||
user_ids,
|
||||
|
@ -541,9 +623,16 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||
_get_all_user_signature_changes_for_remotes_txn,
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_device_stream_token(self) -> int:
|
||||
"""Get the current stream id from the _device_list_id_gen"""
|
||||
...
|
||||
|
||||
|
||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
||||
async def set_e2e_device_keys(
|
||||
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
|
||||
) -> bool:
|
||||
"""Stores device keys for a device. Returns whether there was a change
|
||||
or the keys were already in the database.
|
||||
"""
|
||||
|
@ -579,12 +668,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
log_kv({"message": "Device keys stored."})
|
||||
return True
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"set_e2e_device_keys", _set_e2e_device_keys_txn
|
||||
)
|
||||
|
||||
def claim_e2e_one_time_keys(self, query_list):
|
||||
"""Take a list of one time keys out of the database"""
|
||||
async def claim_e2e_one_time_keys(
|
||||
self, query_list: Iterable[Tuple[str, str, str]]
|
||||
) -> Dict[str, Dict[str, Dict[str, bytes]]]:
|
||||
"""Take a list of one time keys out of the database.
|
||||
|
||||
Args:
|
||||
query_list: An iterable of tuples of (user ID, device ID, algorithm).
|
||||
|
||||
Returns:
|
||||
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
||||
"""
|
||||
|
||||
@trace
|
||||
def _claim_e2e_one_time_keys(txn):
|
||||
|
@ -620,11 +718,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
)
|
||||
return result
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
||||
)
|
||||
|
||||
def delete_e2e_keys_by_device(self, user_id, device_id):
|
||||
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
|
||||
def delete_e2e_keys_by_device_txn(txn):
|
||||
log_kv(
|
||||
{
|
||||
|
@ -647,7 +745,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
|||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||
)
|
||||
|
||||
|
|
|
@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
"""
|
||||
|
||||
if stream_ordering <= self.stream_ordering_month_ago:
|
||||
raise StoreError(400, "stream_ordering too old")
|
||||
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
|
||||
|
||||
sql = """
|
||||
SELECT event_id FROM stream_ordering_to_exterm
|
||||
|
|
|
@ -15,7 +15,9 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
|
||||
|
@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
|
||||
@cached(num_args=3, tree=True, max_entries=5000)
|
||||
async def get_unread_event_push_actions_by_room_for_user(
|
||||
self, room_id, user_id, last_read_event_id
|
||||
):
|
||||
self, room_id: str, user_id: str, last_read_event_id: Optional[str],
|
||||
) -> Dict[str, int]:
|
||||
"""Get the notification count, the highlight count and the unread message count
|
||||
for a given user in a given room after the given read receipt.
|
||||
|
||||
Note that this function assumes the user to be a current member of the room,
|
||||
since it's either called by the sync handler to handle joined room entries, or by
|
||||
the HTTP pusher to calculate the badge of unread joined rooms.
|
||||
|
||||
Args:
|
||||
room_id: The room to retrieve the counts in.
|
||||
user_id: The user to retrieve the counts for.
|
||||
last_read_event_id: The event associated with the latest read receipt for
|
||||
this user in this room. None if no receipt for this user in this room.
|
||||
|
||||
Returns
|
||||
A dict containing the counts mentioned earlier in this docstring,
|
||||
respectively under the keys "notify_count", "highlight_count" and
|
||||
"unread_count".
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_unread_event_push_actions_by_room",
|
||||
self._get_unread_counts_by_receipt_txn,
|
||||
|
@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
def _get_unread_counts_by_receipt_txn(
|
||||
self, txn, room_id, user_id, last_read_event_id
|
||||
self, txn, room_id, user_id, last_read_event_id,
|
||||
):
|
||||
sql = (
|
||||
"SELECT stream_ordering"
|
||||
" FROM events"
|
||||
" WHERE room_id = ? AND event_id = ?"
|
||||
)
|
||||
txn.execute(sql, (room_id, last_read_event_id))
|
||||
results = txn.fetchall()
|
||||
if len(results) == 0:
|
||||
return {"notify_count": 0, "highlight_count": 0}
|
||||
stream_ordering = None
|
||||
|
||||
stream_ordering = results[0][0]
|
||||
if last_read_event_id is not None:
|
||||
stream_ordering = self.get_stream_id_for_event_txn(
|
||||
txn, last_read_event_id, allow_none=True,
|
||||
)
|
||||
|
||||
if stream_ordering is None:
|
||||
# Either last_read_event_id is None, or it's an event we don't have (e.g.
|
||||
# because it's been purged), in which case retrieve the stream ordering for
|
||||
# the latest membership event from this user in this room (which we assume is
|
||||
# a join).
|
||||
event_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn=txn,
|
||||
table="local_current_membership",
|
||||
keyvalues={"room_id": room_id, "user_id": user_id},
|
||||
retcol="event_id",
|
||||
)
|
||||
|
||||
stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
|
||||
|
||||
return self._get_unread_counts_by_pos_txn(
|
||||
txn, room_id, user_id, stream_ordering
|
||||
)
|
||||
|
||||
def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
|
||||
|
||||
# First get number of notifications.
|
||||
# We don't need to put a notif=1 clause as all rows always have
|
||||
# notif=1
|
||||
sql = (
|
||||
"SELECT count(*)"
|
||||
"SELECT"
|
||||
" COUNT(CASE WHEN notif = 1 THEN 1 END),"
|
||||
" COUNT(CASE WHEN highlight = 1 THEN 1 END),"
|
||||
" COUNT(CASE WHEN unread = 1 THEN 1 END)"
|
||||
" FROM event_push_actions ea"
|
||||
" WHERE"
|
||||
" user_id = ?"
|
||||
" AND room_id = ?"
|
||||
" AND stream_ordering > ?"
|
||||
" WHERE user_id = ?"
|
||||
" AND room_id = ?"
|
||||
" AND stream_ordering > ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (user_id, room_id, stream_ordering))
|
||||
row = txn.fetchone()
|
||||
notify_count = row[0] if row else 0
|
||||
|
||||
(notif_count, highlight_count, unread_count) = (0, 0, 0)
|
||||
|
||||
if row:
|
||||
(notif_count, highlight_count, unread_count) = row
|
||||
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT notif_count FROM event_push_summary
|
||||
WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
|
||||
""",
|
||||
SELECT notif_count, unread_count FROM event_push_summary
|
||||
WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
|
||||
""",
|
||||
(room_id, user_id, stream_ordering),
|
||||
)
|
||||
rows = txn.fetchall()
|
||||
if rows:
|
||||
notify_count += rows[0][0]
|
||||
|
||||
# Now get the number of highlights
|
||||
sql = (
|
||||
"SELECT count(*)"
|
||||
" FROM event_push_actions ea"
|
||||
" WHERE"
|
||||
" highlight = 1"
|
||||
" AND user_id = ?"
|
||||
" AND room_id = ?"
|
||||
" AND stream_ordering > ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (user_id, room_id, stream_ordering))
|
||||
row = txn.fetchone()
|
||||
highlight_count = row[0] if row else 0
|
||||
|
||||
return {"notify_count": notify_count, "highlight_count": highlight_count}
|
||||
if row:
|
||||
notif_count += row[0]
|
||||
unread_count += row[1]
|
||||
|
||||
return {
|
||||
"notify_count": notif_count,
|
||||
"unread_count": unread_count,
|
||||
"highlight_count": highlight_count,
|
||||
}
|
||||
|
||||
async def get_push_action_users_in_range(
|
||||
self, min_stream_ordering, max_stream_ordering
|
||||
|
@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
" AND ep.user_id = ?"
|
||||
" AND ep.stream_ordering > ?"
|
||||
" AND ep.stream_ordering <= ?"
|
||||
" AND ep.notif = 1"
|
||||
" ORDER BY ep.stream_ordering ASC LIMIT ?"
|
||||
)
|
||||
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
|
||||
|
@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
" AND ep.user_id = ?"
|
||||
" AND ep.stream_ordering > ?"
|
||||
" AND ep.stream_ordering <= ?"
|
||||
" AND ep.notif = 1"
|
||||
" ORDER BY ep.stream_ordering ASC LIMIT ?"
|
||||
)
|
||||
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
|
||||
|
@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
" AND ep.user_id = ?"
|
||||
" AND ep.stream_ordering > ?"
|
||||
" AND ep.stream_ordering <= ?"
|
||||
" AND ep.notif = 1"
|
||||
" ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||
)
|
||||
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
|
||||
|
@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
" AND ep.user_id = ?"
|
||||
" AND ep.stream_ordering > ?"
|
||||
" AND ep.stream_ordering <= ?"
|
||||
" AND ep.notif = 1"
|
||||
" ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||
)
|
||||
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
|
||||
|
@ -383,62 +409,66 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
# Now return the first `limit`
|
||||
return notifs[:limit]
|
||||
|
||||
def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
|
||||
async def get_if_maybe_push_in_range_for_user(
|
||||
self, user_id: str, min_stream_ordering: int
|
||||
) -> bool:
|
||||
"""A fast check to see if there might be something to push for the
|
||||
user since the given stream ordering. May return false positives.
|
||||
|
||||
Useful to know whether to bother starting a pusher on start up or not.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
min_stream_ordering (int)
|
||||
user_id
|
||||
min_stream_ordering
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: True if there may be push to process, False if
|
||||
there definitely isn't.
|
||||
True if there may be push to process, False if there definitely isn't.
|
||||
"""
|
||||
|
||||
def _get_if_maybe_push_in_range_for_user_txn(txn):
|
||||
sql = """
|
||||
SELECT 1 FROM event_push_actions
|
||||
WHERE user_id = ? AND stream_ordering > ?
|
||||
WHERE user_id = ? AND stream_ordering > ? AND notif = 1
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
txn.execute(sql, (user_id, min_stream_ordering))
|
||||
return bool(txn.fetchone())
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_if_maybe_push_in_range_for_user",
|
||||
_get_if_maybe_push_in_range_for_user_txn,
|
||||
)
|
||||
|
||||
async def add_push_actions_to_staging(self, event_id, user_id_actions):
|
||||
async def add_push_actions_to_staging(
|
||||
self,
|
||||
event_id: str,
|
||||
user_id_actions: Dict[str, List[Union[dict, str]]],
|
||||
count_as_unread: bool,
|
||||
) -> None:
|
||||
"""Add the push actions for the event to the push action staging area.
|
||||
|
||||
Args:
|
||||
event_id (str)
|
||||
user_id_actions (dict[str, list[dict|str])]): A dictionary mapping
|
||||
user_id to list of push actions, where an action can either be
|
||||
a string or dict.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
event_id
|
||||
user_id_actions: A mapping of user_id to list of push actions, where
|
||||
an action can either be a string or dict.
|
||||
count_as_unread: Whether this event should increment unread counts.
|
||||
"""
|
||||
|
||||
if not user_id_actions:
|
||||
return
|
||||
|
||||
# This is a helper function for generating the necessary tuple that
|
||||
# can be used to inert into the `event_push_actions_staging` table.
|
||||
# can be used to insert into the `event_push_actions_staging` table.
|
||||
def _gen_entry(user_id, actions):
|
||||
is_highlight = 1 if _action_has_highlight(actions) else 0
|
||||
notif = 1 if "notify" in actions else 0
|
||||
return (
|
||||
event_id, # event_id column
|
||||
user_id, # user_id column
|
||||
_serialize_action(actions, is_highlight), # actions column
|
||||
1, # notif column
|
||||
notif, # notif column
|
||||
is_highlight, # highlight column
|
||||
int(count_as_unread), # unread column
|
||||
)
|
||||
|
||||
def _add_push_actions_to_staging_txn(txn):
|
||||
|
@ -447,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
|
||||
sql = """
|
||||
INSERT INTO event_push_actions_staging
|
||||
(event_id, user_id, actions, notif, highlight)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
(event_id, user_id, actions, notif, highlight, unread)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
txn.executemany(
|
||||
|
@ -507,7 +537,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
"Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
|
||||
)
|
||||
|
||||
def find_first_stream_ordering_after_ts(self, ts):
|
||||
async def find_first_stream_ordering_after_ts(self, ts: int) -> int:
|
||||
"""Gets the stream ordering corresponding to a given timestamp.
|
||||
|
||||
Specifically, finds the stream_ordering of the first event that was
|
||||
|
@ -516,13 +546,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
relatively slow.
|
||||
|
||||
Args:
|
||||
ts (int): timestamp in millis
|
||||
ts: timestamp in millis
|
||||
|
||||
Returns:
|
||||
Deferred[int]: stream ordering of the first event received on/after
|
||||
the timestamp
|
||||
stream ordering of the first event received on/after the timestamp
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"_find_first_stream_ordering_after_ts_txn",
|
||||
self._find_first_stream_ordering_after_ts_txn,
|
||||
ts,
|
||||
|
@ -813,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||
# Calculate the new counts that should be upserted into event_push_summary
|
||||
sql = """
|
||||
SELECT user_id, room_id,
|
||||
coalesce(old.notif_count, 0) + upd.notif_count,
|
||||
coalesce(old.%s, 0) + upd.cnt,
|
||||
upd.stream_ordering,
|
||||
old.user_id
|
||||
FROM (
|
||||
SELECT user_id, room_id, count(*) as notif_count,
|
||||
SELECT user_id, room_id, count(*) as cnt,
|
||||
max(stream_ordering) as stream_ordering
|
||||
FROM event_push_actions
|
||||
WHERE ? <= stream_ordering AND stream_ordering < ?
|
||||
AND highlight = 0
|
||||
AND %s = 1
|
||||
GROUP BY user_id, room_id
|
||||
) AS upd
|
||||
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
|
||||
"""
|
||||
|
||||
txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
|
||||
rows = txn.fetchall()
|
||||
# First get the count of unread messages.
|
||||
txn.execute(
|
||||
sql % ("unread_count", "unread"),
|
||||
(old_rotate_stream_ordering, rotate_to_stream_ordering),
|
||||
)
|
||||
|
||||
logger.info("Rotating notifications, handling %d rows", len(rows))
|
||||
# We need to merge results from the two requests (the one that retrieves the
|
||||
# unread count and the one that retrieves the notifications count) into a single
|
||||
# object because we might not have the same amount of rows in each of them. To do
|
||||
# this, we use a dict indexed on the user ID and room ID to make it easier to
|
||||
# populate.
|
||||
summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
|
||||
for row in txn:
|
||||
summaries[(row[0], row[1])] = _EventPushSummary(
|
||||
unread_count=row[2],
|
||||
stream_ordering=row[3],
|
||||
old_user_id=row[4],
|
||||
notif_count=0,
|
||||
)
|
||||
|
||||
# Then get the count of notifications.
|
||||
txn.execute(
|
||||
sql % ("notif_count", "notif"),
|
||||
(old_rotate_stream_ordering, rotate_to_stream_ordering),
|
||||
)
|
||||
|
||||
for row in txn:
|
||||
if (row[0], row[1]) in summaries:
|
||||
summaries[(row[0], row[1])].notif_count = row[2]
|
||||
else:
|
||||
# Because the rules on notifying are different than the rules on marking
|
||||
# a message unread, we might end up with messages that notify but aren't
|
||||
# marked unread, so we might not have a summary for this (user, room)
|
||||
# tuple to complete.
|
||||
summaries[(row[0], row[1])] = _EventPushSummary(
|
||||
unread_count=0,
|
||||
stream_ordering=row[3],
|
||||
old_user_id=row[4],
|
||||
notif_count=row[2],
|
||||
)
|
||||
|
||||
logger.info("Rotating notifications, handling %d rows", len(summaries))
|
||||
|
||||
# If the `old.user_id` above is NULL then we know there isn't already an
|
||||
# entry in the table, so we simply insert it. Otherwise we update the
|
||||
|
@ -840,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
|
|||
table="event_push_summary",
|
||||
values=[
|
||||
{
|
||||
"user_id": row[0],
|
||||
"room_id": row[1],
|
||||
"notif_count": row[2],
|
||||
"stream_ordering": row[3],
|
||||
"user_id": user_id,
|
||||
"room_id": room_id,
|
||||
"notif_count": summary.notif_count,
|
||||
"unread_count": summary.unread_count,
|
||||
"stream_ordering": summary.stream_ordering,
|
||||
}
|
||||
for row in rows
|
||||
if row[4] is None
|
||||
for ((user_id, room_id), summary) in summaries.items()
|
||||
if summary.old_user_id is None
|
||||
],
|
||||
)
|
||||
|
||||
txn.executemany(
|
||||
"""
|
||||
UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
|
||||
UPDATE event_push_summary
|
||||
SET notif_count = ?, unread_count = ?, stream_ordering = ?
|
||||
WHERE user_id = ? AND room_id = ?
|
||||
""",
|
||||
((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
|
||||
(
|
||||
(
|
||||
summary.notif_count,
|
||||
summary.unread_count,
|
||||
summary.stream_ordering,
|
||||
user_id,
|
||||
room_id,
|
||||
)
|
||||
for ((user_id, room_id), summary) in summaries.items()
|
||||
if summary.old_user_id is not None
|
||||
),
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
|
@ -881,3 +961,15 @@ def _action_has_highlight(actions):
|
|||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@attr.s
|
||||
class _EventPushSummary:
|
||||
"""Summary of pending event push actions for a given user in a given room.
|
||||
Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
|
||||
"""
|
||||
|
||||
unread_count = attr.ib(type=int)
|
||||
stream_ordering = attr.ib(type=int)
|
||||
old_user_id = attr.ib(type=str)
|
||||
notif_count = attr.ib(type=int)
|
||||
|
|
|
@ -97,6 +97,7 @@ class PersistEventsStore:
|
|||
self.store = main_data_store
|
||||
self.database_engine = db.engine
|
||||
self._clock = hs.get_clock()
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
@ -108,7 +109,7 @@ class PersistEventsStore:
|
|||
|
||||
# This should only exist on instances that are configured to write
|
||||
assert (
|
||||
hs.config.worker.writers.events == hs.get_instance_name()
|
||||
hs.get_instance_name() in hs.config.worker.writers.events
|
||||
), "Can only instantiate EventsStore on master"
|
||||
|
||||
async def _persist_events_and_state_updates(
|
||||
|
@ -800,6 +801,7 @@ class PersistEventsStore:
|
|||
table="events",
|
||||
values=[
|
||||
{
|
||||
"instance_name": self._instance_name,
|
||||
"stream_ordering": event.internal_metadata.stream_ordering,
|
||||
"topological_ordering": event.depth,
|
||||
"depth": event.depth,
|
||||
|
@ -1296,9 +1298,9 @@ class PersistEventsStore:
|
|||
sql = """
|
||||
INSERT INTO event_push_actions (
|
||||
room_id, event_id, user_id, actions, stream_ordering,
|
||||
topological_ordering, notif, highlight
|
||||
topological_ordering, notif, highlight, unread
|
||||
)
|
||||
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
|
||||
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
|
||||
FROM event_push_actions_staging
|
||||
WHERE event_id = ?
|
||||
"""
|
||||
|
|
|
@ -42,7 +42,8 @@ from synapse.replication.tcp.streams import BackfillStream
|
|||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.types import Collection, get_domain_from_id
|
||||
from synapse.util.caches.descriptors import Cache, cached
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
@ -78,27 +79,54 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
if hs.config.worker.writers.events == hs.get_instance_name():
|
||||
# We are the process in charge of generating stream ids for events,
|
||||
# so instantiate ID generators based on the database
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering",
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
# If we're using Postgres than we can use `MultiWriterIdGenerator`
|
||||
# regardless of whether this process writes to the streams or not.
|
||||
self._stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="events",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_ordering",
|
||||
sequence_name="events_stream_seq",
|
||||
)
|
||||
self._backfill_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
"events",
|
||||
"stream_ordering",
|
||||
step=-1,
|
||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
|
||||
self._backfill_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="events",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_ordering",
|
||||
sequence_name="events_backfill_stream_seq",
|
||||
positive=False,
|
||||
)
|
||||
else:
|
||||
# Another process is in charge of persisting events and generating
|
||||
# stream IDs: rely on the replication streams to let us know which
|
||||
# IDs we can process.
|
||||
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
|
||||
self._backfill_id_gen = SlavedIdTracker(
|
||||
db_conn, "events", "stream_ordering", step=-1
|
||||
)
|
||||
# We shouldn't be running in worker mode with SQLite, but its useful
|
||||
# to support it for unit tests.
|
||||
#
|
||||
# If this process is the writer than we need to use
|
||||
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
|
||||
# updated over replication. (Multiple writers are not supported for
|
||||
# SQLite).
|
||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering",
|
||||
)
|
||||
self._backfill_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
"events",
|
||||
"stream_ordering",
|
||||
step=-1,
|
||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
|
||||
)
|
||||
else:
|
||||
self._stream_id_gen = SlavedIdTracker(
|
||||
db_conn, "events", "stream_ordering"
|
||||
)
|
||||
self._backfill_id_gen = SlavedIdTracker(
|
||||
db_conn, "events", "stream_ordering", step=-1
|
||||
)
|
||||
|
||||
self._get_event_cache = Cache(
|
||||
"*getEvent*",
|
||||
|
@ -823,20 +851,24 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
return event_dict
|
||||
|
||||
def _maybe_redact_event_row(self, original_ev, redactions, event_map):
|
||||
def _maybe_redact_event_row(
|
||||
self,
|
||||
original_ev: EventBase,
|
||||
redactions: Iterable[str],
|
||||
event_map: Dict[str, EventBase],
|
||||
) -> Optional[EventBase]:
|
||||
"""Given an event object and a list of possible redacting event ids,
|
||||
determine whether to honour any of those redactions and if so return a redacted
|
||||
event.
|
||||
|
||||
Args:
|
||||
original_ev (EventBase):
|
||||
redactions (iterable[str]): list of event ids of potential redaction events
|
||||
event_map (dict[str, EventBase]): other events which have been fetched, in
|
||||
which we can look up the redaaction events. Map from event id to event.
|
||||
original_ev: The original event.
|
||||
redactions: list of event ids of potential redaction events
|
||||
event_map: other events which have been fetched, in which we can
|
||||
look up the redaaction events. Map from event id to event.
|
||||
|
||||
Returns:
|
||||
Deferred[EventBase|None]: if the event should be redacted, a pruned
|
||||
event object. Otherwise, None.
|
||||
If the event should be redacted, a pruned event object. Otherwise, None.
|
||||
"""
|
||||
if original_ev.type == "m.room.create":
|
||||
# we choose to ignore redactions of m.room.create events.
|
||||
|
@ -946,17 +978,17 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
row = txn.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
def get_current_state_event_counts(self, room_id):
|
||||
async def get_current_state_event_counts(self, room_id: str) -> int:
|
||||
"""
|
||||
Gets the current number of state events in a room.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id: The room ID to query.
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
The current number of state events.
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_current_state_event_counts",
|
||||
self._get_current_state_event_counts_txn,
|
||||
room_id,
|
||||
|
@ -991,7 +1023,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
"""The current maximum token that events have reached"""
|
||||
return self._stream_id_gen.get_current_token()
|
||||
|
||||
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
|
||||
async def get_all_new_forward_event_rows(
|
||||
self, last_id: int, current_id: int, limit: int
|
||||
) -> List[Tuple]:
|
||||
"""Returns new events, for the Events replication stream
|
||||
|
||||
Args:
|
||||
|
@ -999,7 +1033,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
current_id: the maximum stream_id to return up to
|
||||
limit: the maximum number of rows to return
|
||||
|
||||
Returns: Deferred[List[Tuple]]
|
||||
Returns:
|
||||
a list of events stream rows. Each tuple consists of a stream id as
|
||||
the first element, followed by fields suitable for casting into an
|
||||
EventsStreamRow.
|
||||
|
@ -1020,18 +1054,20 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
|
||||
)
|
||||
|
||||
def get_ex_outlier_stream_rows(self, last_id, current_id):
|
||||
async def get_ex_outlier_stream_rows(
|
||||
self, last_id: int, current_id: int
|
||||
) -> List[Tuple]:
|
||||
"""Returns de-outliered events, for the Events replication stream
|
||||
|
||||
Args:
|
||||
last_id: the last stream_id from the previous batch.
|
||||
current_id: the maximum stream_id to return up to
|
||||
|
||||
Returns: Deferred[List[Tuple]]
|
||||
Returns:
|
||||
a list of events stream rows. Each tuple consists of a stream id as
|
||||
the first element, followed by fields suitable for casting into an
|
||||
EventsStreamRow.
|
||||
|
@ -1054,7 +1090,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, (last_id, current_id))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
|
||||
)
|
||||
|
||||
|
@ -1226,11 +1262,11 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
|
||||
|
||||
def get_next_event_to_expire(self):
|
||||
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
|
||||
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
|
||||
table, or None if there's no more event to expire.
|
||||
|
||||
Returns: Deferred[Optional[Tuple[str, int]]]
|
||||
Returns:
|
||||
A tuple containing the event ID as its first element and an expiry timestamp
|
||||
as its second one, if there's at least one row in the event_expiry table.
|
||||
None otherwise.
|
||||
|
@ -1246,6 +1282,6 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
return txn.fetchone()
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||
)
|
||||
|
|
|
@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json
|
|||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
|
||||
|
@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
|
|||
|
||||
return db_to_json(def_json)
|
||||
|
||||
def add_user_filter(self, user_localpart, user_filter):
|
||||
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
|
||||
def_json = encode_canonical_json(user_filter)
|
||||
|
||||
# Need an atomic transaction to SELECT the maximal ID so far then
|
||||
|
@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
|
|||
|
||||
return filter_id
|
||||
|
||||
return self.db_pool.runInteraction("add_user_filter", _do_txn)
|
||||
return await self.db_pool.runInteraction("add_user_filter", _do_txn)
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
@ -93,7 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
desc="mark_local_media_as_safe",
|
||||
)
|
||||
|
||||
def get_url_cache(self, url, ts):
|
||||
async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get the media_id and ts for a cached URL as of the given timestamp
|
||||
Returns:
|
||||
None if the URL isn't cached.
|
||||
|
@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
)
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
|
||||
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
|
||||
|
||||
async def store_url_cache(
|
||||
self, url, response_code, etag, expires_ts, og, media_id, download_ts
|
||||
|
@ -237,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
desc="store_cached_remote_media",
|
||||
)
|
||||
|
||||
def update_cached_last_access_time(self, local_media, remote_media, time_ms):
|
||||
async def update_cached_last_access_time(
|
||||
self,
|
||||
local_media: Iterable[str],
|
||||
remote_media: Iterable[Tuple[str, str]],
|
||||
time_ms: int,
|
||||
):
|
||||
"""Updates the last access time of the given media
|
||||
|
||||
Args:
|
||||
local_media (iterable[str]): Set of media_ids
|
||||
remote_media (iterable[(str, str)]): Set of (server_name, media_id)
|
||||
local_media: Set of media_ids
|
||||
remote_media: Set of (server_name, media_id)
|
||||
time_ms: Current time in milliseconds
|
||||
"""
|
||||
|
||||
|
@ -267,7 +272,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
|
||||
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"update_cached_last_access_time", update_cache_txn
|
||||
)
|
||||
|
||||
|
@ -325,7 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
|
||||
)
|
||||
|
||||
def delete_remote_media(self, media_origin, media_id):
|
||||
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
|
||||
def delete_remote_media_txn(txn):
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
|
@ -338,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
keyvalues={"media_origin": media_origin, "media_id": media_id},
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_remote_media", delete_remote_media_txn
|
||||
)
|
||||
|
||||
def get_expired_url_cache(self, now_ts):
|
||||
async def get_expired_url_cache(self, now_ts: int) -> List[str]:
|
||||
sql = (
|
||||
"SELECT media_id FROM local_media_repository_url_cache"
|
||||
" WHERE expires_ts < ?"
|
||||
|
@ -354,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
txn.execute(sql, (now_ts,))
|
||||
return [row[0] for row in txn]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_expired_url_cache", _get_expired_url_cache_txn
|
||||
)
|
||||
|
||||
|
@ -371,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
"delete_url_cache", _delete_url_cache_txn
|
||||
)
|
||||
|
||||
def get_url_cache_media_before(self, before_ts):
|
||||
async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
|
||||
sql = (
|
||||
"SELECT media_id FROM local_media_repository"
|
||||
" WHERE created_ts < ? AND url_cache IS NOT NULL"
|
||||
|
@ -383,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
txn.execute(sql, (before_ts,))
|
||||
return [row[0] for row in txn]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_url_cache_media_before", _get_url_cache_media_before_txn
|
||||
)
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
|
||||
|
||||
|
@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
|
|||
desc="insert_open_id_token",
|
||||
)
|
||||
|
||||
def get_user_id_for_open_id_token(self, token, ts_now_ms):
|
||||
async def get_user_id_for_open_id_token(
|
||||
self, token: str, ts_now_ms: int
|
||||
) -> Optional[str]:
|
||||
def get_user_id_for_token_txn(txn):
|
||||
sql = (
|
||||
"SELECT user_id FROM open_id_tokens"
|
||||
|
@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
|
|||
else:
|
||||
return rows[0][0]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_user_id_for_token", get_user_id_for_token_txn
|
||||
)
|
||||
|
|
|
@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore):
|
|||
desc="delete_remote_profile_cache",
|
||||
)
|
||||
|
||||
def get_remote_profile_cache_entries_that_expire(self, last_checked):
|
||||
async def get_remote_profile_cache_entries_that_expire(
|
||||
self, last_checked: int
|
||||
) -> Dict[str, str]:
|
||||
"""Get all users who haven't been checked since `last_checked`
|
||||
"""
|
||||
|
||||
|
@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
|
|||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_remote_profile_cache_entries_that_expire",
|
||||
_get_remote_profile_cache_entries_that_expire_txn,
|
||||
)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, List, Set, Tuple
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
|
@ -25,25 +25,24 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
def purge_history(self, room_id, token, delete_local_events):
|
||||
async def purge_history(
|
||||
self, room_id: str, token: str, delete_local_events: bool
|
||||
) -> Set[int]:
|
||||
"""Deletes room history before a certain point
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
|
||||
token (str): A topological token to delete events before
|
||||
|
||||
delete_local_events (bool):
|
||||
room_id:
|
||||
token: A topological token to delete events before
|
||||
delete_local_events:
|
||||
if True, we will delete local events as well as remote ones
|
||||
(instead of just marking them as outliers and deleting their
|
||||
state groups).
|
||||
|
||||
Returns:
|
||||
Deferred[set[int]]: The set of state groups that are referenced by
|
||||
deleted events.
|
||||
The set of state groups that are referenced by deleted events.
|
||||
"""
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"purge_history",
|
||||
self._purge_history_txn,
|
||||
room_id,
|
||||
|
@ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
|||
|
||||
return referenced_state_groups
|
||||
|
||||
def purge_room(self, room_id):
|
||||
async def purge_room(self, room_id: str) -> List[int]:
|
||||
"""Deletes all record of a room
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id
|
||||
|
||||
Returns:
|
||||
Deferred[List[int]]: The list of state groups to delete.
|
||||
The list of state groups to delete.
|
||||
"""
|
||||
|
||||
return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id)
|
||||
return await self.db_pool.runInteraction(
|
||||
"purge_room", self._purge_room_txn, room_id
|
||||
)
|
||||
|
||||
def _purge_room_txn(self, txn, room_id):
|
||||
# First we fetch all the state groups that should be deleted, before
|
||||
|
|
|
@ -18,8 +18,6 @@ import abc
|
|||
import logging
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.push.baserules import list_with_base_rules
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
|
@ -149,9 +147,11 @@ class PushRulesWorkerStore(
|
|||
)
|
||||
return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
|
||||
|
||||
def have_push_rules_changed_for_user(self, user_id, last_id):
|
||||
async def have_push_rules_changed_for_user(
|
||||
self, user_id: str, last_id: int
|
||||
) -> bool:
|
||||
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
|
||||
return defer.succeed(False)
|
||||
return False
|
||||
else:
|
||||
|
||||
def have_push_rules_changed_txn(txn):
|
||||
|
@ -163,7 +163,7 @@ class PushRulesWorkerStore(
|
|||
(count,) = txn.fetchone()
|
||||
return bool(count)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"have_push_rules_changed", have_push_rules_changed_txn
|
||||
)
|
||||
|
||||
|
|
|
@ -276,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
}
|
||||
return results
|
||||
|
||||
def get_users_sent_receipts_between(self, last_id: int, current_id: int):
|
||||
async def get_users_sent_receipts_between(
|
||||
self, last_id: int, current_id: int
|
||||
) -> List[str]:
|
||||
"""Get all users who sent receipts between `last_id` exclusive and
|
||||
`current_id` inclusive.
|
||||
|
||||
Returns:
|
||||
Deferred[List[str]]
|
||||
The list of users.
|
||||
"""
|
||||
|
||||
if last_id == current_id:
|
||||
|
@ -296,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
|
||||
return [r[0] for r in txn]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
|
||||
)
|
||||
|
||||
|
@ -553,8 +555,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
|||
|
||||
return stream_id, max_persisted_id
|
||||
|
||||
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
|
||||
return self.db_pool.runInteraction(
|
||||
async def insert_graph_receipt(
|
||||
self, room_id, receipt_type, user_id, event_ids, data
|
||||
):
|
||||
return await self.db_pool.runInteraction(
|
||||
"insert_graph_receipt",
|
||||
self.insert_graph_receipt_txn,
|
||||
room_id,
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||
|
@ -84,22 +84,22 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
return is_trial
|
||||
|
||||
@cached()
|
||||
def get_user_by_access_token(self, token):
|
||||
async def get_user_by_access_token(self, token: str) -> Optional[dict]:
|
||||
"""Get a user from the given access token.
|
||||
|
||||
Args:
|
||||
token (str): The access token of a user.
|
||||
token: The access token of a user.
|
||||
Returns:
|
||||
defer.Deferred: None, if the token did not match, otherwise dict
|
||||
including the keys `name`, `is_guest`, `device_id`, `token_id`,
|
||||
`valid_until_ms`.
|
||||
None, if the token did not match, otherwise dict
|
||||
including the keys `name`, `is_guest`, `device_id`, `token_id`,
|
||||
`valid_until_ms`.
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_user_by_access_token", self._query_for_auth, token
|
||||
)
|
||||
|
||||
@cached()
|
||||
async def get_expiration_ts_for_user(self, user_id: str) -> Optional[None]:
|
||||
async def get_expiration_ts_for_user(self, user_id: str) -> Optional[int]:
|
||||
"""Get the expiration timestamp for the account bearing a given user ID.
|
||||
|
||||
Args:
|
||||
|
@ -281,13 +281,12 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
|
||||
return bool(res) if res else False
|
||||
|
||||
def set_server_admin(self, user, admin):
|
||||
async def set_server_admin(self, user: UserID, admin: bool) -> None:
|
||||
"""Sets whether a user is an admin of this homeserver.
|
||||
|
||||
Args:
|
||||
user (UserID): user ID of the user to test
|
||||
admin (bool): true iff the user is to be a server admin,
|
||||
false otherwise.
|
||||
user: user ID of the user to test
|
||||
admin: true iff the user is to be a server admin, false otherwise.
|
||||
"""
|
||||
|
||||
def set_server_admin_txn(txn):
|
||||
|
@ -298,7 +297,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
txn, self.get_user_by_id, (user.to_string(),)
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
||||
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
||||
|
||||
def _query_for_auth(self, txn, token):
|
||||
sql = (
|
||||
|
@ -364,9 +363,11 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
)
|
||||
return True if res == UserTypes.SUPPORT else False
|
||||
|
||||
def get_users_by_id_case_insensitive(self, user_id):
|
||||
async def get_users_by_id_case_insensitive(self, user_id: str) -> Dict[str, str]:
|
||||
"""Gets users that match user_id case insensitively.
|
||||
Returns a mapping of user_id -> password_hash.
|
||||
|
||||
Returns:
|
||||
A mapping of user_id -> password_hash.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
|
@ -374,7 +375,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, (user_id,))
|
||||
return dict(txn)
|
||||
|
||||
return self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
|
||||
return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
|
||||
|
||||
async def get_user_by_external_id(
|
||||
self, auth_provider: str, external_id: str
|
||||
|
@ -408,7 +409,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
|
||||
return await self.db_pool.runInteraction("count_users", _count_users)
|
||||
|
||||
def count_daily_user_type(self):
|
||||
async def count_daily_user_type(self) -> Dict[str, int]:
|
||||
"""
|
||||
Counts 1) native non guest users
|
||||
2) native guests users
|
||||
|
@ -437,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
results[row[0]] = row[1]
|
||||
return results
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"count_daily_user_type", _count_daily_user_type
|
||||
)
|
||||
|
||||
|
@ -663,24 +664,29 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
# Convert the integer into a boolean.
|
||||
return res == 1
|
||||
|
||||
def get_threepid_validation_session(
|
||||
self, medium, client_secret, address=None, sid=None, validated=True
|
||||
):
|
||||
async def get_threepid_validation_session(
|
||||
self,
|
||||
medium: Optional[str],
|
||||
client_secret: str,
|
||||
address: Optional[str] = None,
|
||||
sid: Optional[str] = None,
|
||||
validated: Optional[bool] = True,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Gets a session_id and last_send_attempt (if available) for a
|
||||
combination of validation metadata
|
||||
|
||||
Args:
|
||||
medium (str|None): The medium of the 3PID
|
||||
address (str|None): The address of the 3PID
|
||||
sid (str|None): The ID of the validation session
|
||||
client_secret (str): A unique string provided by the client to help identify this
|
||||
medium: The medium of the 3PID
|
||||
client_secret: A unique string provided by the client to help identify this
|
||||
validation attempt
|
||||
validated (bool|None): Whether sessions should be filtered by
|
||||
address: The address of the 3PID
|
||||
sid: The ID of the validation session
|
||||
validated: Whether sessions should be filtered by
|
||||
whether they have been validated already or not. None to
|
||||
perform no filtering
|
||||
|
||||
Returns:
|
||||
Deferred[dict|None]: A dict containing the following:
|
||||
A dict containing the following:
|
||||
* address - address of the 3pid
|
||||
* medium - medium of the 3pid
|
||||
* client_secret - a secret provided by the client for this validation session
|
||||
|
@ -726,17 +732,17 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
|
||||
return rows[0]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_threepid_validation_session", get_threepid_validation_session_txn
|
||||
)
|
||||
|
||||
def delete_threepid_session(self, session_id):
|
||||
async def delete_threepid_session(self, session_id: str) -> None:
|
||||
"""Removes a threepid validation session from the database. This can
|
||||
be done after validation has been performed and whatever action was
|
||||
waiting on it has been carried out
|
||||
|
||||
Args:
|
||||
session_id (str): The ID of the session to delete
|
||||
session_id: The ID of the session to delete
|
||||
"""
|
||||
|
||||
def delete_threepid_session_txn(txn):
|
||||
|
@ -751,7 +757,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
keyvalues={"session_id": session_id},
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_threepid_session", delete_threepid_session_txn
|
||||
)
|
||||
|
||||
|
@ -941,43 +947,40 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
desc="add_access_token_to_user",
|
||||
)
|
||||
|
||||
def register_user(
|
||||
async def register_user(
|
||||
self,
|
||||
user_id,
|
||||
password_hash=None,
|
||||
was_guest=False,
|
||||
make_guest=False,
|
||||
appservice_id=None,
|
||||
create_profile_with_displayname=None,
|
||||
admin=False,
|
||||
user_type=None,
|
||||
shadow_banned=False,
|
||||
):
|
||||
user_id: str,
|
||||
password_hash: Optional[str] = None,
|
||||
was_guest: bool = False,
|
||||
make_guest: bool = False,
|
||||
appservice_id: Optional[str] = None,
|
||||
create_profile_with_displayname: Optional[str] = None,
|
||||
admin: bool = False,
|
||||
user_type: Optional[str] = None,
|
||||
shadow_banned: bool = False,
|
||||
) -> None:
|
||||
"""Attempts to register an account.
|
||||
|
||||
Args:
|
||||
user_id (str): The desired user ID to register.
|
||||
password_hash (str|None): Optional. The password hash for this user.
|
||||
was_guest (bool): Optional. Whether this is a guest account being
|
||||
upgraded to a non-guest account.
|
||||
make_guest (boolean): True if the the new user should be guest,
|
||||
false to add a regular user account.
|
||||
appservice_id (str): The ID of the appservice registering the user.
|
||||
create_profile_with_displayname (unicode): Optionally create a profile for
|
||||
user_id: The desired user ID to register.
|
||||
password_hash: Optional. The password hash for this user.
|
||||
was_guest: Whether this is a guest account being upgraded to a
|
||||
non-guest account.
|
||||
make_guest: True if the the new user should be guest, false to add a
|
||||
regular user account.
|
||||
appservice_id: The ID of the appservice registering the user.
|
||||
create_profile_with_displayname: Optionally create a profile for
|
||||
the user, setting their displayname to the given value
|
||||
admin (boolean): is an admin user?
|
||||
user_type (str|None): type of user. One of the values from
|
||||
api.constants.UserTypes, or None for a normal user.
|
||||
shadow_banned (bool): Whether the user is shadow-banned,
|
||||
i.e. they may be told their requests succeeded but we ignore them.
|
||||
admin: is an admin user?
|
||||
user_type: type of user. One of the values from api.constants.UserTypes,
|
||||
or None for a normal user.
|
||||
shadow_banned: Whether the user is shadow-banned, i.e. they may be
|
||||
told their requests succeeded but we ignore them.
|
||||
|
||||
Raises:
|
||||
StoreError if the user_id could not be registered.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"register_user",
|
||||
self._register_user,
|
||||
user_id,
|
||||
|
@ -1101,7 +1104,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
desc="record_user_external_id",
|
||||
)
|
||||
|
||||
def user_set_password_hash(self, user_id, password_hash):
|
||||
async def user_set_password_hash(self, user_id: str, password_hash: str) -> None:
|
||||
"""
|
||||
NB. This does *not* evict any cache because the one use for this
|
||||
removes most of the entries subsequently anyway so it would be
|
||||
|
@ -1114,17 +1117,18 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
)
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"user_set_password_hash", user_set_password_hash_txn
|
||||
)
|
||||
|
||||
def user_set_consent_version(self, user_id, consent_version):
|
||||
async def user_set_consent_version(
|
||||
self, user_id: str, consent_version: str
|
||||
) -> None:
|
||||
"""Updates the user table to record privacy policy consent
|
||||
|
||||
Args:
|
||||
user_id (str): full mxid of the user to update
|
||||
consent_version (str): version of the policy the user has consented
|
||||
to
|
||||
user_id: full mxid of the user to update
|
||||
consent_version: version of the policy the user has consented to
|
||||
|
||||
Raises:
|
||||
StoreError(404) if user not found
|
||||
|
@ -1139,16 +1143,17 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
)
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
|
||||
return self.db_pool.runInteraction("user_set_consent_version", f)
|
||||
await self.db_pool.runInteraction("user_set_consent_version", f)
|
||||
|
||||
def user_set_consent_server_notice_sent(self, user_id, consent_version):
|
||||
async def user_set_consent_server_notice_sent(
|
||||
self, user_id: str, consent_version: str
|
||||
) -> None:
|
||||
"""Updates the user table to record that we have sent the user a server
|
||||
notice about privacy policy consent
|
||||
|
||||
Args:
|
||||
user_id (str): full mxid of the user to update
|
||||
consent_version (str): version of the policy we have notified the
|
||||
user about
|
||||
user_id: full mxid of the user to update
|
||||
consent_version: version of the policy we have notified the user about
|
||||
|
||||
Raises:
|
||||
StoreError(404) if user not found
|
||||
|
@ -1163,22 +1168,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
)
|
||||
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||
|
||||
return self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
|
||||
await self.db_pool.runInteraction("user_set_consent_server_notice_sent", f)
|
||||
|
||||
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
|
||||
async def user_delete_access_tokens(
|
||||
self,
|
||||
user_id: str,
|
||||
except_token_id: Optional[str] = None,
|
||||
device_id: Optional[str] = None,
|
||||
) -> List[Tuple[str, int, Optional[str]]]:
|
||||
"""
|
||||
Invalidate access tokens belonging to a user
|
||||
|
||||
Args:
|
||||
user_id (str): ID of user the tokens belong to
|
||||
except_token_id (str): list of access_tokens IDs which should
|
||||
*not* be deleted
|
||||
device_id (str|None): ID of device the tokens are associated with.
|
||||
user_id: ID of user the tokens belong to
|
||||
except_token_id: access_tokens ID which should *not* be deleted
|
||||
device_id: ID of device the tokens are associated with.
|
||||
If None, tokens associated with any device (or no device) will
|
||||
be deleted
|
||||
Returns:
|
||||
defer.Deferred[list[str, int, str|None, int]]: a list of
|
||||
(token, token id, device id) for each of the deleted tokens
|
||||
A tuple of (token, token id, device id) for each of the deleted tokens
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
|
@ -1209,9 +1217,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
|
||||
return tokens_and_devices
|
||||
|
||||
return self.db_pool.runInteraction("user_delete_access_tokens", f)
|
||||
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
|
||||
|
||||
def delete_access_token(self, access_token):
|
||||
async def delete_access_token(self, access_token: str) -> None:
|
||||
def f(txn):
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn, table="access_tokens", keyvalues={"token": access_token}
|
||||
|
@ -1221,7 +1229,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
txn, self.get_user_by_access_token, (access_token,)
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction("delete_access_token", f)
|
||||
await self.db_pool.runInteraction("delete_access_token", f)
|
||||
|
||||
@cached()
|
||||
async def is_guest(self, user_id: str) -> bool:
|
||||
|
@ -1272,24 +1280,25 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
desc="get_users_pending_deactivation",
|
||||
)
|
||||
|
||||
def validate_threepid_session(self, session_id, client_secret, token, current_ts):
|
||||
async def validate_threepid_session(
|
||||
self, session_id: str, client_secret: str, token: str, current_ts: int
|
||||
) -> Optional[str]:
|
||||
"""Attempt to validate a threepid session using a token
|
||||
|
||||
Args:
|
||||
session_id (str): The id of a validation session
|
||||
client_secret (str): A unique string provided by the client to
|
||||
help identify this validation attempt
|
||||
token (str): A validation token
|
||||
current_ts (int): The current unix time in milliseconds. Used for
|
||||
checking token expiry status
|
||||
session_id: The id of a validation session
|
||||
client_secret: A unique string provided by the client to help identify
|
||||
this validation attempt
|
||||
token: A validation token
|
||||
current_ts: The current unix time in milliseconds. Used for checking
|
||||
token expiry status
|
||||
|
||||
Raises:
|
||||
ThreepidValidationError: if a matching validation token was not found or has
|
||||
expired
|
||||
|
||||
Returns:
|
||||
deferred str|None: A str representing a link to redirect the user
|
||||
to if there is one.
|
||||
A str representing a link to redirect the user to if there is one.
|
||||
"""
|
||||
|
||||
# Insert everything into a transaction in order to run atomically
|
||||
|
@ -1359,36 +1368,35 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
return next_link
|
||||
|
||||
# Return next_link if it exists
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"validate_threepid_session_txn", validate_threepid_session_txn
|
||||
)
|
||||
|
||||
def start_or_continue_validation_session(
|
||||
async def start_or_continue_validation_session(
|
||||
self,
|
||||
medium,
|
||||
address,
|
||||
session_id,
|
||||
client_secret,
|
||||
send_attempt,
|
||||
next_link,
|
||||
token,
|
||||
token_expires,
|
||||
):
|
||||
medium: str,
|
||||
address: str,
|
||||
session_id: str,
|
||||
client_secret: str,
|
||||
send_attempt: int,
|
||||
next_link: Optional[str],
|
||||
token: str,
|
||||
token_expires: int,
|
||||
) -> None:
|
||||
"""Creates a new threepid validation session if it does not already
|
||||
exist and associates a new validation token with it
|
||||
|
||||
Args:
|
||||
medium (str): The medium of the 3PID
|
||||
address (str): The address of the 3PID
|
||||
session_id (str): The id of this validation session
|
||||
client_secret (str): A unique string provided by the client to
|
||||
help identify this validation attempt
|
||||
send_attempt (int): The latest send_attempt on this session
|
||||
next_link (str|None): The link to redirect the user to upon
|
||||
successful validation
|
||||
token (str): The validation token
|
||||
token_expires (int): The timestamp for which after the token
|
||||
will no longer be valid
|
||||
medium: The medium of the 3PID
|
||||
address: The address of the 3PID
|
||||
session_id: The id of this validation session
|
||||
client_secret: A unique string provided by the client to help
|
||||
identify this validation attempt
|
||||
send_attempt: The latest send_attempt on this session
|
||||
next_link: The link to redirect the user to upon successful validation
|
||||
token: The validation token
|
||||
token_expires: The timestamp for which after the token will no
|
||||
longer be valid
|
||||
"""
|
||||
|
||||
def start_or_continue_validation_session_txn(txn):
|
||||
|
@ -1417,12 +1425,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
},
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"start_or_continue_validation_session",
|
||||
start_or_continue_validation_session_txn,
|
||||
)
|
||||
|
||||
def cull_expired_threepid_validation_tokens(self):
|
||||
async def cull_expired_threepid_validation_tokens(self) -> None:
|
||||
"""Remove threepid validation tokens with expiry dates that have passed"""
|
||||
|
||||
def cull_expired_threepid_validation_tokens_txn(txn, ts):
|
||||
|
@ -1430,9 +1438,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
DELETE FROM threepid_validation_token WHERE
|
||||
expires < ?
|
||||
"""
|
||||
return txn.execute(sql, (ts,))
|
||||
txn.execute(sql, (ts,))
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"cull_expired_threepid_validation_tokens",
|
||||
cull_expired_threepid_validation_tokens_txn,
|
||||
self.clock.time_msec(),
|
||||
|
|
|
@ -34,38 +34,33 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class RelationsWorkerStore(SQLBaseStore):
|
||||
@cached(tree=True)
|
||||
def get_relations_for_event(
|
||||
async def get_relations_for_event(
|
||||
self,
|
||||
event_id,
|
||||
relation_type=None,
|
||||
event_type=None,
|
||||
aggregation_key=None,
|
||||
limit=5,
|
||||
direction="b",
|
||||
from_token=None,
|
||||
to_token=None,
|
||||
):
|
||||
event_id: str,
|
||||
relation_type: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
aggregation_key: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
direction: str = "b",
|
||||
from_token: Optional[RelationPaginationToken] = None,
|
||||
to_token: Optional[RelationPaginationToken] = None,
|
||||
) -> PaginationChunk:
|
||||
"""Get a list of relations for an event, ordered by topological ordering.
|
||||
|
||||
Args:
|
||||
event_id (str): Fetch events that relate to this event ID.
|
||||
relation_type (str|None): Only fetch events with this relation
|
||||
type, if given.
|
||||
event_type (str|None): Only fetch events with this event type, if
|
||||
given.
|
||||
aggregation_key (str|None): Only fetch events with this aggregation
|
||||
key, if given.
|
||||
limit (int): Only fetch the most recent `limit` events.
|
||||
direction (str): Whether to fetch the most recent first (`"b"`) or
|
||||
the oldest first (`"f"`).
|
||||
from_token (RelationPaginationToken|None): Fetch rows from the given
|
||||
token, or from the start if None.
|
||||
to_token (RelationPaginationToken|None): Fetch rows up to the given
|
||||
token, or up to the end if None.
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
relation_type: Only fetch events with this relation type, if given.
|
||||
event_type: Only fetch events with this event type, if given.
|
||||
aggregation_key: Only fetch events with this aggregation key, if given.
|
||||
limit: Only fetch the most recent `limit` events.
|
||||
direction: Whether to fetch the most recent first (`"b"`) or the
|
||||
oldest first (`"f"`).
|
||||
from_token: Fetch rows from the given token, or from the start if None.
|
||||
to_token: Fetch rows up to the given token, or up to the end if None.
|
||||
|
||||
Returns:
|
||||
Deferred[PaginationChunk]: List of event IDs that match relations
|
||||
requested. The rows are of the form `{"event_id": "..."}`.
|
||||
List of event IDs that match relations requested. The rows are of
|
||||
the form `{"event_id": "..."}`.
|
||||
"""
|
||||
|
||||
where_clause = ["relates_to_id = ?"]
|
||||
|
@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_recent_references_for_event", _get_recent_references_for_event_txn
|
||||
)
|
||||
|
||||
@cached(tree=True)
|
||||
def get_aggregation_groups_for_event(
|
||||
async def get_aggregation_groups_for_event(
|
||||
self,
|
||||
event_id,
|
||||
event_type=None,
|
||||
limit=5,
|
||||
direction="b",
|
||||
from_token=None,
|
||||
to_token=None,
|
||||
):
|
||||
event_id: str,
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
direction: str = "b",
|
||||
from_token: Optional[AggregationPaginationToken] = None,
|
||||
to_token: Optional[AggregationPaginationToken] = None,
|
||||
) -> PaginationChunk:
|
||||
"""Get a list of annotations on the event, grouped by event type and
|
||||
aggregation key, sorted by count.
|
||||
|
||||
|
@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
on an event.
|
||||
|
||||
Args:
|
||||
event_id (str): Fetch events that relate to this event ID.
|
||||
event_type (str|None): Only fetch events with this event type, if
|
||||
given.
|
||||
limit (int): Only fetch the `limit` groups.
|
||||
direction (str): Whether to fetch the highest count first (`"b"`) or
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
event_type: Only fetch events with this event type, if given.
|
||||
limit: Only fetch the `limit` groups.
|
||||
direction: Whether to fetch the highest count first (`"b"`) or
|
||||
the lowest count first (`"f"`).
|
||||
from_token (AggregationPaginationToken|None): Fetch rows from the
|
||||
given token, or from the start if None.
|
||||
to_token (AggregationPaginationToken|None): Fetch rows up to the
|
||||
given token, or up to the end if None.
|
||||
|
||||
from_token: Fetch rows from the given token, or from the start if None.
|
||||
to_token: Fetch rows up to the given token, or up to the end if None.
|
||||
|
||||
Returns:
|
||||
Deferred[PaginationChunk]: List of groups of annotations that
|
||||
match. Each row is a dict with `type`, `key` and `count` fields.
|
||||
List of groups of annotations that match. Each row is a dict with
|
||||
`type`, `key` and `count` fields.
|
||||
"""
|
||||
|
||||
where_clause = ["relates_to_id = ?", "relation_type = ?"]
|
||||
|
@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
|
||||
)
|
||||
|
||||
|
@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
return await self.get_event(edit_id, allow_none=True)
|
||||
|
||||
def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
|
||||
async def has_user_annotated_event(
|
||||
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
|
||||
) -> bool:
|
||||
"""Check if a user has already annotated an event with the same key
|
||||
(e.g. already liked an event).
|
||||
|
||||
Args:
|
||||
parent_id (str): The event being annotated
|
||||
event_type (str): The event type of the annotation
|
||||
aggregation_key (str): The aggregation key of the annotation
|
||||
sender (str): The sender of the annotation
|
||||
parent_id: The event being annotated
|
||||
event_type: The event type of the annotation
|
||||
aggregation_key: The aggregation key of the annotation
|
||||
sender: The sender of the annotation
|
||||
|
||||
Returns:
|
||||
Deferred[bool]
|
||||
True if the event is already annotated.
|
||||
"""
|
||||
|
||||
sql = """
|
||||
|
@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
return bool(txn.fetchone())
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
|
||||
)
|
||||
|
||||
|
|
|
@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
allow_none=True,
|
||||
)
|
||||
|
||||
def get_room_with_stats(self, room_id: str):
|
||||
async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve room with statistics.
|
||||
|
||||
Args:
|
||||
|
@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
res["public"] = bool(res["public"])
|
||||
return res
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_room_with_stats", get_room_with_stats_txn, room_id
|
||||
)
|
||||
|
||||
|
@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
desc="get_public_room_ids",
|
||||
)
|
||||
|
||||
def count_public_rooms(self, network_tuple, ignore_non_federatable):
|
||||
async def count_public_rooms(
|
||||
self,
|
||||
network_tuple: Optional[ThirdPartyInstanceID],
|
||||
ignore_non_federatable: bool,
|
||||
) -> int:
|
||||
"""Counts the number of public rooms as tracked in the room_stats_current
|
||||
and room_stats_state table.
|
||||
|
||||
Args:
|
||||
network_tuple (ThirdPartyInstanceID|None)
|
||||
ignore_non_federatable (bool): If true filters out non-federatable rooms
|
||||
network_tuple
|
||||
ignore_non_federatable: If true filters out non-federatable rooms
|
||||
"""
|
||||
|
||||
def _count_public_rooms_txn(txn):
|
||||
|
@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, query_args)
|
||||
return txn.fetchone()[0]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"count_public_rooms", _count_public_rooms_txn
|
||||
)
|
||||
|
||||
|
@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
|
||||
return row
|
||||
|
||||
def get_media_mxcs_in_room(self, room_id):
|
||||
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
|
||||
"""Retrieves all the local and remote media MXC URIs in a given room
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id
|
||||
|
||||
Returns:
|
||||
The local and remote media as a lists of tuples where the key is
|
||||
the hostname and the value is the media ID.
|
||||
The local and remote media as a lists of the media IDs.
|
||||
"""
|
||||
|
||||
def _get_media_mxcs_in_room_txn(txn):
|
||||
|
@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
|
||||
return local_media_mxcs, remote_media_mxcs
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
|
||||
)
|
||||
|
||||
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
|
||||
async def quarantine_media_ids_in_room(
|
||||
self, room_id: str, quarantined_by: str
|
||||
) -> int:
|
||||
"""For a room loops through all events with media and quarantines
|
||||
the associated media
|
||||
"""
|
||||
|
@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
txn, local_mxcs, remote_mxcs, quarantined_by
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"quarantine_media_in_room", _quarantine_media_in_room_txn
|
||||
)
|
||||
|
||||
|
@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
|
||||
return local_media_mxcs, remote_media_mxcs
|
||||
|
||||
def quarantine_media_by_id(
|
||||
async def quarantine_media_by_id(
|
||||
self, server_name: str, media_id: str, quarantined_by: str,
|
||||
):
|
||||
) -> int:
|
||||
"""quarantines a single local or remote media id
|
||||
|
||||
Args:
|
||||
|
@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
txn, local_mxcs, remote_mxcs, quarantined_by
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"quarantine_media_by_user", _quarantine_media_by_id_txn
|
||||
)
|
||||
|
||||
def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
|
||||
async def quarantine_media_ids_by_user(
|
||||
self, user_id: str, quarantined_by: str
|
||||
) -> int:
|
||||
"""quarantines all local media associated with a single user
|
||||
|
||||
Args:
|
||||
|
@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
|
||||
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"quarantine_media_by_user", _quarantine_media_by_user_txn
|
||||
)
|
||||
|
||||
|
@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
)
|
||||
self.hs.get_notifier().on_new_replication_data()
|
||||
|
||||
def get_room_count(self):
|
||||
"""Retrieve a list of all rooms
|
||||
async def get_room_count(self) -> int:
|
||||
"""Retrieve the total number of rooms.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
|
@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
row = txn.fetchone()
|
||||
return row[0] or 0
|
||||
|
||||
return self.db_pool.runInteraction("get_rooms", f)
|
||||
return await self.db_pool.runInteraction("get_rooms", f)
|
||||
|
||||
async def add_event_report(
|
||||
self,
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events import EventBase
|
||||
|
@ -152,8 +152,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
)
|
||||
|
||||
@cached(max_entries=100000, iterable=True)
|
||||
def get_users_in_room(self, room_id: str):
|
||||
return self.db_pool.runInteraction(
|
||||
async def get_users_in_room(self, room_id: str) -> List[str]:
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_users_in_room", self.get_users_in_room_txn, room_id
|
||||
)
|
||||
|
||||
|
@ -180,14 +180,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
return [r[0] for r in txn]
|
||||
|
||||
@cached(max_entries=100000)
|
||||
def get_room_summary(self, room_id: str):
|
||||
async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
|
||||
""" Get the details of a room roughly suitable for use by the room
|
||||
summary extension to /sync. Useful when lazy loading room members.
|
||||
Args:
|
||||
room_id: The room ID to query
|
||||
Returns:
|
||||
Deferred[dict[str, MemberSummary]:
|
||||
dict of membership states, pointing to a MemberSummary named tuple.
|
||||
dict of membership states, pointing to a MemberSummary named tuple.
|
||||
"""
|
||||
|
||||
def _get_room_summary_txn(txn):
|
||||
|
@ -261,20 +260,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
return res
|
||||
|
||||
return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_room_summary", _get_room_summary_txn
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
|
||||
async def get_invited_rooms_for_local_user(self, user_id: str) -> RoomsForUser:
|
||||
"""Get all the rooms the *local* user is invited to.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
|
||||
Returns:
|
||||
A awaitable list of RoomsForUser.
|
||||
A list of RoomsForUser.
|
||||
"""
|
||||
|
||||
return self.get_rooms_for_local_user_where_membership_is(
|
||||
return await self.get_rooms_for_local_user_where_membership_is(
|
||||
user_id, [Membership.INVITE]
|
||||
)
|
||||
|
||||
|
@ -297,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
return None
|
||||
|
||||
async def get_rooms_for_local_user_where_membership_is(
|
||||
self, user_id: str, membership_list: List[str]
|
||||
) -> Optional[List[RoomsForUser]]:
|
||||
self, user_id: str, membership_list: Collection[str]
|
||||
) -> List[RoomsForUser]:
|
||||
"""Get all the rooms for this *local* user where the membership for this user
|
||||
matches one in the membership list.
|
||||
|
||||
|
@ -313,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
The RoomsForUser that the user matches the membership types.
|
||||
"""
|
||||
if not membership_list:
|
||||
return None
|
||||
return []
|
||||
|
||||
rooms = await self.db_pool.runInteraction(
|
||||
"get_rooms_for_local_user_where_membership_is",
|
||||
|
@ -357,7 +358,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
return results
|
||||
|
||||
@cached(max_entries=500000, iterable=True)
|
||||
def get_rooms_for_user_with_stream_ordering(self, user_id: str):
|
||||
async def get_rooms_for_user_with_stream_ordering(
|
||||
self, user_id: str
|
||||
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
|
||||
"""Returns a set of room_ids the user is currently joined to.
|
||||
|
||||
If a remote user only returns rooms this server is currently
|
||||
|
@ -367,17 +370,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
user_id
|
||||
|
||||
Returns:
|
||||
Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
|
||||
the rooms the user is in currently, along with the stream ordering
|
||||
of the most recent join for that user and room.
|
||||
Returns the rooms the user is in currently, along with the stream
|
||||
ordering of the most recent join for that user and room.
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_rooms_for_user_with_stream_ordering",
|
||||
self._get_rooms_for_user_with_stream_ordering_txn,
|
||||
user_id,
|
||||
)
|
||||
|
||||
def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
|
||||
def _get_rooms_for_user_with_stream_ordering_txn(
|
||||
self, txn, user_id: str
|
||||
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
|
||||
# We use `current_state_events` here and not `local_current_membership`
|
||||
# as a) this gets called with remote users and b) this only gets called
|
||||
# for rooms the server is participating in.
|
||||
|
@ -404,9 +408,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
"""
|
||||
|
||||
txn.execute(sql, (user_id, Membership.JOIN))
|
||||
results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
|
||||
|
||||
return results
|
||||
return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
|
||||
|
||||
async def get_users_server_still_shares_room_with(
|
||||
self, user_ids: Collection[str]
|
||||
|
@ -711,14 +713,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
return count == 0
|
||||
|
||||
@cached()
|
||||
def get_forgotten_rooms_for_user(self, user_id: str):
|
||||
async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
|
||||
"""Gets all rooms the user has forgotten.
|
||||
|
||||
Args:
|
||||
user_id
|
||||
user_id: The user ID to query the rooms of.
|
||||
|
||||
Returns:
|
||||
Deferred[set[str]]
|
||||
The forgotten rooms.
|
||||
"""
|
||||
|
||||
def _get_forgotten_rooms_for_user_txn(txn):
|
||||
|
@ -744,7 +746,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
txn.execute(sql, (user_id,))
|
||||
return {row[0] for row in txn if row[1] == 0}
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
|
||||
)
|
||||
|
||||
|
@ -973,7 +975,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
|||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(RoomMemberStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
def forget(self, user_id: str, room_id: str):
|
||||
async def forget(self, user_id: str, room_id: str) -> None:
|
||||
"""Indicate that user_id wishes to discard history for room_id."""
|
||||
|
||||
def f(txn):
|
||||
|
@ -994,7 +996,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
|||
txn, self.get_forgotten_rooms_for_user, (user_id,)
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction("forget_membership", f)
|
||||
await self.db_pool.runInteraction("forget_membership", f)
|
||||
|
||||
|
||||
class _JoinedHostsCache(object):
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE events ADD COLUMN instance_name TEXT;
|
|
@ -0,0 +1,26 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
|
||||
|
||||
SELECT setval('events_stream_seq', (
|
||||
SELECT COALESCE(MAX(stream_ordering), 1) FROM events
|
||||
));
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
|
||||
|
||||
SELECT setval('events_backfill_stream_seq', (
|
||||
SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
|
||||
));
|
|
@ -0,0 +1,26 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- We're hijacking the push actions to store unread messages and unread counts (specified
|
||||
-- in MSC2654) because doing otherwise would result in either performance issues or
|
||||
-- reimplementing a consequent bit of the push actions.
|
||||
|
||||
-- Add columns to event_push_actions and event_push_actions_staging to track unread
|
||||
-- messages and calculate unread counts.
|
||||
ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
|
||||
ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
|
||||
|
||||
-- Add column to event_push_summary
|
||||
ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;
|
|
@ -16,9 +16,10 @@
|
|||
import logging
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
|
@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
|
|||
"count": count,
|
||||
}
|
||||
|
||||
def _find_highlights_in_postgres(self, search_query, events):
|
||||
async def _find_highlights_in_postgres(
|
||||
self, search_query: str, events: List[EventBase]
|
||||
) -> Set[str]:
|
||||
"""Given a list of events and a search term, return a list of words
|
||||
that match from the content of the event.
|
||||
|
||||
|
@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
|
|||
highlight the matching parts.
|
||||
|
||||
Args:
|
||||
search_query (str)
|
||||
events (list): A list of events
|
||||
search_query
|
||||
events: A list of events
|
||||
|
||||
Returns:
|
||||
deferred : A set of strings.
|
||||
A set of strings.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
|
@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
|||
|
||||
return highlight_words
|
||||
|
||||
return self.db_pool.runInteraction("_find_highlights", f)
|
||||
return await self.db_pool.runInteraction("_find_highlights", f)
|
||||
|
||||
|
||||
def _to_postgres_options(options_dict):
|
||||
|
|
|
@ -13,9 +13,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Iterable, List, Tuple
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
|
||||
|
@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore):
|
|||
@cachedList(
|
||||
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
|
||||
)
|
||||
def get_event_reference_hashes(self, event_ids):
|
||||
async def get_event_reference_hashes(
|
||||
self, event_ids: Iterable[str]
|
||||
) -> Dict[str, Dict[str, bytes]]:
|
||||
"""Get all hashes for given events.
|
||||
|
||||
Args:
|
||||
event_ids: The event IDs to get hashes for.
|
||||
|
||||
Returns:
|
||||
A mapping of event ID to a mapping of algorithm to hash.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
return {
|
||||
event_id: self._get_event_reference_hashes_txn(txn, event_id)
|
||||
for event_id in event_ids
|
||||
}
|
||||
|
||||
return self.db_pool.runInteraction("get_event_reference_hashes", f)
|
||||
return await self.db_pool.runInteraction("get_event_reference_hashes", f)
|
||||
|
||||
async def add_event_hashes(self, event_ids):
|
||||
async def add_event_hashes(
|
||||
self, event_ids: Iterable[str]
|
||||
) -> List[Tuple[str, Dict[str, str]]]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
event_ids: The event IDs
|
||||
|
||||
Returns:
|
||||
A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
|
||||
"""
|
||||
hashes = await self.get_event_reference_hashes(event_ids)
|
||||
hashes = {
|
||||
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
|
||||
|
@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
|
|||
|
||||
return list(hashes.items())
|
||||
|
||||
def _get_event_reference_hashes_txn(self, txn, event_id):
|
||||
def _get_event_reference_hashes_txn(
|
||||
self, txn: Cursor, event_id: str
|
||||
) -> Dict[str, bytes]:
|
||||
"""Get all the hashes for a given PDU.
|
||||
Args:
|
||||
txn (cursor):
|
||||
event_id (str): Id for the Event.
|
||||
txn:
|
||||
event_id: Id for the Event.
|
||||
Returns:
|
||||
A dict[unicode, bytes] of algorithm -> hash.
|
||||
A mapping of algorithm -> hash.
|
||||
"""
|
||||
query = (
|
||||
"SELECT algorithm, hash"
|
||||
|
|
|
@ -224,14 +224,32 @@ class StatsStore(StateDeltasStore):
|
|||
)
|
||||
|
||||
async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Args:
|
||||
room_id
|
||||
fields
|
||||
"""
|
||||
"""Update the state of a room.
|
||||
|
||||
# For whatever reason some of the fields may contain null bytes, which
|
||||
# postgres isn't a fan of, so we replace those fields with null.
|
||||
fields can contain the following keys with string values:
|
||||
* join_rules
|
||||
* history_visibility
|
||||
* encryption
|
||||
* name
|
||||
* topic
|
||||
* avatar
|
||||
* canonical_alias
|
||||
|
||||
A is_federatable key can also be included with a boolean value.
|
||||
|
||||
Args:
|
||||
room_id: The room ID to update the state of.
|
||||
fields: The fields to update. This can include a partial list of the
|
||||
above fields to only update some room information.
|
||||
"""
|
||||
# Ensure that the values to update are valid, they should be strings and
|
||||
# not contain any null bytes.
|
||||
#
|
||||
# Invalid data gets overwritten with null.
|
||||
#
|
||||
# Note that a missing value should not be overwritten (it keeps the
|
||||
# previous value).
|
||||
sentinel = object()
|
||||
for col in (
|
||||
"join_rules",
|
||||
"history_visibility",
|
||||
|
@ -241,8 +259,8 @@ class StatsStore(StateDeltasStore):
|
|||
"avatar",
|
||||
"canonical_alias",
|
||||
):
|
||||
field = fields.get(col)
|
||||
if field and "\0" in field:
|
||||
field = fields.get(col, sentinel)
|
||||
if field is not sentinel and (not isinstance(field, str) or "\0" in field):
|
||||
fields[col] = None
|
||||
|
||||
await self.db_pool.simple_upsert(
|
||||
|
|
|
@ -39,7 +39,7 @@ what sort order was used:
|
|||
import abc
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
@ -47,12 +47,19 @@ from synapse.api.filtering import Filter
|
|||
from synapse.events import EventBase
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.types import Collection, RoomStreamToken
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -202,7 +209,7 @@ def _make_generic_sql_bound(
|
|||
)
|
||||
|
||||
|
||||
def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
|
||||
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
|
||||
# NB: This may create SQL clauses that don't optimise well (and we don't
|
||||
# have indices on all possible clauses). E.g. it may create
|
||||
# "room_id == X AND room_id != X", which postgres doesn't optimise.
|
||||
|
@ -260,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
@ -293,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_room_max_stream_ordering(self):
|
||||
def get_room_max_stream_ordering(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_room_min_stream_ordering(self):
|
||||
def get_room_min_stream_ordering(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_room_events_stream_for_rooms(
|
||||
self,
|
||||
room_ids: Iterable[str],
|
||||
room_ids: Collection[str],
|
||||
from_key: str,
|
||||
to_key: str,
|
||||
limit: int = 0,
|
||||
|
@ -356,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return results
|
||||
|
||||
def get_rooms_that_changed(self, room_ids, from_key):
|
||||
def get_rooms_that_changed(
|
||||
self, room_ids: Collection[str], from_key: str
|
||||
) -> Set[str]:
|
||||
"""Given a list of rooms and a token, return rooms where there may have
|
||||
been changes.
|
||||
|
||||
Args:
|
||||
room_ids (list)
|
||||
from_key (str): The room_key portion of a StreamToken
|
||||
room_ids
|
||||
from_key: The room_key portion of a StreamToken
|
||||
"""
|
||||
from_key = RoomStreamToken.parse_stream_token(from_key).stream
|
||||
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||
return {
|
||||
room_id
|
||||
for room_id in room_ids
|
||||
if self._events_stream_cache.has_entity_changed(room_id, from_key)
|
||||
if self._events_stream_cache.has_entity_changed(room_id, from_id)
|
||||
}
|
||||
|
||||
async def get_room_events_stream_for_room(
|
||||
|
@ -440,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
return ret, key
|
||||
|
||||
async def get_membership_changes_for_user(self, user_id, from_key, to_key):
|
||||
async def get_membership_changes_for_user(
|
||||
self, user_id: str, from_key: str, to_key: str
|
||||
) -> List[EventBase]:
|
||||
from_id = RoomStreamToken.parse_stream_token(from_key).stream
|
||||
to_id = RoomStreamToken.parse_stream_token(to_key).stream
|
||||
|
||||
|
@ -593,8 +604,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
Returns:
|
||||
A stream ID.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
|
||||
)
|
||||
|
||||
def get_stream_id_for_event_txn(
|
||||
self, txn: LoggingTransaction, event_id: str, allow_none=False,
|
||||
) -> int:
|
||||
return self.db_pool.simple_select_one_onecol_txn(
|
||||
txn=txn,
|
||||
table="events",
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="stream_ordering",
|
||||
allow_none=allow_none,
|
||||
)
|
||||
|
||||
async def get_stream_token_for_event(self, event_id: str) -> str:
|
||||
|
@ -646,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
)
|
||||
return row[0][0] if row else 0
|
||||
|
||||
def _get_max_topological_txn(self, txn, room_id):
|
||||
def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
|
||||
txn.execute(
|
||||
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
|
||||
(room_id,),
|
||||
|
@ -719,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
def _get_events_around_txn(
|
||||
self,
|
||||
txn,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
before_limit: int,
|
||||
|
@ -747,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
retcols=["stream_ordering", "topological_ordering"],
|
||||
)
|
||||
|
||||
# This cannot happen as `allow_none=False`.
|
||||
assert results is not None
|
||||
|
||||
# Paginating backwards includes the event at the token, but paginating
|
||||
# forward doesn't.
|
||||
before_token = RoomStreamToken(
|
||||
|
@ -856,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
desc="update_federation_out_pos",
|
||||
)
|
||||
|
||||
def _reset_federation_positions_txn(self, txn) -> None:
|
||||
def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
|
||||
"""Fiddles with the `federation_stream_position` table to make it match
|
||||
the configured federation sender instances during start up.
|
||||
"""
|
||||
|
@ -895,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
GROUP BY type
|
||||
"""
|
||||
txn.execute(sql)
|
||||
min_positions = dict(txn) # Map from type -> min position
|
||||
min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
|
||||
|
||||
# Ensure we do actually have some values here
|
||||
assert set(min_positions) == {"federation", "events"}
|
||||
|
@ -922,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
|
||||
def _paginate_room_events_txn(
|
||||
self,
|
||||
txn,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
from_token: RoomStreamToken,
|
||||
to_token: Optional[RoomStreamToken] = None,
|
||||
|
|
|
@ -43,7 +43,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
|
||||
)
|
||||
|
||||
tags_by_room = {}
|
||||
tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
|
||||
for row in rows:
|
||||
room_tags = tags_by_room.setdefault(row["room_id"], {})
|
||||
room_tags[row["tag"]] = db_to_json(row["content"])
|
||||
|
@ -123,7 +123,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
|||
|
||||
async def get_updated_tags(
|
||||
self, user_id: str, stream_id: int
|
||||
) -> Dict[str, List[str]]:
|
||||
) -> Dict[str, Dict[str, JsonDict]]:
|
||||
"""Get all the tags for the rooms where the tags have changed since the
|
||||
given version
|
||||
|
||||
|
|
|
@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
|
||||
|
||||
class UIAuthStore(UIAuthWorkerStore):
|
||||
def delete_old_ui_auth_sessions(self, expiration_time: int):
|
||||
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
|
||||
"""
|
||||
Remove sessions which were last used earlier than the expiration time.
|
||||
|
||||
|
@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore):
|
|||
This is an epoch time in milliseconds.
|
||||
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_old_ui_auth_sessions",
|
||||
self._delete_old_ui_auth_sessions_txn,
|
||||
expiration_time,
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Iterable, Optional, Set, Tuple
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
@ -365,10 +365,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
|
||||
return False
|
||||
|
||||
def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
|
||||
async def update_profile_in_user_dir(
|
||||
self, user_id: str, display_name: str, avatar_url: str
|
||||
) -> None:
|
||||
"""
|
||||
Update or add a user's profile in the user directory.
|
||||
"""
|
||||
# If the display name or avatar URL are unexpected types, overwrite them.
|
||||
if not isinstance(display_name, str):
|
||||
display_name = None
|
||||
if not isinstance(avatar_url, str):
|
||||
avatar_url = None
|
||||
|
||||
def _update_profile_in_user_dir_txn(txn):
|
||||
new_entry = self.db_pool.simple_upsert_txn(
|
||||
|
@ -458,17 +465,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
|
||||
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
|
||||
)
|
||||
|
||||
def add_users_who_share_private_room(self, room_id, user_id_tuples):
|
||||
async def add_users_who_share_private_room(
|
||||
self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
|
||||
) -> None:
|
||||
"""Insert entries into the users_who_share_private_rooms table. The first
|
||||
user should be a local user.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
|
||||
room_id
|
||||
user_id_tuples: iterable of 2-tuple of user IDs.
|
||||
"""
|
||||
|
||||
def _add_users_who_share_room_txn(txn):
|
||||
|
@ -484,17 +493,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
value_values=None,
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"add_users_who_share_room", _add_users_who_share_room_txn
|
||||
)
|
||||
|
||||
def add_users_in_public_rooms(self, room_id, user_ids):
|
||||
async def add_users_in_public_rooms(
|
||||
self, room_id: str, user_ids: Iterable[str]
|
||||
) -> None:
|
||||
"""Insert entries into the users_who_share_private_rooms table. The first
|
||||
user should be a local user.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
user_ids (list[str])
|
||||
room_id
|
||||
user_ids
|
||||
"""
|
||||
|
||||
def _add_users_in_public_rooms_txn(txn):
|
||||
|
@ -508,11 +519,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
value_values=None,
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
|
||||
)
|
||||
|
||||
def delete_all_from_user_dir(self):
|
||||
async def delete_all_from_user_dir(self) -> None:
|
||||
"""Delete the entire user directory
|
||||
"""
|
||||
|
||||
|
@ -523,7 +534,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
txn.execute("DELETE FROM users_who_share_private_rooms")
|
||||
txn.call_after(self.get_user_in_directory.invalidate_all)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
|
||||
)
|
||||
|
||||
|
@ -555,7 +566,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
|||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(UserDirectoryStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
def remove_from_user_dir(self, user_id):
|
||||
async def remove_from_user_dir(self, user_id: str) -> None:
|
||||
def _remove_from_user_dir_txn(txn):
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="user_directory", keyvalues={"user_id": user_id}
|
||||
|
@ -578,7 +589,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
|||
)
|
||||
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"remove_from_user_dir", _remove_from_user_dir_txn
|
||||
)
|
||||
|
||||
|
@ -605,14 +616,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
|||
|
||||
return user_ids
|
||||
|
||||
def remove_user_who_share_room(self, user_id, room_id):
|
||||
async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
|
||||
"""
|
||||
Deletes entries in the users_who_share_*_rooms table. The first
|
||||
user should be a local user.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
room_id (str)
|
||||
user_id
|
||||
room_id
|
||||
"""
|
||||
|
||||
def _remove_user_who_share_room_txn(txn):
|
||||
|
@ -632,7 +643,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
|||
keyvalues={"user_id": user_id, "room_id": room_id},
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"remove_user_who_share_room", _remove_user_who_share_room_txn
|
||||
)
|
||||
|
||||
|
@ -664,6 +675,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
|||
users.update(rows)
|
||||
return list(users)
|
||||
|
||||
@cached()
|
||||
async def get_shared_rooms_for_users(
|
||||
self, user_id: str, other_user_id: str
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Returns the rooms that a local user shares with another local or remote user.
|
||||
|
||||
Args:
|
||||
user_id: The MXID of a local user
|
||||
other_user_id: The MXID of the other user
|
||||
|
||||
Returns:
|
||||
A set of room ID's that the users share.
|
||||
"""
|
||||
|
||||
def _get_shared_rooms_for_users_txn(txn):
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT p1.room_id
|
||||
FROM users_in_public_rooms as p1
|
||||
INNER JOIN users_in_public_rooms as p2
|
||||
ON p1.room_id = p2.room_id
|
||||
AND p1.user_id = ?
|
||||
AND p2.user_id = ?
|
||||
UNION
|
||||
SELECT room_id
|
||||
FROM users_who_share_private_rooms
|
||||
WHERE
|
||||
user_id = ?
|
||||
AND other_user_id = ?
|
||||
""",
|
||||
(user_id, other_user_id, user_id, other_user_id),
|
||||
)
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
return rows
|
||||
|
||||
rows = await self.db_pool.runInteraction(
|
||||
"get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
|
||||
)
|
||||
|
||||
return {row["room_id"] for row in rows}
|
||||
|
||||
async def get_user_directory_stream_pos(self) -> int:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="user_directory_stream_pos",
|
||||
|
|
|
@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
|
|||
|
||||
|
||||
class UserErasureStore(UserErasureWorkerStore):
|
||||
def mark_user_erased(self, user_id: str) -> None:
|
||||
async def mark_user_erased(self, user_id: str) -> None:
|
||||
"""Indicate that user_id wishes their message history to be erased.
|
||||
|
||||
Args:
|
||||
|
@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore):
|
|||
|
||||
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
|
||||
|
||||
return self.db_pool.runInteraction("mark_user_erased", f)
|
||||
await self.db_pool.runInteraction("mark_user_erased", f)
|
||||
|
||||
def mark_user_not_erased(self, user_id: str) -> None:
|
||||
async def mark_user_not_erased(self, user_id: str) -> None:
|
||||
"""Indicate that user_id is no longer erased.
|
||||
|
||||
Args:
|
||||
|
@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore):
|
|||
|
||||
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
|
||||
|
||||
return self.db_pool.runInteraction("mark_user_not_erased", f)
|
||||
await self.db_pool.runInteraction("mark_user_not_erased", f)
|
||||
|
|
|
@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
|
|||
id_column: Column that stores the stream ID.
|
||||
sequence_name: The name of the postgres sequence used to generate new
|
||||
IDs.
|
||||
positive: Whether the IDs are positive (true) or negative (false).
|
||||
When using negative IDs we go backwards from -1 to -2, -3, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -196,13 +198,19 @@ class MultiWriterIdGenerator:
|
|||
instance_column: str,
|
||||
id_column: str,
|
||||
sequence_name: str,
|
||||
positive: bool = True,
|
||||
):
|
||||
self._db = db
|
||||
self._instance_name = instance_name
|
||||
self._positive = positive
|
||||
self._return_factor = 1 if positive else -1
|
||||
|
||||
# We lock as some functions may be called from DB threads.
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Note: If we are a negative stream then we still store all the IDs as
|
||||
# positive to make life easier for us, and simply negate the IDs when we
|
||||
# return them.
|
||||
self._current_positions = self._load_current_ids(
|
||||
db_conn, table, instance_column, id_column
|
||||
)
|
||||
|
@ -223,8 +231,12 @@ class MultiWriterIdGenerator:
|
|||
# gaps should be relatively rare it's still worth doing the book keeping
|
||||
# that allows us to skip forwards when there are gapless runs of
|
||||
# positions.
|
||||
#
|
||||
# We start at 1 here as a) the first generated stream ID will be 2, and
|
||||
# b) other parts of the code assume that stream IDs are strictly greater
|
||||
# than 0.
|
||||
self._persisted_upto_position = (
|
||||
min(self._current_positions.values()) if self._current_positions else 0
|
||||
min(self._current_positions.values()) if self._current_positions else 1
|
||||
)
|
||||
self._known_persisted_positions = [] # type: List[int]
|
||||
|
||||
|
@ -233,13 +245,16 @@ class MultiWriterIdGenerator:
|
|||
def _load_current_ids(
|
||||
self, db_conn, table: str, instance_column: str, id_column: str
|
||||
) -> Dict[str, int]:
|
||||
# If positive stream aggregate via MAX. For negative stream use MIN
|
||||
# *and* negate the result to get a positive number.
|
||||
sql = """
|
||||
SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
|
||||
SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
|
||||
GROUP BY %(instance)s
|
||||
""" % {
|
||||
"instance": instance_column,
|
||||
"id": id_column,
|
||||
"table": table,
|
||||
"agg": "MAX" if self._positive else "-MIN",
|
||||
}
|
||||
|
||||
cur = db_conn.cursor()
|
||||
|
@ -269,15 +284,16 @@ class MultiWriterIdGenerator:
|
|||
# Assert the fetched ID is actually greater than what we currently
|
||||
# believe the ID to be. If not, then the sequence and table have got
|
||||
# out of sync somehow.
|
||||
assert self.get_current_token_for_writer(self._instance_name) < next_id
|
||||
|
||||
with self._lock:
|
||||
assert self._current_positions.get(self._instance_name, 0) < next_id
|
||||
|
||||
self._unfinished_ids.add(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_id
|
||||
# Multiply by the return factor so that the ID has correct sign.
|
||||
yield self._return_factor * next_id
|
||||
finally:
|
||||
self._mark_id_as_finished(next_id)
|
||||
|
||||
|
@ -296,15 +312,15 @@ class MultiWriterIdGenerator:
|
|||
# Assert the fetched ID is actually greater than any ID we've already
|
||||
# seen. If not, then the sequence and table have got out of sync
|
||||
# somehow.
|
||||
assert max(self.get_positions().values(), default=0) < min(next_ids)
|
||||
|
||||
with self._lock:
|
||||
assert max(self._current_positions.values(), default=0) < min(next_ids)
|
||||
|
||||
self._unfinished_ids.update(next_ids)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_ids
|
||||
yield [self._return_factor * i for i in next_ids]
|
||||
finally:
|
||||
for i in next_ids:
|
||||
self._mark_id_as_finished(i)
|
||||
|
@ -327,7 +343,7 @@ class MultiWriterIdGenerator:
|
|||
txn.call_after(self._mark_id_as_finished, next_id)
|
||||
txn.call_on_exception(self._mark_id_as_finished, next_id)
|
||||
|
||||
return next_id
|
||||
return self._return_factor * next_id
|
||||
|
||||
def _mark_id_as_finished(self, next_id: int):
|
||||
"""The ID has finished being processed so we should advance the
|
||||
|
@ -350,29 +366,32 @@ class MultiWriterIdGenerator:
|
|||
equal to it have been successfully persisted.
|
||||
"""
|
||||
|
||||
# Currently we don't support this operation, as it's not obvious how to
|
||||
# condense the stream positions of multiple writers into a single int.
|
||||
raise NotImplementedError()
|
||||
return self.get_persisted_upto_position()
|
||||
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
"""Returns the position of the given writer.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return self._current_positions.get(instance_name, 0)
|
||||
return self._return_factor * self._current_positions.get(instance_name, 0)
|
||||
|
||||
def get_positions(self) -> Dict[str, int]:
|
||||
"""Get a copy of the current positon map.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return dict(self._current_positions)
|
||||
return {
|
||||
name: self._return_factor * i
|
||||
for name, i in self._current_positions.items()
|
||||
}
|
||||
|
||||
def advance(self, instance_name: str, new_id: int):
|
||||
"""Advance the postion of the named writer to the given ID, if greater
|
||||
than existing entry.
|
||||
"""
|
||||
|
||||
new_id *= self._return_factor
|
||||
|
||||
with self._lock:
|
||||
self._current_positions[instance_name] = max(
|
||||
new_id, self._current_positions.get(instance_name, 0)
|
||||
|
@ -390,7 +409,7 @@ class MultiWriterIdGenerator:
|
|||
"""
|
||||
|
||||
with self._lock:
|
||||
return self._persisted_upto_position
|
||||
return self._return_factor * self._persisted_upto_position
|
||||
|
||||
def _add_persisted_position(self, new_id: int):
|
||||
"""Record that we have persisted a position.
|
||||
|
|
|
@ -20,6 +20,7 @@ from contextlib import contextmanager
|
|||
from typing import Dict, Sequence, Set, Union
|
||||
|
||||
import attr
|
||||
from typing_extensions import ContextManager
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import CancelledError
|
||||
|
@ -338,11 +339,11 @@ class Linearizer(object):
|
|||
|
||||
|
||||
class ReadWriteLock(object):
|
||||
"""A deferred style read write lock.
|
||||
"""An async read write lock.
|
||||
|
||||
Example:
|
||||
|
||||
with (yield read_write_lock.read("test_key")):
|
||||
with await read_write_lock.read("test_key"):
|
||||
# do some work
|
||||
"""
|
||||
|
||||
|
@ -365,8 +366,7 @@ class ReadWriteLock(object):
|
|||
# Latest writer queued
|
||||
self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def read(self, key):
|
||||
async def read(self, key: str) -> ContextManager:
|
||||
new_defer = defer.Deferred()
|
||||
|
||||
curr_readers = self.key_to_current_readers.setdefault(key, set())
|
||||
|
@ -376,7 +376,8 @@ class ReadWriteLock(object):
|
|||
|
||||
# We wait for the latest writer to finish writing. We can safely ignore
|
||||
# any existing readers... as they're readers.
|
||||
yield make_deferred_yieldable(curr_writer)
|
||||
if curr_writer:
|
||||
await make_deferred_yieldable(curr_writer)
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
|
@ -388,8 +389,7 @@ class ReadWriteLock(object):
|
|||
|
||||
return _ctx_manager()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def write(self, key):
|
||||
async def write(self, key: str) -> ContextManager:
|
||||
new_defer = defer.Deferred()
|
||||
|
||||
curr_readers = self.key_to_current_readers.get(key, set())
|
||||
|
@ -405,7 +405,7 @@ class ReadWriteLock(object):
|
|||
curr_readers.clear()
|
||||
self.key_to_current_writer[key] = new_defer
|
||||
|
||||
yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
||||
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue