|
|
|
@ -15,9 +15,7 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import logging
|
|
|
|
|
from typing import List, Optional, Set, Tuple
|
|
|
|
|
|
|
|
|
|
from twisted.internet import defer
|
|
|
|
|
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
|
|
|
|
|
|
|
|
|
from synapse.api.errors import Codes, StoreError
|
|
|
|
|
from synapse.logging.opentracing import (
|
|
|
|
@ -33,14 +31,9 @@ from synapse.storage.database import (
|
|
|
|
|
LoggingTransaction,
|
|
|
|
|
make_tuple_comparison_clause,
|
|
|
|
|
)
|
|
|
|
|
from synapse.types import Collection, get_verify_key_from_cross_signing_key
|
|
|
|
|
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
|
|
|
|
|
from synapse.util import json_encoder
|
|
|
|
|
from synapse.util.caches.descriptors import (
|
|
|
|
|
Cache,
|
|
|
|
|
cached,
|
|
|
|
|
cachedInlineCallbacks,
|
|
|
|
|
cachedList,
|
|
|
|
|
)
|
|
|
|
|
from synapse.util.caches.descriptors import Cache, cached, cachedList
|
|
|
|
|
from synapse.util.iterutils import batch_iter
|
|
|
|
|
from synapse.util.stringutils import shortstr
|
|
|
|
|
|
|
|
|
@ -54,13 +47,13 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
def get_device(self, user_id, device_id):
|
|
|
|
|
def get_device(self, user_id: str, device_id: str):
|
|
|
|
|
"""Retrieve a device. Only returns devices that are not marked as
|
|
|
|
|
hidden.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str): The ID of the user which owns the device
|
|
|
|
|
device_id (str): The ID of the device to retrieve
|
|
|
|
|
user_id: The ID of the user which owns the device
|
|
|
|
|
device_id: The ID of the device to retrieve
|
|
|
|
|
Returns:
|
|
|
|
|
defer.Deferred for a dict containing the device information
|
|
|
|
|
Raises:
|
|
|
|
@ -73,19 +66,17 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
desc="get_device",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_devices_by_user(self, user_id):
|
|
|
|
|
async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
|
|
|
|
|
"""Retrieve all of a user's registered devices. Only returns devices
|
|
|
|
|
that are not marked as hidden.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str):
|
|
|
|
|
user_id:
|
|
|
|
|
Returns:
|
|
|
|
|
defer.Deferred: resolves to a dict from device_id to a dict
|
|
|
|
|
containing "device_id", "user_id" and "display_name" for each
|
|
|
|
|
device.
|
|
|
|
|
A mapping from device_id to a dict containing "device_id", "user_id"
|
|
|
|
|
and "display_name" for each device.
|
|
|
|
|
"""
|
|
|
|
|
devices = yield self.db_pool.simple_select_list(
|
|
|
|
|
devices = await self.db_pool.simple_select_list(
|
|
|
|
|
table="devices",
|
|
|
|
|
keyvalues={"user_id": user_id, "hidden": False},
|
|
|
|
|
retcols=("user_id", "device_id", "display_name"),
|
|
|
|
@ -95,19 +86,20 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
return {d["device_id"]: d for d in devices}
|
|
|
|
|
|
|
|
|
|
@trace
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_device_updates_by_remote(self, destination, from_stream_id, limit):
|
|
|
|
|
async def get_device_updates_by_remote(
|
|
|
|
|
self, destination: str, from_stream_id: int, limit: int
|
|
|
|
|
) -> Tuple[int, List[Tuple[str, dict]]]:
|
|
|
|
|
"""Get a stream of device updates to send to the given remote server.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
destination (str): The host the device updates are intended for
|
|
|
|
|
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
|
|
|
|
limit (int): Maximum number of device updates to return
|
|
|
|
|
destination: The host the device updates are intended for
|
|
|
|
|
from_stream_id: The minimum stream_id to filter updates by, exclusive
|
|
|
|
|
limit: Maximum number of device updates to return
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred[tuple[int, list[tuple[string,dict]]]]:
|
|
|
|
|
current stream id (ie, the stream id of the last update included in the
|
|
|
|
|
response), and the list of updates, where each update is a pair of EDU
|
|
|
|
|
type and EDU contents
|
|
|
|
|
A mapping from the current stream id (ie, the stream id of the last
|
|
|
|
|
update included in the response), and the list of updates, where
|
|
|
|
|
each update is a pair of EDU type and EDU contents.
|
|
|
|
|
"""
|
|
|
|
|
now_stream_id = self._device_list_id_gen.get_current_token()
|
|
|
|
|
|
|
|
|
@ -117,7 +109,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
if not has_changed:
|
|
|
|
|
return now_stream_id, []
|
|
|
|
|
|
|
|
|
|
updates = yield self.db_pool.runInteraction(
|
|
|
|
|
updates = await self.db_pool.runInteraction(
|
|
|
|
|
"get_device_updates_by_remote",
|
|
|
|
|
self._get_device_updates_by_remote_txn,
|
|
|
|
|
destination,
|
|
|
|
@ -136,9 +128,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
master_key_by_user = {}
|
|
|
|
|
self_signing_key_by_user = {}
|
|
|
|
|
for user in users:
|
|
|
|
|
cross_signing_key = yield defer.ensureDeferred(
|
|
|
|
|
self.get_e2e_cross_signing_key(user, "master")
|
|
|
|
|
)
|
|
|
|
|
cross_signing_key = await self.get_e2e_cross_signing_key(user, "master")
|
|
|
|
|
if cross_signing_key:
|
|
|
|
|
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
|
|
|
|
cross_signing_key
|
|
|
|
@ -151,8 +141,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
"device_id": verify_key.version,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cross_signing_key = yield defer.ensureDeferred(
|
|
|
|
|
self.get_e2e_cross_signing_key(user, "self_signing")
|
|
|
|
|
cross_signing_key = await self.get_e2e_cross_signing_key(
|
|
|
|
|
user, "self_signing"
|
|
|
|
|
)
|
|
|
|
|
if cross_signing_key:
|
|
|
|
|
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
|
|
|
@ -202,7 +192,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
if update_stream_id > previous_update_stream_id:
|
|
|
|
|
query_map[key] = (update_stream_id, update_context)
|
|
|
|
|
|
|
|
|
|
results = yield self._get_device_update_edus_by_remote(
|
|
|
|
|
results = await self._get_device_update_edus_by_remote(
|
|
|
|
|
destination, from_stream_id, query_map
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -215,16 +205,21 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
return now_stream_id, results
|
|
|
|
|
|
|
|
|
|
def _get_device_updates_by_remote_txn(
|
|
|
|
|
self, txn, destination, from_stream_id, now_stream_id, limit
|
|
|
|
|
self,
|
|
|
|
|
txn: LoggingTransaction,
|
|
|
|
|
destination: str,
|
|
|
|
|
from_stream_id: int,
|
|
|
|
|
now_stream_id: int,
|
|
|
|
|
limit: int,
|
|
|
|
|
):
|
|
|
|
|
"""Return device update information for a given remote destination
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
txn (LoggingTransaction): The transaction to execute
|
|
|
|
|
destination (str): The host the device updates are intended for
|
|
|
|
|
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
|
|
|
|
now_stream_id (int): The maximum stream_id to filter updates by, inclusive
|
|
|
|
|
limit (int): Maximum number of device updates to return
|
|
|
|
|
txn: The transaction to execute
|
|
|
|
|
destination: The host the device updates are intended for
|
|
|
|
|
from_stream_id: The minimum stream_id to filter updates by, exclusive
|
|
|
|
|
now_stream_id: The maximum stream_id to filter updates by, inclusive
|
|
|
|
|
limit: Maximum number of device updates to return
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List: List of device updates
|
|
|
|
@ -240,23 +235,26 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
|
|
|
|
|
return list(txn)
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map):
|
|
|
|
|
async def _get_device_update_edus_by_remote(
|
|
|
|
|
self,
|
|
|
|
|
destination: str,
|
|
|
|
|
from_stream_id: int,
|
|
|
|
|
query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]],
|
|
|
|
|
) -> List[Tuple[str, dict]]:
|
|
|
|
|
"""Returns a list of device update EDUs as well as E2EE keys
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
destination (str): The host the device updates are intended for
|
|
|
|
|
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
|
|
|
|
destination: The host the device updates are intended for
|
|
|
|
|
from_stream_id: The minimum stream_id to filter updates by, exclusive
|
|
|
|
|
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
|
|
|
|
|
user_id/device_id to update stream_id and the relevant json-encoded
|
|
|
|
|
opentracing context
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List[Dict]: List of objects representing an device update EDU
|
|
|
|
|
|
|
|
|
|
List of objects representing an device update EDU
|
|
|
|
|
"""
|
|
|
|
|
devices = (
|
|
|
|
|
yield self.db_pool.runInteraction(
|
|
|
|
|
await self.db_pool.runInteraction(
|
|
|
|
|
"_get_e2e_device_keys_txn",
|
|
|
|
|
self._get_e2e_device_keys_txn,
|
|
|
|
|
query_map.keys(),
|
|
|
|
@ -271,7 +269,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
for user_id, user_devices in devices.items():
|
|
|
|
|
# The prev_id for the first row is always the last row before
|
|
|
|
|
# `from_stream_id`
|
|
|
|
|
prev_id = yield self._get_last_device_update_for_remote_user(
|
|
|
|
|
prev_id = await self._get_last_device_update_for_remote_user(
|
|
|
|
|
destination, user_id, from_stream_id
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -315,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
def _get_last_device_update_for_remote_user(
|
|
|
|
|
self, destination, user_id, from_stream_id
|
|
|
|
|
self, destination: str, user_id: str, from_stream_id: int
|
|
|
|
|
):
|
|
|
|
|
def f(txn):
|
|
|
|
|
prev_sent_id_sql = """
|
|
|
|
@ -329,7 +327,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
|
|
|
|
|
return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
|
|
|
|
|
|
|
|
|
|
def mark_as_sent_devices_by_remote(self, destination, stream_id):
|
|
|
|
|
def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
|
|
|
|
|
"""Mark that updates have successfully been sent to the destination.
|
|
|
|
|
"""
|
|
|
|
|
return self.db_pool.runInteraction(
|
|
|
|
@ -339,7 +337,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
stream_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
|
|
|
|
|
def _mark_as_sent_devices_by_remote_txn(
|
|
|
|
|
self, txn: LoggingTransaction, destination: str, stream_id: int
|
|
|
|
|
) -> None:
|
|
|
|
|
# We update the device_lists_outbound_last_success with the successfully
|
|
|
|
|
# poked users.
|
|
|
|
|
sql = """
|
|
|
|
@ -367,17 +367,21 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
"""
|
|
|
|
|
txn.execute(sql, (destination, stream_id))
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def add_user_signature_change_to_streams(self, from_user_id, user_ids):
|
|
|
|
|
async def add_user_signature_change_to_streams(
|
|
|
|
|
self, from_user_id: str, user_ids: List[str]
|
|
|
|
|
) -> int:
|
|
|
|
|
"""Persist that a user has made new signatures
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
from_user_id (str): the user who made the signatures
|
|
|
|
|
user_ids (list[str]): the users who were signed
|
|
|
|
|
from_user_id: the user who made the signatures
|
|
|
|
|
user_ids: the users who were signed
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
THe new stream ID.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
with self._device_list_id_gen.get_next() as stream_id:
|
|
|
|
|
yield self.db_pool.runInteraction(
|
|
|
|
|
await self.db_pool.runInteraction(
|
|
|
|
|
"add_user_sig_change_to_streams",
|
|
|
|
|
self._add_user_signature_change_txn,
|
|
|
|
|
from_user_id,
|
|
|
|
@ -386,7 +390,13 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
)
|
|
|
|
|
return stream_id
|
|
|
|
|
|
|
|
|
|
def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
|
|
|
|
|
def _add_user_signature_change_txn(
|
|
|
|
|
self,
|
|
|
|
|
txn: LoggingTransaction,
|
|
|
|
|
from_user_id: str,
|
|
|
|
|
user_ids: List[str],
|
|
|
|
|
stream_id: int,
|
|
|
|
|
) -> None:
|
|
|
|
|
txn.call_after(
|
|
|
|
|
self._user_signature_stream_cache.entity_has_changed,
|
|
|
|
|
from_user_id,
|
|
|
|
@ -402,29 +412,30 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def get_device_stream_token(self):
|
|
|
|
|
def get_device_stream_token(self) -> int:
|
|
|
|
|
return self._device_list_id_gen.get_current_token()
|
|
|
|
|
|
|
|
|
|
@trace
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_user_devices_from_cache(self, query_list):
|
|
|
|
|
async def get_user_devices_from_cache(
|
|
|
|
|
self, query_list: List[Tuple[str, str]]
|
|
|
|
|
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
|
|
|
|
|
"""Get the devices (and keys if any) for remote users from the cache.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
query_list(list): List of (user_id, device_ids), if device_ids is
|
|
|
|
|
query_list: List of (user_id, device_ids), if device_ids is
|
|
|
|
|
falsey then return all device ids for that user.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
|
|
|
|
|
a set of user_ids and results_map is a mapping of
|
|
|
|
|
user_id -> device_id -> device_info
|
|
|
|
|
A tuple of (user_ids_not_in_cache, results_map), where
|
|
|
|
|
user_ids_not_in_cache is a set of user_ids and results_map is a
|
|
|
|
|
mapping of user_id -> device_id -> device_info.
|
|
|
|
|
"""
|
|
|
|
|
user_ids = {user_id for user_id, _ in query_list}
|
|
|
|
|
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
|
|
|
|
|
user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
|
|
|
|
|
|
|
|
|
|
# We go and check if any of the users need to have their device lists
|
|
|
|
|
# resynced. If they do then we remove them from the cached list.
|
|
|
|
|
users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
|
|
|
|
|
users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
|
|
|
|
|
user_ids
|
|
|
|
|
)
|
|
|
|
|
user_ids_in_cache = {
|
|
|
|
@ -438,19 +449,19 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if device_id:
|
|
|
|
|
device = yield self._get_cached_user_device(user_id, device_id)
|
|
|
|
|
device = await self._get_cached_user_device(user_id, device_id)
|
|
|
|
|
results.setdefault(user_id, {})[device_id] = device
|
|
|
|
|
else:
|
|
|
|
|
results[user_id] = yield self.get_cached_devices_for_user(user_id)
|
|
|
|
|
results[user_id] = await self.get_cached_devices_for_user(user_id)
|
|
|
|
|
|
|
|
|
|
set_tag("in_cache", results)
|
|
|
|
|
set_tag("not_in_cache", user_ids_not_in_cache)
|
|
|
|
|
|
|
|
|
|
return user_ids_not_in_cache, results
|
|
|
|
|
|
|
|
|
|
@cachedInlineCallbacks(num_args=2, tree=True)
|
|
|
|
|
def _get_cached_user_device(self, user_id, device_id):
|
|
|
|
|
content = yield self.db_pool.simple_select_one_onecol(
|
|
|
|
|
@cached(num_args=2, tree=True)
|
|
|
|
|
async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
|
|
|
|
|
content = await self.db_pool.simple_select_one_onecol(
|
|
|
|
|
table="device_lists_remote_cache",
|
|
|
|
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
|
|
|
|
retcol="content",
|
|
|
|
@ -458,9 +469,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
)
|
|
|
|
|
return db_to_json(content)
|
|
|
|
|
|
|
|
|
|
@cachedInlineCallbacks()
|
|
|
|
|
def get_cached_devices_for_user(self, user_id):
|
|
|
|
|
devices = yield self.db_pool.simple_select_list(
|
|
|
|
|
@cached()
|
|
|
|
|
async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
|
|
|
|
|
devices = await self.db_pool.simple_select_list(
|
|
|
|
|
table="device_lists_remote_cache",
|
|
|
|
|
keyvalues={"user_id": user_id},
|
|
|
|
|
retcols=("device_id", "content"),
|
|
|
|
@ -470,11 +481,11 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
device["device_id"]: db_to_json(device["content"]) for device in devices
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def get_devices_with_keys_by_user(self, user_id):
|
|
|
|
|
def get_devices_with_keys_by_user(self, user_id: str):
|
|
|
|
|
"""Get all devices (with any device keys) for a user
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
(stream_id, devices)
|
|
|
|
|
Deferred which resolves to (stream_id, devices)
|
|
|
|
|
"""
|
|
|
|
|
return self.db_pool.runInteraction(
|
|
|
|
|
"get_devices_with_keys_by_user",
|
|
|
|
@ -482,7 +493,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
user_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
|
|
|
|
|
def _get_devices_with_keys_by_user_txn(
|
|
|
|
|
self, txn: LoggingTransaction, user_id: str
|
|
|
|
|
) -> Tuple[int, List[JsonDict]]:
|
|
|
|
|
now_stream_id = self._device_list_id_gen.get_current_token()
|
|
|
|
|
|
|
|
|
|
devices = self._get_e2e_device_keys_txn(
|
|
|
|
@ -515,17 +528,18 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
|
|
|
|
|
return now_stream_id, []
|
|
|
|
|
|
|
|
|
|
def get_users_whose_devices_changed(self, from_key, user_ids):
|
|
|
|
|
async def get_users_whose_devices_changed(
|
|
|
|
|
self, from_key: str, user_ids: Iterable[str]
|
|
|
|
|
) -> Set[str]:
|
|
|
|
|
"""Get set of users whose devices have changed since `from_key` that
|
|
|
|
|
are in the given list of user_ids.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
from_key (str): The device lists stream token
|
|
|
|
|
user_ids (Iterable[str])
|
|
|
|
|
from_key: The device lists stream token
|
|
|
|
|
user_ids: The user IDs to query for devices.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred[set[str]]: The set of user_ids whose devices have changed
|
|
|
|
|
since `from_key`
|
|
|
|
|
The set of user_ids whose devices have changed since `from_key`
|
|
|
|
|
"""
|
|
|
|
|
from_key = int(from_key)
|
|
|
|
|
|
|
|
|
@ -536,7 +550,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not to_check:
|
|
|
|
|
return defer.succeed(set())
|
|
|
|
|
return set()
|
|
|
|
|
|
|
|
|
|
def _get_users_whose_devices_changed_txn(txn):
|
|
|
|
|
changes = set()
|
|
|
|
@ -556,18 +570,22 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
|
|
|
|
|
return changes
|
|
|
|
|
|
|
|
|
|
return self.db_pool.runInteraction(
|
|
|
|
|
return await self.db_pool.runInteraction(
|
|
|
|
|
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_users_whose_signatures_changed(self, user_id, from_key):
|
|
|
|
|
async def get_users_whose_signatures_changed(
|
|
|
|
|
self, user_id: str, from_key: str
|
|
|
|
|
) -> Set[str]:
|
|
|
|
|
"""Get the users who have new cross-signing signatures made by `user_id` since
|
|
|
|
|
`from_key`.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str): the user who made the signatures
|
|
|
|
|
from_key (str): The device lists stream token
|
|
|
|
|
user_id: the user who made the signatures
|
|
|
|
|
from_key: The device lists stream token
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
A set of user IDs with updated signatures.
|
|
|
|
|
"""
|
|
|
|
|
from_key = int(from_key)
|
|
|
|
|
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
|
|
|
|
@ -575,7 +593,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
SELECT DISTINCT user_ids FROM user_signature_stream
|
|
|
|
|
WHERE from_user_id = ? AND stream_id > ?
|
|
|
|
|
"""
|
|
|
|
|
rows = yield self.db_pool.execute(
|
|
|
|
|
rows = await self.db_pool.execute(
|
|
|
|
|
"get_users_whose_signatures_changed", None, sql, user_id, from_key
|
|
|
|
|
)
|
|
|
|
|
return {user for row in rows for user in db_to_json(row[0])}
|
|
|
|
@ -638,7 +656,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@cached(max_entries=10000)
|
|
|
|
|
def get_device_list_last_stream_id_for_remote(self, user_id):
|
|
|
|
|
def get_device_list_last_stream_id_for_remote(self, user_id: str):
|
|
|
|
|
"""Get the last stream_id we got for a user. May be None if we haven't
|
|
|
|
|
got any information for them.
|
|
|
|
|
"""
|
|
|
|
@ -655,7 +673,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
list_name="user_ids",
|
|
|
|
|
inlineCallbacks=True,
|
|
|
|
|
)
|
|
|
|
|
def get_device_list_last_stream_id_for_remotes(self, user_ids):
|
|
|
|
|
def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
|
|
|
|
rows = yield self.db_pool.simple_select_many_batch(
|
|
|
|
|
table="device_lists_remote_extremeties",
|
|
|
|
|
column="user_id",
|
|
|
|
@ -669,8 +687,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def get_user_ids_requiring_device_list_resync(
|
|
|
|
|
async def get_user_ids_requiring_device_list_resync(
|
|
|
|
|
self, user_ids: Optional[Collection[str]] = None,
|
|
|
|
|
) -> Set[str]:
|
|
|
|
|
"""Given a list of remote users return the list of users that we
|
|
|
|
@ -681,7 +698,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
The IDs of users whose device lists need resync.
|
|
|
|
|
"""
|
|
|
|
|
if user_ids:
|
|
|
|
|
rows = yield self.db_pool.simple_select_many_batch(
|
|
|
|
|
rows = await self.db_pool.simple_select_many_batch(
|
|
|
|
|
table="device_lists_remote_resync",
|
|
|
|
|
column="user_id",
|
|
|
|
|
iterable=user_ids,
|
|
|
|
@ -689,7 +706,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
desc="get_user_ids_requiring_device_list_resync_with_iterable",
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
rows = yield self.db_pool.simple_select_list(
|
|
|
|
|
rows = await self.db_pool.simple_select_list(
|
|
|
|
|
table="device_lists_remote_resync",
|
|
|
|
|
keyvalues=None,
|
|
|
|
|
retcols=("user_id",),
|
|
|
|
@ -710,7 +727,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
|
|
desc="make_remote_user_device_cache_as_stale",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
|
|
|
|
|
def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
|
|
|
|
|
"""Mark that we no longer track device lists for remote user.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
@ -779,16 +796,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|
|
|
|
"drop_device_lists_outbound_last_success_non_unique_idx",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
|
|
|
|
async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
|
|
|
|
def f(conn):
|
|
|
|
|
txn = conn.cursor()
|
|
|
|
|
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
|
|
|
|
|
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
|
|
|
|
|
txn.close()
|
|
|
|
|
|
|
|
|
|
yield self.db_pool.runWithConnection(f)
|
|
|
|
|
yield self.db_pool.updates._end_background_update(
|
|
|
|
|
await self.db_pool.runWithConnection(f)
|
|
|
|
|
await self.db_pool.updates._end_background_update(
|
|
|
|
|
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
|
|
|
|
|
)
|
|
|
|
|
return 1
|
|
|
|
@ -868,18 +884,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
|
|
|
|
|
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def store_device(self, user_id, device_id, initial_device_display_name):
|
|
|
|
|
async def store_device(
|
|
|
|
|
self, user_id: str, device_id: str, initial_device_display_name: str
|
|
|
|
|
) -> bool:
|
|
|
|
|
"""Ensure the given device is known; add it to the store if not
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str): id of user associated with the device
|
|
|
|
|
device_id (str): id of device
|
|
|
|
|
initial_device_display_name (str): initial displayname of the
|
|
|
|
|
device. Ignored if device exists.
|
|
|
|
|
user_id: id of user associated with the device
|
|
|
|
|
device_id: id of device
|
|
|
|
|
initial_device_display_name: initial displayname of the device.
|
|
|
|
|
Ignored if device exists.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
defer.Deferred: boolean whether the device was inserted or an
|
|
|
|
|
existing device existed with that ID.
|
|
|
|
|
Whether the device was inserted or an existing device existed with that ID.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
StoreError: if the device is already in use
|
|
|
|
|
"""
|
|
|
|
@ -888,7 +906,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
inserted = yield self.db_pool.simple_insert(
|
|
|
|
|
inserted = await self.db_pool.simple_insert(
|
|
|
|
|
"devices",
|
|
|
|
|
values={
|
|
|
|
|
"user_id": user_id,
|
|
|
|
@ -902,7 +920,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
if not inserted:
|
|
|
|
|
# if the device already exists, check if it's a real device, or
|
|
|
|
|
# if the device ID is reserved by something else
|
|
|
|
|
hidden = yield self.db_pool.simple_select_one_onecol(
|
|
|
|
|
hidden = await self.db_pool.simple_select_one_onecol(
|
|
|
|
|
"devices",
|
|
|
|
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
|
|
|
|
retcol="hidden",
|
|
|
|
@ -927,17 +945,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
)
|
|
|
|
|
raise StoreError(500, "Problem storing device.")
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def delete_device(self, user_id, device_id):
|
|
|
|
|
async def delete_device(self, user_id: str, device_id: str) -> None:
|
|
|
|
|
"""Delete a device.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str): The ID of the user which owns the device
|
|
|
|
|
device_id (str): The ID of the device to delete
|
|
|
|
|
Returns:
|
|
|
|
|
defer.Deferred
|
|
|
|
|
user_id: The ID of the user which owns the device
|
|
|
|
|
device_id: The ID of the device to delete
|
|
|
|
|
"""
|
|
|
|
|
yield self.db_pool.simple_delete_one(
|
|
|
|
|
await self.db_pool.simple_delete_one(
|
|
|
|
|
table="devices",
|
|
|
|
|
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
|
|
|
|
desc="delete_device",
|
|
|
|
@ -945,17 +960,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
|
|
|
|
|
self.device_id_exists_cache.invalidate((user_id, device_id))
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def delete_devices(self, user_id, device_ids):
|
|
|
|
|
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
|
|
|
|
|
"""Deletes several devices.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str): The ID of the user which owns the devices
|
|
|
|
|
device_ids (list): The IDs of the devices to delete
|
|
|
|
|
Returns:
|
|
|
|
|
defer.Deferred
|
|
|
|
|
user_id: The ID of the user which owns the devices
|
|
|
|
|
device_ids: The IDs of the devices to delete
|
|
|
|
|
"""
|
|
|
|
|
yield self.db_pool.simple_delete_many(
|
|
|
|
|
await self.db_pool.simple_delete_many(
|
|
|
|
|
table="devices",
|
|
|
|
|
column="device_id",
|
|
|
|
|
iterable=device_ids,
|
|
|
|
@ -965,26 +977,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
for device_id in device_ids:
|
|
|
|
|
self.device_id_exists_cache.invalidate((user_id, device_id))
|
|
|
|
|
|
|
|
|
|
def update_device(self, user_id, device_id, new_display_name=None):
|
|
|
|
|
async def update_device(
|
|
|
|
|
self, user_id: str, device_id: str, new_display_name: Optional[str] = None
|
|
|
|
|
) -> None:
|
|
|
|
|
"""Update a device. Only updates the device if it is not marked as
|
|
|
|
|
hidden.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str): The ID of the user which owns the device
|
|
|
|
|
device_id (str): The ID of the device to update
|
|
|
|
|
new_display_name (str|None): new displayname for device; None
|
|
|
|
|
to leave unchanged
|
|
|
|
|
user_id: The ID of the user which owns the device
|
|
|
|
|
device_id: The ID of the device to update
|
|
|
|
|
new_display_name: new displayname for device; None to leave unchanged
|
|
|
|
|
Raises:
|
|
|
|
|
StoreError: if the device is not found
|
|
|
|
|
Returns:
|
|
|
|
|
defer.Deferred
|
|
|
|
|
"""
|
|
|
|
|
updates = {}
|
|
|
|
|
if new_display_name is not None:
|
|
|
|
|
updates["display_name"] = new_display_name
|
|
|
|
|
if not updates:
|
|
|
|
|
return defer.succeed(None)
|
|
|
|
|
return self.db_pool.simple_update_one(
|
|
|
|
|
return None
|
|
|
|
|
await self.db_pool.simple_update_one(
|
|
|
|
|
table="devices",
|
|
|
|
|
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
|
|
|
|
updatevalues=updates,
|
|
|
|
@ -992,7 +1003,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def update_remote_device_list_cache_entry(
|
|
|
|
|
self, user_id, device_id, content, stream_id
|
|
|
|
|
self, user_id: str, device_id: str, content: JsonDict, stream_id: int
|
|
|
|
|
):
|
|
|
|
|
"""Updates a single device in the cache of a remote user's devicelist.
|
|
|
|
|
|
|
|
|
@ -1000,10 +1011,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
device list.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str): User to update device list for
|
|
|
|
|
device_id (str): ID of decivice being updated
|
|
|
|
|
content (dict): new data on this device
|
|
|
|
|
stream_id (int): the version of the device list
|
|
|
|
|
user_id: User to update device list for
|
|
|
|
|
device_id: ID of decivice being updated
|
|
|
|
|
content: new data on this device
|
|
|
|
|
stream_id: the version of the device list
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred[None]
|
|
|
|
@ -1018,8 +1029,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _update_remote_device_list_cache_entry_txn(
|
|
|
|
|
self, txn, user_id, device_id, content, stream_id
|
|
|
|
|
):
|
|
|
|
|
self,
|
|
|
|
|
txn: LoggingTransaction,
|
|
|
|
|
user_id: str,
|
|
|
|
|
device_id: str,
|
|
|
|
|
content: JsonDict,
|
|
|
|
|
stream_id: int,
|
|
|
|
|
) -> None:
|
|
|
|
|
if content.get("deleted"):
|
|
|
|
|
self.db_pool.simple_delete_txn(
|
|
|
|
|
txn,
|
|
|
|
@ -1055,16 +1071,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
lock=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def update_remote_device_list_cache(self, user_id, devices, stream_id):
|
|
|
|
|
def update_remote_device_list_cache(
|
|
|
|
|
self, user_id: str, devices: List[dict], stream_id: int
|
|
|
|
|
):
|
|
|
|
|
"""Replace the entire cache of the remote user's devices.
|
|
|
|
|
|
|
|
|
|
Note: assumes that we are the only thread that can be updating this user's
|
|
|
|
|
device list.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
user_id (str): User to update device list for
|
|
|
|
|
devices (list[dict]): list of device objects supplied over federation
|
|
|
|
|
stream_id (int): the version of the device list
|
|
|
|
|
user_id: User to update device list for
|
|
|
|
|
devices: list of device objects supplied over federation
|
|
|
|
|
stream_id: the version of the device list
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Deferred[None]
|
|
|
|
@ -1077,7 +1095,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
stream_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
|
|
|
|
|
def _update_remote_device_list_cache_txn(
|
|
|
|
|
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
|
|
|
|
|
):
|
|
|
|
|
self.db_pool.simple_delete_txn(
|
|
|
|
|
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
|
|
|
|
|
)
|
|
|
|
@ -1118,8 +1138,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
|
|
def add_device_change_to_streams(self, user_id, device_ids, hosts):
|
|
|
|
|
async def add_device_change_to_streams(
|
|
|
|
|
self, user_id: str, device_ids: Collection[str], hosts: List[str]
|
|
|
|
|
):
|
|
|
|
|
"""Persist that a user's devices have been updated, and which hosts
|
|
|
|
|
(if any) should be poked.
|
|
|
|
|
"""
|
|
|
|
@ -1127,7 +1148,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
|
|
|
|
|
yield self.db_pool.runInteraction(
|
|
|
|
|
await self.db_pool.runInteraction(
|
|
|
|
|
"add_device_change_to_stream",
|
|
|
|
|
self._add_device_change_to_stream_txn,
|
|
|
|
|
user_id,
|
|
|
|
@ -1142,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
with self._device_list_id_gen.get_next_mult(
|
|
|
|
|
len(hosts) * len(device_ids)
|
|
|
|
|
) as stream_ids:
|
|
|
|
|
yield self.db_pool.runInteraction(
|
|
|
|
|
await self.db_pool.runInteraction(
|
|
|
|
|
"add_device_outbound_poke_to_stream",
|
|
|
|
|
self._add_device_outbound_poke_to_stream_txn,
|
|
|
|
|
user_id,
|
|
|
|
@ -1187,7 +1208,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _add_device_outbound_poke_to_stream_txn(
|
|
|
|
|
self, txn, user_id, device_ids, hosts, stream_ids, context,
|
|
|
|
|
self,
|
|
|
|
|
txn: LoggingTransaction,
|
|
|
|
|
user_id: str,
|
|
|
|
|
device_ids: Collection[str],
|
|
|
|
|
hosts: List[str],
|
|
|
|
|
stream_ids: List[str],
|
|
|
|
|
context: Dict[str, str],
|
|
|
|
|
):
|
|
|
|
|
for host in hosts:
|
|
|
|
|
txn.call_after(
|
|
|
|
@ -1219,7 +1246,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
|
|
|
|
|
def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
|
|
|
|
|
"""Delete old entries out of the device_lists_outbound_pokes to ensure
|
|
|
|
|
that we don't fill up due to dead servers.
|
|
|
|
|
|
|
|
|
|