Add support for claiming multiple OTKs at once. (#15468)
MSC3983 provides a way to request multiple OTKs at once from appservices, this extends this concept to the Client-Server API. Note that this will likely be spit out into a separate MSC, but is currently part of MSC3983.pull/15507/head
parent
6efa674004
commit
57aeeb308b
|
@ -0,0 +1 @@
|
|||
Support claiming more than one OTK at a time.
|
|
@ -442,8 +442,10 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
return False
|
||||
|
||||
async def claim_client_keys(
|
||||
self, service: "ApplicationService", query: List[Tuple[str, str, str]]
|
||||
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
|
||||
self, service: "ApplicationService", query: List[Tuple[str, str, str, int]]
|
||||
) -> Tuple[
|
||||
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
|
||||
]:
|
||||
"""Claim one time keys from an application service.
|
||||
|
||||
Note that any error (including a timeout) is treated as the application
|
||||
|
@ -469,8 +471,10 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
|
||||
# Create the expected payload shape.
|
||||
body: Dict[str, Dict[str, List[str]]] = {}
|
||||
for user_id, device, algorithm in query:
|
||||
body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
|
||||
for user_id, device, algorithm, count in query:
|
||||
body.setdefault(user_id, {}).setdefault(device, []).extend(
|
||||
[algorithm] * count
|
||||
)
|
||||
|
||||
uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
|
||||
try:
|
||||
|
@ -493,11 +497,20 @@ class ApplicationServiceApi(SimpleHttpClient):
|
|||
# or if some are still missing.
|
||||
#
|
||||
# TODO This places a lot of faith in the response shape being correct.
|
||||
missing = [
|
||||
(user_id, device, algorithm)
|
||||
for user_id, device, algorithm in query
|
||||
if algorithm not in response.get(user_id, {}).get(device, [])
|
||||
]
|
||||
missing = []
|
||||
for user_id, device, algorithm, count in query:
|
||||
# Count the number of keys in the response for this algorithm by
|
||||
# checking which key IDs start with the algorithm. This uses that
|
||||
# True == 1 in Python to generate a count.
|
||||
response_count = sum(
|
||||
key_id.startswith(f"{algorithm}:")
|
||||
for key_id in response.get(user_id, {}).get(device, {})
|
||||
)
|
||||
count -= response_count
|
||||
# If the appservice responds with fewer keys than requested, then
|
||||
# consider the request unfulfilled.
|
||||
if count > 0:
|
||||
missing.append((user_id, device, algorithm, count))
|
||||
|
||||
return response, missing
|
||||
|
||||
|
|
|
@ -235,7 +235,10 @@ class FederationClient(FederationBase):
|
|||
)
|
||||
|
||||
async def claim_client_keys(
|
||||
self, destination: str, content: JsonDict, timeout: Optional[int]
|
||||
self,
|
||||
destination: str,
|
||||
query: Dict[str, Dict[str, Dict[str, int]]],
|
||||
timeout: Optional[int],
|
||||
) -> JsonDict:
|
||||
"""Claims one-time keys for a device hosted on a remote server.
|
||||
|
||||
|
@ -247,6 +250,50 @@ class FederationClient(FederationBase):
|
|||
The JSON object from the response
|
||||
"""
|
||||
sent_queries_counter.labels("client_one_time_keys").inc()
|
||||
|
||||
# Convert the query with counts into a stable and unstable query and check
|
||||
# if attempting to claim more than 1 OTK.
|
||||
content: Dict[str, Dict[str, str]] = {}
|
||||
unstable_content: Dict[str, Dict[str, List[str]]] = {}
|
||||
use_unstable = False
|
||||
for user_id, one_time_keys in query.items():
|
||||
for device_id, algorithms in one_time_keys.items():
|
||||
if any(count > 1 for count in algorithms.values()):
|
||||
use_unstable = True
|
||||
if algorithms:
|
||||
# For the stable query, choose only the first algorithm.
|
||||
content.setdefault(user_id, {})[device_id] = next(iter(algorithms))
|
||||
# For the unstable query, repeat each algorithm by count, then
|
||||
# splat those into chain to get a flattened list of all algorithms.
|
||||
#
|
||||
# Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"].
|
||||
unstable_content.setdefault(user_id, {})[device_id] = list(
|
||||
itertools.chain(
|
||||
*(
|
||||
itertools.repeat(algorithm, count)
|
||||
for algorithm, count in algorithms.items()
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if use_unstable:
|
||||
try:
|
||||
return await self.transport_layer.claim_client_keys_unstable(
|
||||
destination, unstable_content, timeout
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
# If an error is received that is due to an unrecognised endpoint,
|
||||
# fallback to the v1 endpoint. Otherwise, consider it a legitimate error
|
||||
# and raise.
|
||||
if not is_unknown_endpoint(e):
|
||||
raise
|
||||
|
||||
logger.debug(
|
||||
"Couldn't claim client keys with the unstable API, falling back to the v1 API"
|
||||
)
|
||||
else:
|
||||
logger.debug("Skipping unstable claim client keys API")
|
||||
|
||||
return await self.transport_layer.claim_client_keys(
|
||||
destination, content, timeout
|
||||
)
|
||||
|
|
|
@ -1005,13 +1005,8 @@ class FederationServer(FederationBase):
|
|||
|
||||
@trace
|
||||
async def on_claim_client_keys(
|
||||
self, origin: str, content: JsonDict, always_include_fallback_keys: bool
|
||||
self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
|
||||
) -> Dict[str, Any]:
|
||||
query = []
|
||||
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
||||
for device_id, algorithm in device_keys.items():
|
||||
query.append((user_id, device_id, algorithm))
|
||||
|
||||
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
|
||||
results = await self._e2e_keys_handler.claim_local_one_time_keys(
|
||||
query, always_include_fallback_keys=always_include_fallback_keys
|
||||
|
|
|
@ -650,10 +650,10 @@ class TransportLayerClient:
|
|||
|
||||
Response:
|
||||
{
|
||||
"device_keys": {
|
||||
"one_time_keys": {
|
||||
"<user_id>": {
|
||||
"<device_id>": {
|
||||
"<algorithm>:<key_id>": "<key_base64>"
|
||||
"<algorithm>:<key_id>": <OTK JSON>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -669,7 +669,50 @@ class TransportLayerClient:
|
|||
path = _create_v1_path("/user/keys/claim")
|
||||
|
||||
return await self.client.post_json(
|
||||
destination=destination, path=path, data=query_content, timeout=timeout
|
||||
destination=destination,
|
||||
path=path,
|
||||
data={"one_time_keys": query_content},
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def claim_client_keys_unstable(
|
||||
self, destination: str, query_content: JsonDict, timeout: Optional[int]
|
||||
) -> JsonDict:
|
||||
"""Claim one-time keys for a list of devices hosted on a remote server.
|
||||
|
||||
Request:
|
||||
{
|
||||
"one_time_keys": {
|
||||
"<user_id>": {
|
||||
"<device_id>": {"<algorithm>": <count>}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Response:
|
||||
{
|
||||
"one_time_keys": {
|
||||
"<user_id>": {
|
||||
"<device_id>": {
|
||||
"<algorithm>:<key_id>": <OTK JSON>
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
destination: The server to query.
|
||||
query_content: The user ids to query.
|
||||
Returns:
|
||||
A dict containing the one-time keys.
|
||||
"""
|
||||
path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim")
|
||||
|
||||
return await self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data={"one_time_keys": query_content},
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def get_missing_events(
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from collections import Counter
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
|
@ -577,16 +578,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
|
|||
async def on_POST(
|
||||
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
# Generate a count for each algorithm, which is hard-coded to 1.
|
||||
key_query: List[Tuple[str, str, str, int]] = []
|
||||
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
||||
for device_id, algorithm in device_keys.items():
|
||||
key_query.append((user_id, device_id, algorithm, 1))
|
||||
|
||||
response = await self.handler.on_claim_client_keys(
|
||||
origin, content, always_include_fallback_keys=False
|
||||
key_query, always_include_fallback_keys=False
|
||||
)
|
||||
return 200, response
|
||||
|
||||
|
||||
class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
|
||||
"""
|
||||
Identical to the stable endpoint (FederationClientKeysClaimServlet) except it
|
||||
always includes fallback keys in the response.
|
||||
Identical to the stable endpoint (FederationClientKeysClaimServlet) except
|
||||
it allows for querying for multiple OTKs at once and always includes fallback
|
||||
keys in the response.
|
||||
"""
|
||||
|
||||
PREFIX = FEDERATION_UNSTABLE_PREFIX
|
||||
|
@ -596,8 +604,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
|
|||
async def on_POST(
|
||||
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
# Generate a count for each algorithm.
|
||||
key_query: List[Tuple[str, str, str, int]] = []
|
||||
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
||||
for device_id, algorithms in device_keys.items():
|
||||
counts = Counter(algorithms)
|
||||
for algorithm, count in counts.items():
|
||||
key_query.append((user_id, device_id, algorithm, count))
|
||||
|
||||
response = await self.handler.on_claim_client_keys(
|
||||
origin, content, always_include_fallback_keys=True
|
||||
key_query, always_include_fallback_keys=True
|
||||
)
|
||||
return 200, response
|
||||
|
||||
|
@ -805,6 +821,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
|
|||
FederationClientKeysQueryServlet,
|
||||
FederationUserDevicesQueryServlet,
|
||||
FederationClientKeysClaimServlet,
|
||||
FederationUnstableClientKeysClaimServlet,
|
||||
FederationThirdPartyInviteExchangeServlet,
|
||||
On3pidBindServlet,
|
||||
FederationVersionServlet,
|
||||
|
|
|
@ -841,8 +841,10 @@ class ApplicationServicesHandler:
|
|||
return True
|
||||
|
||||
async def claim_e2e_one_time_keys(
|
||||
self, query: Iterable[Tuple[str, str, str]]
|
||||
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
|
||||
self, query: Iterable[Tuple[str, str, str, int]]
|
||||
) -> Tuple[
|
||||
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
|
||||
]:
|
||||
"""Claim one time keys from application services.
|
||||
|
||||
Users which are exclusively owned by an application service are sent a
|
||||
|
@ -863,18 +865,18 @@ class ApplicationServicesHandler:
|
|||
services = self.store.get_app_services()
|
||||
|
||||
# Partition the users by appservice.
|
||||
query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
|
||||
query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {}
|
||||
missing = []
|
||||
for user_id, device, algorithm in query:
|
||||
for user_id, device, algorithm, count in query:
|
||||
if not self.store.get_if_app_services_interested_in_user(user_id):
|
||||
missing.append((user_id, device, algorithm))
|
||||
missing.append((user_id, device, algorithm, count))
|
||||
continue
|
||||
|
||||
# Find the associated appservice.
|
||||
for service in services:
|
||||
if service.is_exclusive_user(user_id):
|
||||
query_by_appservice.setdefault(service.id, []).append(
|
||||
(user_id, device, algorithm)
|
||||
(user_id, device, algorithm, count)
|
||||
)
|
||||
continue
|
||||
|
||||
|
|
|
@ -564,7 +564,7 @@ class E2eKeysHandler:
|
|||
|
||||
async def claim_local_one_time_keys(
|
||||
self,
|
||||
local_query: List[Tuple[str, str, str]],
|
||||
local_query: List[Tuple[str, str, str, int]],
|
||||
always_include_fallback_keys: bool,
|
||||
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
|
||||
"""Claim one time keys for local users.
|
||||
|
@ -581,6 +581,12 @@ class E2eKeysHandler:
|
|||
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
|
||||
"""
|
||||
|
||||
# Cap the number of OTKs that can be claimed at once to avoid abuse.
|
||||
local_query = [
|
||||
(user_id, device_id, algorithm, min(count, 5))
|
||||
for user_id, device_id, algorithm, count in local_query
|
||||
]
|
||||
|
||||
otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
|
||||
|
||||
# If the application services have not provided any keys via the C-S
|
||||
|
@ -607,7 +613,7 @@ class E2eKeysHandler:
|
|||
# from the appservice for that user ID / device ID. If it is found,
|
||||
# check if any of the keys match the requested algorithm & are a
|
||||
# fallback key.
|
||||
for user_id, device_id, algorithm in local_query:
|
||||
for user_id, device_id, algorithm, _count in local_query:
|
||||
# Check if the appservice responded for this query.
|
||||
as_result = appservice_results.get(user_id, {}).get(device_id, {})
|
||||
found_otk = False
|
||||
|
@ -630,13 +636,17 @@ class E2eKeysHandler:
|
|||
.get(device_id, {})
|
||||
.keys()
|
||||
)
|
||||
# Note that it doesn't make sense to request more than 1 fallback key
|
||||
# per (user_id, device_id, algorithm).
|
||||
fallback_query.append((user_id, device_id, algorithm, mark_as_used))
|
||||
|
||||
else:
|
||||
# All fallback keys get marked as used.
|
||||
fallback_query = [
|
||||
# Note that it doesn't make sense to request more than 1 fallback key
|
||||
# per (user_id, device_id, algorithm).
|
||||
(user_id, device_id, algorithm, True)
|
||||
for user_id, device_id, algorithm in not_found
|
||||
for user_id, device_id, algorithm, count in not_found
|
||||
]
|
||||
|
||||
# For each user that does not have a one-time keys available, see if
|
||||
|
@ -650,18 +660,19 @@ class E2eKeysHandler:
|
|||
@trace
|
||||
async def claim_one_time_keys(
|
||||
self,
|
||||
query: Dict[str, Dict[str, Dict[str, str]]],
|
||||
query: Dict[str, Dict[str, Dict[str, int]]],
|
||||
timeout: Optional[int],
|
||||
always_include_fallback_keys: bool,
|
||||
) -> JsonDict:
|
||||
local_query: List[Tuple[str, str, str]] = []
|
||||
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||
local_query: List[Tuple[str, str, str, int]] = []
|
||||
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
|
||||
|
||||
for user_id, one_time_keys in query.get("one_time_keys", {}).items():
|
||||
for user_id, one_time_keys in query.items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if self.is_mine(UserID.from_string(user_id)):
|
||||
for device_id, algorithm in one_time_keys.items():
|
||||
local_query.append((user_id, device_id, algorithm))
|
||||
for device_id, algorithms in one_time_keys.items():
|
||||
for algorithm, count in algorithms.items():
|
||||
local_query.append((user_id, device_id, algorithm, count))
|
||||
else:
|
||||
domain = get_domain_from_id(user_id)
|
||||
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
|
||||
|
@ -692,7 +703,7 @@ class E2eKeysHandler:
|
|||
device_keys = remote_queries[destination]
|
||||
try:
|
||||
remote_result = await self.federation.claim_client_keys(
|
||||
destination, {"one_time_keys": device_keys}, timeout=timeout
|
||||
destination, device_keys, timeout=timeout
|
||||
)
|
||||
for user_id, keys in remote_result["one_time_keys"].items():
|
||||
if user_id in device_keys:
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Optional, Tuple
|
||||
from collections import Counter
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import InvalidAPICallError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
|
@ -289,16 +290,40 @@ class OneTimeKeyServlet(RestServlet):
|
|||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
# Generate a count for each algorithm, which is hard-coded to 1.
|
||||
query: Dict[str, Dict[str, Dict[str, int]]] = {}
|
||||
for user_id, one_time_keys in body.get("one_time_keys", {}).items():
|
||||
for device_id, algorithm in one_time_keys.items():
|
||||
query.setdefault(user_id, {})[device_id] = {algorithm: 1}
|
||||
|
||||
result = await self.e2e_keys_handler.claim_one_time_keys(
|
||||
body, timeout, always_include_fallback_keys=False
|
||||
query, timeout, always_include_fallback_keys=False
|
||||
)
|
||||
return 200, result
|
||||
|
||||
|
||||
class UnstableOneTimeKeyServlet(RestServlet):
|
||||
"""
|
||||
Identical to the stable endpoint (OneTimeKeyServlet) except it always includes
|
||||
fallback keys in the response.
|
||||
Identical to the stable endpoint (OneTimeKeyServlet) except it allows for
|
||||
querying for multiple OTKs at once and always includes fallback keys in the
|
||||
response.
|
||||
|
||||
POST /keys/claim HTTP/1.1
|
||||
{
|
||||
"one_time_keys": {
|
||||
"<user_id>": {
|
||||
"<device_id>": ["<algorithm>", ...]
|
||||
} } }
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
{
|
||||
"one_time_keys": {
|
||||
"<user_id>": {
|
||||
"<device_id>": {
|
||||
"<algorithm>:<key_id>": "<key_base64>"
|
||||
} } } }
|
||||
|
||||
"""
|
||||
|
||||
PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")]
|
||||
|
@ -313,8 +338,15 @@ class UnstableOneTimeKeyServlet(RestServlet):
|
|||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
# Generate a count for each algorithm.
|
||||
query: Dict[str, Dict[str, Dict[str, int]]] = {}
|
||||
for user_id, one_time_keys in body.get("one_time_keys", {}).items():
|
||||
for device_id, algorithms in one_time_keys.items():
|
||||
query.setdefault(user_id, {})[device_id] = Counter(algorithms)
|
||||
|
||||
result = await self.e2e_keys_handler.claim_one_time_keys(
|
||||
body, timeout, always_include_fallback_keys=True
|
||||
query, timeout, always_include_fallback_keys=True
|
||||
)
|
||||
return 200, result
|
||||
|
||||
|
|
|
@ -1027,8 +1027,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
...
|
||||
|
||||
async def claim_e2e_one_time_keys(
|
||||
self, query_list: Iterable[Tuple[str, str, str]]
|
||||
) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
|
||||
self, query_list: Iterable[Tuple[str, str, str, int]]
|
||||
) -> Tuple[
|
||||
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
|
||||
]:
|
||||
"""Take a list of one time keys out of the database.
|
||||
|
||||
Args:
|
||||
|
@ -1043,8 +1045,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
|
||||
@trace
|
||||
def _claim_e2e_one_time_key_simple(
|
||||
txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
|
||||
) -> Optional[Tuple[str, str]]:
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
algorithm: str,
|
||||
count: int,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Claim OTK for device for DBs that don't support RETURNING.
|
||||
|
||||
Returns:
|
||||
|
@ -1055,36 +1061,41 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
sql = """
|
||||
SELECT key_id, key_json FROM e2e_one_time_keys_json
|
||||
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
||||
LIMIT 1
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (user_id, device_id, algorithm))
|
||||
otk_row = txn.fetchone()
|
||||
if otk_row is None:
|
||||
return None
|
||||
txn.execute(sql, (user_id, device_id, algorithm, count))
|
||||
otk_rows = list(txn)
|
||||
if not otk_rows:
|
||||
return []
|
||||
|
||||
key_id, key_json = otk_row
|
||||
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="e2e_one_time_keys_json",
|
||||
column="key_id",
|
||||
values=[otk_row[0] for otk_row in otk_rows],
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"algorithm": algorithm,
|
||||
"key_id": key_id,
|
||||
},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
return f"{algorithm}:{key_id}", key_json
|
||||
return [
|
||||
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
|
||||
]
|
||||
|
||||
@trace
|
||||
def _claim_e2e_one_time_key_returning(
|
||||
txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str
|
||||
) -> Optional[Tuple[str, str]]:
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
algorithm: str,
|
||||
count: int,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Claim OTK for device for DBs that support RETURNING.
|
||||
|
||||
Returns:
|
||||
|
@ -1099,28 +1110,30 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
AND key_id IN (
|
||||
SELECT key_id FROM e2e_one_time_keys_json
|
||||
WHERE user_id = ? AND device_id = ? AND algorithm = ?
|
||||
LIMIT 1
|
||||
LIMIT ?
|
||||
)
|
||||
RETURNING key_id, key_json
|
||||
"""
|
||||
|
||||
txn.execute(
|
||||
sql, (user_id, device_id, algorithm, user_id, device_id, algorithm)
|
||||
sql,
|
||||
(user_id, device_id, algorithm, user_id, device_id, algorithm, count),
|
||||
)
|
||||
otk_row = txn.fetchone()
|
||||
if otk_row is None:
|
||||
return None
|
||||
otk_rows = list(txn)
|
||||
if not otk_rows:
|
||||
return []
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
key_id, key_json = otk_row
|
||||
return f"{algorithm}:{key_id}", key_json
|
||||
return [
|
||||
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows
|
||||
]
|
||||
|
||||
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
|
||||
missing: List[Tuple[str, str, str]] = []
|
||||
for user_id, device_id, algorithm in query_list:
|
||||
missing: List[Tuple[str, str, str, int]] = []
|
||||
for user_id, device_id, algorithm, count in query_list:
|
||||
if self.database_engine.supports_returning:
|
||||
# If we support RETURNING clause we can use a single query that
|
||||
# allows us to use autocommit mode.
|
||||
|
@ -1130,21 +1143,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|||
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
|
||||
db_autocommit = False
|
||||
|
||||
claim_row = await self.db_pool.runInteraction(
|
||||
claim_rows = await self.db_pool.runInteraction(
|
||||
"claim_e2e_one_time_keys",
|
||||
_claim_e2e_one_time_key,
|
||||
user_id,
|
||||
device_id,
|
||||
algorithm,
|
||||
count,
|
||||
db_autocommit=db_autocommit,
|
||||
)
|
||||
if claim_row:
|
||||
if claim_rows:
|
||||
device_results = results.setdefault(user_id, {}).setdefault(
|
||||
device_id, {}
|
||||
)
|
||||
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
|
||||
else:
|
||||
missing.append((user_id, device_id, algorithm))
|
||||
for claim_row in claim_rows:
|
||||
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
|
||||
# Did we get enough OTKs?
|
||||
count -= len(claim_rows)
|
||||
if count:
|
||||
missing.append((user_id, device_id, algorithm, count))
|
||||
|
||||
return results, missing
|
||||
|
||||
|
|
|
@ -195,11 +195,11 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
MISSING_KEYS = [
|
||||
# Known user, known device, missing algorithm.
|
||||
("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"),
|
||||
("@alice:example.org", "DEVICE_2", "xyz", 1),
|
||||
# Known user, missing device.
|
||||
("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"),
|
||||
("@alice:example.org", "DEVICE_3", "signed_curve25519", 1),
|
||||
# Unknown user.
|
||||
("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"),
|
||||
("@bob:example.org", "DEVICE_4", "signed_curve25519", 1),
|
||||
]
|
||||
|
||||
claimed_keys, missing = self.get_success(
|
||||
|
@ -207,9 +207,8 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
|
|||
self.service,
|
||||
[
|
||||
# Found devices
|
||||
("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"),
|
||||
("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"),
|
||||
("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"),
|
||||
("@alice:example.org", "DEVICE_1", "signed_curve25519", 1),
|
||||
("@alice:example.org", "DEVICE_2", "signed_curve25519", 1),
|
||||
]
|
||||
+ MISSING_KEYS,
|
||||
)
|
||||
|
|
|
@ -160,7 +160,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
res2 = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
{local_user: {device_id: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
|
@ -205,7 +205,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# key
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
{local_user: {device_id: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
|
@ -224,7 +224,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# claiming an OTK again should return the same fallback key
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
{local_user: {device_id: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
|
@ -273,7 +273,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
{local_user: {device_id: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
|
@ -285,7 +285,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
{local_user: {device_id: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
|
@ -306,7 +306,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
{local_user: {device_id: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
|
@ -347,7 +347,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# return both.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
{local_user: {device_id: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
|
@ -369,7 +369,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# Claiming an OTK again should return only the fallback key.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}},
|
||||
{local_user: {device_id: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
|
@ -1052,7 +1052,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
# Setup a response, but only for device 2.
|
||||
self.appservice_api.claim_client_keys.return_value = make_awaitable(
|
||||
({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")])
|
||||
({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)])
|
||||
)
|
||||
|
||||
# we shouldn't have any unused fallback keys yet
|
||||
|
@ -1079,11 +1079,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# query the fallback keys.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{
|
||||
"one_time_keys": {
|
||||
local_user: {device_id_1: "alg1", device_id_2: "alg1"}
|
||||
}
|
||||
},
|
||||
{local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=False,
|
||||
)
|
||||
|
@ -1128,7 +1124,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# Claim OTKs, which will ask the appservice and do nothing else.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||
{local_user: {device_id_1: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
|
@ -1172,7 +1168,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# uploaded fallback key.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||
{local_user: {device_id_1: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
|
@ -1205,7 +1201,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# Claim OTKs, which will return information only from the database.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||
{local_user: {device_id_1: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
|
@ -1232,7 +1228,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
|||
# Claim OTKs, which will return only the fallback key from the database.
|
||||
claim_res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id_1: "alg1"}}},
|
||||
{local_user: {device_id_1: {"alg1": 1}}},
|
||||
timeout=None,
|
||||
always_include_fallback_keys=True,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue