Add support for claiming multiple OTKs at once. (#15468)
MSC3983 provides a way to request multiple OTKs at once from appservices, this extends this concept to the Client-Server API. Note that this will likely be spit out into a separate MSC, but is currently part of MSC3983.pull/15507/head
							parent
							
								
									6efa674004
								
							
						
					
					
						commit
						57aeeb308b
					
				|  | @ -0,0 +1 @@ | |||
| Support claiming more than one OTK at a time. | ||||
|  | @ -442,8 +442,10 @@ class ApplicationServiceApi(SimpleHttpClient): | |||
|         return False | ||||
| 
 | ||||
|     async def claim_client_keys( | ||||
|         self, service: "ApplicationService", query: List[Tuple[str, str, str]] | ||||
|     ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: | ||||
|         self, service: "ApplicationService", query: List[Tuple[str, str, str, int]] | ||||
|     ) -> Tuple[ | ||||
|         Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] | ||||
|     ]: | ||||
|         """Claim one time keys from an application service. | ||||
| 
 | ||||
|         Note that any error (including a timeout) is treated as the application | ||||
|  | @ -469,8 +471,10 @@ class ApplicationServiceApi(SimpleHttpClient): | |||
| 
 | ||||
|         # Create the expected payload shape. | ||||
|         body: Dict[str, Dict[str, List[str]]] = {} | ||||
|         for user_id, device, algorithm in query: | ||||
|             body.setdefault(user_id, {}).setdefault(device, []).append(algorithm) | ||||
|         for user_id, device, algorithm, count in query: | ||||
|             body.setdefault(user_id, {}).setdefault(device, []).extend( | ||||
|                 [algorithm] * count | ||||
|             ) | ||||
| 
 | ||||
|         uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim" | ||||
|         try: | ||||
|  | @ -493,11 +497,20 @@ class ApplicationServiceApi(SimpleHttpClient): | |||
|         # or if some are still missing. | ||||
|         # | ||||
|         # TODO This places a lot of faith in the response shape being correct. | ||||
|         missing = [ | ||||
|             (user_id, device, algorithm) | ||||
|             for user_id, device, algorithm in query | ||||
|             if algorithm not in response.get(user_id, {}).get(device, []) | ||||
|         ] | ||||
|         missing = [] | ||||
|         for user_id, device, algorithm, count in query: | ||||
|             # Count the number of keys in the response for this algorithm by | ||||
|             # checking which key IDs start with the algorithm. This uses that | ||||
|             # True == 1 in Python to generate a count. | ||||
|             response_count = sum( | ||||
|                 key_id.startswith(f"{algorithm}:") | ||||
|                 for key_id in response.get(user_id, {}).get(device, {}) | ||||
|             ) | ||||
|             count -= response_count | ||||
|             # If the appservice responds with fewer keys than requested, then | ||||
|             # consider the request unfulfilled. | ||||
|             if count > 0: | ||||
|                 missing.append((user_id, device, algorithm, count)) | ||||
| 
 | ||||
|         return response, missing | ||||
| 
 | ||||
|  |  | |||
|  | @ -235,7 +235,10 @@ class FederationClient(FederationBase): | |||
|         ) | ||||
| 
 | ||||
|     async def claim_client_keys( | ||||
|         self, destination: str, content: JsonDict, timeout: Optional[int] | ||||
|         self, | ||||
|         destination: str, | ||||
|         query: Dict[str, Dict[str, Dict[str, int]]], | ||||
|         timeout: Optional[int], | ||||
|     ) -> JsonDict: | ||||
|         """Claims one-time keys for a device hosted on a remote server. | ||||
| 
 | ||||
|  | @ -247,6 +250,50 @@ class FederationClient(FederationBase): | |||
|             The JSON object from the response | ||||
|         """ | ||||
|         sent_queries_counter.labels("client_one_time_keys").inc() | ||||
| 
 | ||||
|         # Convert the query with counts into a stable and unstable query and check | ||||
|         # if attempting to claim more than 1 OTK. | ||||
|         content: Dict[str, Dict[str, str]] = {} | ||||
|         unstable_content: Dict[str, Dict[str, List[str]]] = {} | ||||
|         use_unstable = False | ||||
|         for user_id, one_time_keys in query.items(): | ||||
|             for device_id, algorithms in one_time_keys.items(): | ||||
|                 if any(count > 1 for count in algorithms.values()): | ||||
|                     use_unstable = True | ||||
|                 if algorithms: | ||||
|                     # For the stable query, choose only the first algorithm. | ||||
|                     content.setdefault(user_id, {})[device_id] = next(iter(algorithms)) | ||||
|                     # For the unstable query, repeat each algorithm by count, then | ||||
|                     # splat those into chain to get a flattened list of all algorithms. | ||||
|                     # | ||||
|                     # Converts from {"algo1": 2, "algo2": 2} to ["algo1", "algo1", "algo2"]. | ||||
|                     unstable_content.setdefault(user_id, {})[device_id] = list( | ||||
|                         itertools.chain( | ||||
|                             *( | ||||
|                                 itertools.repeat(algorithm, count) | ||||
|                                 for algorithm, count in algorithms.items() | ||||
|                             ) | ||||
|                         ) | ||||
|                     ) | ||||
| 
 | ||||
|         if use_unstable: | ||||
|             try: | ||||
|                 return await self.transport_layer.claim_client_keys_unstable( | ||||
|                     destination, unstable_content, timeout | ||||
|                 ) | ||||
|             except HttpResponseException as e: | ||||
|                 # If an error is received that is due to an unrecognised endpoint, | ||||
|                 # fallback to the v1 endpoint. Otherwise, consider it a legitimate error | ||||
|                 # and raise. | ||||
|                 if not is_unknown_endpoint(e): | ||||
|                     raise | ||||
| 
 | ||||
|             logger.debug( | ||||
|                 "Couldn't claim client keys with the unstable API, falling back to the v1 API" | ||||
|             ) | ||||
|         else: | ||||
|             logger.debug("Skipping unstable claim client keys API") | ||||
| 
 | ||||
|         return await self.transport_layer.claim_client_keys( | ||||
|             destination, content, timeout | ||||
|         ) | ||||
|  |  | |||
|  | @ -1005,13 +1005,8 @@ class FederationServer(FederationBase): | |||
| 
 | ||||
|     @trace | ||||
|     async def on_claim_client_keys( | ||||
|         self, origin: str, content: JsonDict, always_include_fallback_keys: bool | ||||
|         self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool | ||||
|     ) -> Dict[str, Any]: | ||||
|         query = [] | ||||
|         for user_id, device_keys in content.get("one_time_keys", {}).items(): | ||||
|             for device_id, algorithm in device_keys.items(): | ||||
|                 query.append((user_id, device_id, algorithm)) | ||||
| 
 | ||||
|         log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) | ||||
|         results = await self._e2e_keys_handler.claim_local_one_time_keys( | ||||
|             query, always_include_fallback_keys=always_include_fallback_keys | ||||
|  |  | |||
|  | @ -650,10 +650,10 @@ class TransportLayerClient: | |||
| 
 | ||||
|         Response: | ||||
|             { | ||||
|               "device_keys": { | ||||
|               "one_time_keys": { | ||||
|                 "<user_id>": { | ||||
|                   "<device_id>": { | ||||
|                     "<algorithm>:<key_id>": "<key_base64>" | ||||
|                     "<algorithm>:<key_id>": <OTK JSON> | ||||
|                   } | ||||
|                 } | ||||
|               } | ||||
|  | @ -669,7 +669,50 @@ class TransportLayerClient: | |||
|         path = _create_v1_path("/user/keys/claim") | ||||
| 
 | ||||
|         return await self.client.post_json( | ||||
|             destination=destination, path=path, data=query_content, timeout=timeout | ||||
|             destination=destination, | ||||
|             path=path, | ||||
|             data={"one_time_keys": query_content}, | ||||
|             timeout=timeout, | ||||
|         ) | ||||
| 
 | ||||
|     async def claim_client_keys_unstable( | ||||
|         self, destination: str, query_content: JsonDict, timeout: Optional[int] | ||||
|     ) -> JsonDict: | ||||
|         """Claim one-time keys for a list of devices hosted on a remote server. | ||||
| 
 | ||||
|         Request: | ||||
|             { | ||||
|               "one_time_keys": { | ||||
|                 "<user_id>": { | ||||
|                   "<device_id>": {"<algorithm>": <count>} | ||||
|                 } | ||||
|               } | ||||
|             } | ||||
| 
 | ||||
|         Response: | ||||
|             { | ||||
|               "one_time_keys": { | ||||
|                 "<user_id>": { | ||||
|                   "<device_id>": { | ||||
|                     "<algorithm>:<key_id>": <OTK JSON> | ||||
|                   } | ||||
|                 } | ||||
|               } | ||||
|             } | ||||
| 
 | ||||
|         Args: | ||||
|             destination: The server to query. | ||||
|             query_content: The user ids to query. | ||||
|         Returns: | ||||
|             A dict containing the one-time keys. | ||||
|         """ | ||||
|         path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim") | ||||
| 
 | ||||
|         return await self.client.post_json( | ||||
|             destination=destination, | ||||
|             path=path, | ||||
|             data={"one_time_keys": query_content}, | ||||
|             timeout=timeout, | ||||
|         ) | ||||
| 
 | ||||
|     async def get_missing_events( | ||||
|  |  | |||
|  | @ -12,6 +12,7 @@ | |||
| #  See the License for the specific language governing permissions and | ||||
| #  limitations under the License. | ||||
| import logging | ||||
| from collections import Counter | ||||
| from typing import ( | ||||
|     TYPE_CHECKING, | ||||
|     Dict, | ||||
|  | @ -577,16 +578,23 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet): | |||
|     async def on_POST( | ||||
|         self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         # Generate a count for each algorithm, which is hard-coded to 1. | ||||
|         key_query: List[Tuple[str, str, str, int]] = [] | ||||
|         for user_id, device_keys in content.get("one_time_keys", {}).items(): | ||||
|             for device_id, algorithm in device_keys.items(): | ||||
|                 key_query.append((user_id, device_id, algorithm, 1)) | ||||
| 
 | ||||
|         response = await self.handler.on_claim_client_keys( | ||||
|             origin, content, always_include_fallback_keys=False | ||||
|             key_query, always_include_fallback_keys=False | ||||
|         ) | ||||
|         return 200, response | ||||
| 
 | ||||
| 
 | ||||
| class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet): | ||||
|     """ | ||||
|     Identical to the stable endpoint (FederationClientKeysClaimServlet) except it | ||||
|     always includes fallback keys in the response. | ||||
|     Identical to the stable endpoint (FederationClientKeysClaimServlet) except | ||||
|     it allows for querying for multiple OTKs at once and always includes fallback | ||||
|     keys in the response. | ||||
|     """ | ||||
| 
 | ||||
|     PREFIX = FEDERATION_UNSTABLE_PREFIX | ||||
|  | @ -596,8 +604,16 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet): | |||
|     async def on_POST( | ||||
|         self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] | ||||
|     ) -> Tuple[int, JsonDict]: | ||||
|         # Generate a count for each algorithm. | ||||
|         key_query: List[Tuple[str, str, str, int]] = [] | ||||
|         for user_id, device_keys in content.get("one_time_keys", {}).items(): | ||||
|             for device_id, algorithms in device_keys.items(): | ||||
|                 counts = Counter(algorithms) | ||||
|                 for algorithm, count in counts.items(): | ||||
|                     key_query.append((user_id, device_id, algorithm, count)) | ||||
| 
 | ||||
|         response = await self.handler.on_claim_client_keys( | ||||
|             origin, content, always_include_fallback_keys=True | ||||
|             key_query, always_include_fallback_keys=True | ||||
|         ) | ||||
|         return 200, response | ||||
| 
 | ||||
|  | @ -805,6 +821,7 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( | |||
|     FederationClientKeysQueryServlet, | ||||
|     FederationUserDevicesQueryServlet, | ||||
|     FederationClientKeysClaimServlet, | ||||
|     FederationUnstableClientKeysClaimServlet, | ||||
|     FederationThirdPartyInviteExchangeServlet, | ||||
|     On3pidBindServlet, | ||||
|     FederationVersionServlet, | ||||
|  |  | |||
|  | @ -841,8 +841,10 @@ class ApplicationServicesHandler: | |||
|         return True | ||||
| 
 | ||||
|     async def claim_e2e_one_time_keys( | ||||
|         self, query: Iterable[Tuple[str, str, str]] | ||||
|     ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: | ||||
|         self, query: Iterable[Tuple[str, str, str, int]] | ||||
|     ) -> Tuple[ | ||||
|         Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] | ||||
|     ]: | ||||
|         """Claim one time keys from application services. | ||||
| 
 | ||||
|         Users which are exclusively owned by an application service are sent a | ||||
|  | @ -863,18 +865,18 @@ class ApplicationServicesHandler: | |||
|         services = self.store.get_app_services() | ||||
| 
 | ||||
|         # Partition the users by appservice. | ||||
|         query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {} | ||||
|         query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {} | ||||
|         missing = [] | ||||
|         for user_id, device, algorithm in query: | ||||
|         for user_id, device, algorithm, count in query: | ||||
|             if not self.store.get_if_app_services_interested_in_user(user_id): | ||||
|                 missing.append((user_id, device, algorithm)) | ||||
|                 missing.append((user_id, device, algorithm, count)) | ||||
|                 continue | ||||
| 
 | ||||
|             # Find the associated appservice. | ||||
|             for service in services: | ||||
|                 if service.is_exclusive_user(user_id): | ||||
|                     query_by_appservice.setdefault(service.id, []).append( | ||||
|                         (user_id, device, algorithm) | ||||
|                         (user_id, device, algorithm, count) | ||||
|                     ) | ||||
|                     continue | ||||
| 
 | ||||
|  |  | |||
|  | @ -564,7 +564,7 @@ class E2eKeysHandler: | |||
| 
 | ||||
|     async def claim_local_one_time_keys( | ||||
|         self, | ||||
|         local_query: List[Tuple[str, str, str]], | ||||
|         local_query: List[Tuple[str, str, str, int]], | ||||
|         always_include_fallback_keys: bool, | ||||
|     ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: | ||||
|         """Claim one time keys for local users. | ||||
|  | @ -581,6 +581,12 @@ class E2eKeysHandler: | |||
|             An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. | ||||
|         """ | ||||
| 
 | ||||
|         # Cap the number of OTKs that can be claimed at once to avoid abuse. | ||||
|         local_query = [ | ||||
|             (user_id, device_id, algorithm, min(count, 5)) | ||||
|             for user_id, device_id, algorithm, count in local_query | ||||
|         ] | ||||
| 
 | ||||
|         otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) | ||||
| 
 | ||||
|         # If the application services have not provided any keys via the C-S | ||||
|  | @ -607,7 +613,7 @@ class E2eKeysHandler: | |||
|             # from the appservice for that user ID / device ID. If it is found, | ||||
|             # check if any of the keys match the requested algorithm & are a | ||||
|             # fallback key. | ||||
|             for user_id, device_id, algorithm in local_query: | ||||
|             for user_id, device_id, algorithm, _count in local_query: | ||||
|                 # Check if the appservice responded for this query. | ||||
|                 as_result = appservice_results.get(user_id, {}).get(device_id, {}) | ||||
|                 found_otk = False | ||||
|  | @ -630,13 +636,17 @@ class E2eKeysHandler: | |||
|                         .get(device_id, {}) | ||||
|                         .keys() | ||||
|                     ) | ||||
|                     # Note that it doesn't make sense to request more than 1 fallback key | ||||
|                     # per (user_id, device_id, algorithm). | ||||
|                     fallback_query.append((user_id, device_id, algorithm, mark_as_used)) | ||||
| 
 | ||||
|         else: | ||||
|             # All fallback keys get marked as used. | ||||
|             fallback_query = [ | ||||
|                 # Note that it doesn't make sense to request more than 1 fallback key | ||||
|                 # per (user_id, device_id, algorithm). | ||||
|                 (user_id, device_id, algorithm, True) | ||||
|                 for user_id, device_id, algorithm in not_found | ||||
|                 for user_id, device_id, algorithm, count in not_found | ||||
|             ] | ||||
| 
 | ||||
|         # For each user that does not have a one-time keys available, see if | ||||
|  | @ -650,18 +660,19 @@ class E2eKeysHandler: | |||
|     @trace | ||||
|     async def claim_one_time_keys( | ||||
|         self, | ||||
|         query: Dict[str, Dict[str, Dict[str, str]]], | ||||
|         query: Dict[str, Dict[str, Dict[str, int]]], | ||||
|         timeout: Optional[int], | ||||
|         always_include_fallback_keys: bool, | ||||
|     ) -> JsonDict: | ||||
|         local_query: List[Tuple[str, str, str]] = [] | ||||
|         remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} | ||||
|         local_query: List[Tuple[str, str, str, int]] = [] | ||||
|         remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {} | ||||
| 
 | ||||
|         for user_id, one_time_keys in query.get("one_time_keys", {}).items(): | ||||
|         for user_id, one_time_keys in query.items(): | ||||
|             # we use UserID.from_string to catch invalid user ids | ||||
|             if self.is_mine(UserID.from_string(user_id)): | ||||
|                 for device_id, algorithm in one_time_keys.items(): | ||||
|                     local_query.append((user_id, device_id, algorithm)) | ||||
|                 for device_id, algorithms in one_time_keys.items(): | ||||
|                     for algorithm, count in algorithms.items(): | ||||
|                         local_query.append((user_id, device_id, algorithm, count)) | ||||
|             else: | ||||
|                 domain = get_domain_from_id(user_id) | ||||
|                 remote_queries.setdefault(domain, {})[user_id] = one_time_keys | ||||
|  | @ -692,7 +703,7 @@ class E2eKeysHandler: | |||
|             device_keys = remote_queries[destination] | ||||
|             try: | ||||
|                 remote_result = await self.federation.claim_client_keys( | ||||
|                     destination, {"one_time_keys": device_keys}, timeout=timeout | ||||
|                     destination, device_keys, timeout=timeout | ||||
|                 ) | ||||
|                 for user_id, keys in remote_result["one_time_keys"].items(): | ||||
|                     if user_id in device_keys: | ||||
|  |  | |||
|  | @ -16,7 +16,8 @@ | |||
| 
 | ||||
| import logging | ||||
| import re | ||||
| from typing import TYPE_CHECKING, Any, Optional, Tuple | ||||
| from collections import Counter | ||||
| from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple | ||||
| 
 | ||||
| from synapse.api.errors import InvalidAPICallError, SynapseError | ||||
| from synapse.http.server import HttpServer | ||||
|  | @ -289,16 +290,40 @@ class OneTimeKeyServlet(RestServlet): | |||
|         await self.auth.get_user_by_req(request, allow_guest=True) | ||||
|         timeout = parse_integer(request, "timeout", 10 * 1000) | ||||
|         body = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         # Generate a count for each algorithm, which is hard-coded to 1. | ||||
|         query: Dict[str, Dict[str, Dict[str, int]]] = {} | ||||
|         for user_id, one_time_keys in body.get("one_time_keys", {}).items(): | ||||
|             for device_id, algorithm in one_time_keys.items(): | ||||
|                 query.setdefault(user_id, {})[device_id] = {algorithm: 1} | ||||
| 
 | ||||
|         result = await self.e2e_keys_handler.claim_one_time_keys( | ||||
|             body, timeout, always_include_fallback_keys=False | ||||
|             query, timeout, always_include_fallback_keys=False | ||||
|         ) | ||||
|         return 200, result | ||||
| 
 | ||||
| 
 | ||||
| class UnstableOneTimeKeyServlet(RestServlet): | ||||
|     """ | ||||
|     Identical to the stable endpoint (OneTimeKeyServlet) except it always includes | ||||
|     fallback keys in the response. | ||||
|     Identical to the stable endpoint (OneTimeKeyServlet) except it allows for | ||||
|     querying for multiple OTKs at once and always includes fallback keys in the | ||||
|     response. | ||||
| 
 | ||||
|     POST /keys/claim HTTP/1.1 | ||||
|     { | ||||
|       "one_time_keys": { | ||||
|         "<user_id>": { | ||||
|           "<device_id>": ["<algorithm>", ...] | ||||
|     } } } | ||||
| 
 | ||||
|     HTTP/1.1 200 OK | ||||
|     { | ||||
|       "one_time_keys": { | ||||
|         "<user_id>": { | ||||
|           "<device_id>": { | ||||
|             "<algorithm>:<key_id>": "<key_base64>" | ||||
|     } } } } | ||||
| 
 | ||||
|     """ | ||||
| 
 | ||||
|     PATTERNS = [re.compile(r"^/_matrix/client/unstable/org.matrix.msc3983/keys/claim$")] | ||||
|  | @ -313,8 +338,15 @@ class UnstableOneTimeKeyServlet(RestServlet): | |||
|         await self.auth.get_user_by_req(request, allow_guest=True) | ||||
|         timeout = parse_integer(request, "timeout", 10 * 1000) | ||||
|         body = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         # Generate a count for each algorithm. | ||||
|         query: Dict[str, Dict[str, Dict[str, int]]] = {} | ||||
|         for user_id, one_time_keys in body.get("one_time_keys", {}).items(): | ||||
|             for device_id, algorithms in one_time_keys.items(): | ||||
|                 query.setdefault(user_id, {})[device_id] = Counter(algorithms) | ||||
| 
 | ||||
|         result = await self.e2e_keys_handler.claim_one_time_keys( | ||||
|             body, timeout, always_include_fallback_keys=True | ||||
|             query, timeout, always_include_fallback_keys=True | ||||
|         ) | ||||
|         return 200, result | ||||
| 
 | ||||
|  |  | |||
|  | @ -1027,8 +1027,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
|         ... | ||||
| 
 | ||||
|     async def claim_e2e_one_time_keys( | ||||
|         self, query_list: Iterable[Tuple[str, str, str]] | ||||
|     ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: | ||||
|         self, query_list: Iterable[Tuple[str, str, str, int]] | ||||
|     ) -> Tuple[ | ||||
|         Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] | ||||
|     ]: | ||||
|         """Take a list of one time keys out of the database. | ||||
| 
 | ||||
|         Args: | ||||
|  | @ -1043,8 +1045,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
| 
 | ||||
|         @trace | ||||
|         def _claim_e2e_one_time_key_simple( | ||||
|             txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str | ||||
|         ) -> Optional[Tuple[str, str]]: | ||||
|             txn: LoggingTransaction, | ||||
|             user_id: str, | ||||
|             device_id: str, | ||||
|             algorithm: str, | ||||
|             count: int, | ||||
|         ) -> List[Tuple[str, str]]: | ||||
|             """Claim OTK for device for DBs that don't support RETURNING. | ||||
| 
 | ||||
|             Returns: | ||||
|  | @ -1055,36 +1061,41 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
|             sql = """ | ||||
|                 SELECT key_id, key_json FROM e2e_one_time_keys_json | ||||
|                 WHERE user_id = ? AND device_id = ? AND algorithm = ? | ||||
|                 LIMIT 1 | ||||
|                 LIMIT ? | ||||
|             """ | ||||
| 
 | ||||
|             txn.execute(sql, (user_id, device_id, algorithm)) | ||||
|             otk_row = txn.fetchone() | ||||
|             if otk_row is None: | ||||
|                 return None | ||||
|             txn.execute(sql, (user_id, device_id, algorithm, count)) | ||||
|             otk_rows = list(txn) | ||||
|             if not otk_rows: | ||||
|                 return [] | ||||
| 
 | ||||
|             key_id, key_json = otk_row | ||||
| 
 | ||||
|             self.db_pool.simple_delete_one_txn( | ||||
|             self.db_pool.simple_delete_many_txn( | ||||
|                 txn, | ||||
|                 table="e2e_one_time_keys_json", | ||||
|                 column="key_id", | ||||
|                 values=[otk_row[0] for otk_row in otk_rows], | ||||
|                 keyvalues={ | ||||
|                     "user_id": user_id, | ||||
|                     "device_id": device_id, | ||||
|                     "algorithm": algorithm, | ||||
|                     "key_id": key_id, | ||||
|                 }, | ||||
|             ) | ||||
|             self._invalidate_cache_and_stream( | ||||
|                 txn, self.count_e2e_one_time_keys, (user_id, device_id) | ||||
|             ) | ||||
| 
 | ||||
|             return f"{algorithm}:{key_id}", key_json | ||||
|             return [ | ||||
|                 (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows | ||||
|             ] | ||||
| 
 | ||||
|         @trace | ||||
|         def _claim_e2e_one_time_key_returning( | ||||
|             txn: LoggingTransaction, user_id: str, device_id: str, algorithm: str | ||||
|         ) -> Optional[Tuple[str, str]]: | ||||
|             txn: LoggingTransaction, | ||||
|             user_id: str, | ||||
|             device_id: str, | ||||
|             algorithm: str, | ||||
|             count: int, | ||||
|         ) -> List[Tuple[str, str]]: | ||||
|             """Claim OTK for device for DBs that support RETURNING. | ||||
| 
 | ||||
|             Returns: | ||||
|  | @ -1099,28 +1110,30 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
|                     AND key_id IN ( | ||||
|                         SELECT key_id FROM e2e_one_time_keys_json | ||||
|                         WHERE user_id = ? AND device_id = ? AND algorithm = ? | ||||
|                         LIMIT 1 | ||||
|                         LIMIT ? | ||||
|                     ) | ||||
|                 RETURNING key_id, key_json | ||||
|             """ | ||||
| 
 | ||||
|             txn.execute( | ||||
|                 sql, (user_id, device_id, algorithm, user_id, device_id, algorithm) | ||||
|                 sql, | ||||
|                 (user_id, device_id, algorithm, user_id, device_id, algorithm, count), | ||||
|             ) | ||||
|             otk_row = txn.fetchone() | ||||
|             if otk_row is None: | ||||
|                 return None | ||||
|             otk_rows = list(txn) | ||||
|             if not otk_rows: | ||||
|                 return [] | ||||
| 
 | ||||
|             self._invalidate_cache_and_stream( | ||||
|                 txn, self.count_e2e_one_time_keys, (user_id, device_id) | ||||
|             ) | ||||
| 
 | ||||
|             key_id, key_json = otk_row | ||||
|             return f"{algorithm}:{key_id}", key_json | ||||
|             return [ | ||||
|                 (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows | ||||
|             ] | ||||
| 
 | ||||
|         results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} | ||||
|         missing: List[Tuple[str, str, str]] = [] | ||||
|         for user_id, device_id, algorithm in query_list: | ||||
|         missing: List[Tuple[str, str, str, int]] = [] | ||||
|         for user_id, device_id, algorithm, count in query_list: | ||||
|             if self.database_engine.supports_returning: | ||||
|                 # If we support RETURNING clause we can use a single query that | ||||
|                 # allows us to use autocommit mode. | ||||
|  | @ -1130,21 +1143,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
|                 _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple | ||||
|                 db_autocommit = False | ||||
| 
 | ||||
|             claim_row = await self.db_pool.runInteraction( | ||||
|             claim_rows = await self.db_pool.runInteraction( | ||||
|                 "claim_e2e_one_time_keys", | ||||
|                 _claim_e2e_one_time_key, | ||||
|                 user_id, | ||||
|                 device_id, | ||||
|                 algorithm, | ||||
|                 count, | ||||
|                 db_autocommit=db_autocommit, | ||||
|             ) | ||||
|             if claim_row: | ||||
|             if claim_rows: | ||||
|                 device_results = results.setdefault(user_id, {}).setdefault( | ||||
|                     device_id, {} | ||||
|                 ) | ||||
|                 device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) | ||||
|             else: | ||||
|                 missing.append((user_id, device_id, algorithm)) | ||||
|                 for claim_row in claim_rows: | ||||
|                     device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) | ||||
|             # Did we get enough OTKs? | ||||
|             count -= len(claim_rows) | ||||
|             if count: | ||||
|                 missing.append((user_id, device_id, algorithm, count)) | ||||
| 
 | ||||
|         return results, missing | ||||
| 
 | ||||
|  |  | |||
|  | @ -195,11 +195,11 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         MISSING_KEYS = [ | ||||
|             # Known user, known device, missing algorithm. | ||||
|             ("@alice:example.org", "DEVICE_1", "signed_curve25519:DDDDHg"), | ||||
|             ("@alice:example.org", "DEVICE_2", "xyz", 1), | ||||
|             # Known user, missing device. | ||||
|             ("@alice:example.org", "DEVICE_3", "signed_curve25519:EEEEHg"), | ||||
|             ("@alice:example.org", "DEVICE_3", "signed_curve25519", 1), | ||||
|             # Unknown user. | ||||
|             ("@bob:example.org", "DEVICE_4", "signed_curve25519:FFFFHg"), | ||||
|             ("@bob:example.org", "DEVICE_4", "signed_curve25519", 1), | ||||
|         ] | ||||
| 
 | ||||
|         claimed_keys, missing = self.get_success( | ||||
|  | @ -207,9 +207,8 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): | |||
|                 self.service, | ||||
|                 [ | ||||
|                     # Found devices | ||||
|                     ("@alice:example.org", "DEVICE_1", "signed_curve25519:AAAAHg"), | ||||
|                     ("@alice:example.org", "DEVICE_1", "signed_curve25519:BBBBHg"), | ||||
|                     ("@alice:example.org", "DEVICE_2", "signed_curve25519:CCCCHg"), | ||||
|                     ("@alice:example.org", "DEVICE_1", "signed_curve25519", 1), | ||||
|                     ("@alice:example.org", "DEVICE_2", "signed_curve25519", 1), | ||||
|                 ] | ||||
|                 + MISSING_KEYS, | ||||
|             ) | ||||
|  |  | |||
|  | @ -160,7 +160,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         res2 = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||
|                 {local_user: {device_id: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=False, | ||||
|             ) | ||||
|  | @ -205,7 +205,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
|         # key | ||||
|         claim_res = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||
|                 {local_user: {device_id: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=False, | ||||
|             ) | ||||
|  | @ -224,7 +224,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
|         # claiming an OTK again should return the same fallback key | ||||
|         claim_res = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||
|                 {local_user: {device_id: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=False, | ||||
|             ) | ||||
|  | @ -273,7 +273,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         claim_res = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||
|                 {local_user: {device_id: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=False, | ||||
|             ) | ||||
|  | @ -285,7 +285,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         claim_res = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||
|                 {local_user: {device_id: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=False, | ||||
|             ) | ||||
|  | @ -306,7 +306,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         claim_res = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||
|                 {local_user: {device_id: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=False, | ||||
|             ) | ||||
|  | @ -347,7 +347,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
|         # return both. | ||||
|         claim_res = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||
|                 {local_user: {device_id: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=True, | ||||
|             ) | ||||
|  | @ -369,7 +369,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
|         # Claiming an OTK again should return only the fallback key. | ||||
|         claim_res = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, | ||||
|                 {local_user: {device_id: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=True, | ||||
|             ) | ||||
|  | @ -1052,7 +1052,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
| 
 | ||||
|         # Setup a response, but only for device 2. | ||||
|         self.appservice_api.claim_client_keys.return_value = make_awaitable( | ||||
|             ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1")]) | ||||
|             ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)]) | ||||
|         ) | ||||
| 
 | ||||
|         # we shouldn't have any unused fallback keys yet | ||||
|  | @ -1079,11 +1079,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
|         # query the fallback keys. | ||||
|         claim_res = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 { | ||||
|                     "one_time_keys": { | ||||
|                         local_user: {device_id_1: "alg1", device_id_2: "alg1"} | ||||
|                     } | ||||
|                 }, | ||||
|                 {local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=False, | ||||
|             ) | ||||
|  | @ -1128,7 +1124,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
|         # Claim OTKs, which will ask the appservice and do nothing else. | ||||
|         claim_res = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, | ||||
|                 {local_user: {device_id_1: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=True, | ||||
|             ) | ||||
|  | @ -1172,7 +1168,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
|         # uploaded fallback key. | ||||
|         claim_res = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, | ||||
|                 {local_user: {device_id_1: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=True, | ||||
|             ) | ||||
|  | @ -1205,7 +1201,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
|         # Claim OTKs, which will return information only from the database. | ||||
|         claim_res = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, | ||||
|                 {local_user: {device_id_1: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=True, | ||||
|             ) | ||||
|  | @ -1232,7 +1228,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
|         # Claim OTKs, which will return only the fallback key from the database. | ||||
|         claim_res = self.get_success( | ||||
|             self.handler.claim_one_time_keys( | ||||
|                 {"one_time_keys": {local_user: {device_id_1: "alg1"}}}, | ||||
|                 {local_user: {device_id_1: {"alg1": 1}}}, | ||||
|                 timeout=None, | ||||
|                 always_include_fallback_keys=True, | ||||
|             ) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Patrick Cloke
						Patrick Cloke