Merge pull request #1046 from matrix-org/markjh/direct_to_device
Start adding store-and-forward direct-to-device messagingpull/1050/head
						commit
						8c1e746f54
					
				|  | @ -35,6 +35,7 @@ SyncConfig = collections.namedtuple("SyncConfig", [ | |||
|     "filter_collection", | ||||
|     "is_guest", | ||||
|     "request_key", | ||||
|     "device_id", | ||||
| ]) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -113,6 +114,7 @@ class SyncResult(collections.namedtuple("SyncResult", [ | |||
|     "joined",  # JoinedSyncResult for each joined room. | ||||
|     "invited",  # InvitedSyncResult for each invited room. | ||||
|     "archived",  # ArchivedSyncResult for each archived room. | ||||
|     "to_device",  # List of direct messages for the device. | ||||
| ])): | ||||
|     __slots__ = [] | ||||
| 
 | ||||
|  | @ -126,7 +128,8 @@ class SyncResult(collections.namedtuple("SyncResult", [ | |||
|             self.joined or | ||||
|             self.invited or | ||||
|             self.archived or | ||||
|             self.account_data | ||||
|             self.account_data or | ||||
|             self.to_device | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -531,15 +534,52 @@ class SyncHandler(object): | |||
|             sync_result_builder, newly_joined_rooms, newly_joined_users | ||||
|         ) | ||||
| 
 | ||||
|         yield self._generate_sync_entry_for_to_device(sync_result_builder) | ||||
| 
 | ||||
|         defer.returnValue(SyncResult( | ||||
|             presence=sync_result_builder.presence, | ||||
|             account_data=sync_result_builder.account_data, | ||||
|             joined=sync_result_builder.joined, | ||||
|             invited=sync_result_builder.invited, | ||||
|             archived=sync_result_builder.archived, | ||||
|             to_device=sync_result_builder.to_device, | ||||
|             next_batch=sync_result_builder.now_token, | ||||
|         )) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _generate_sync_entry_for_to_device(self, sync_result_builder): | ||||
|         """Generates the portion of the sync response. Populates | ||||
|         `sync_result_builder` with the result. | ||||
| 
 | ||||
|         Args: | ||||
|             sync_result_builder(SyncResultBuilder) | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred(dict): A dictionary containing the per room account data. | ||||
|         """ | ||||
|         user_id = sync_result_builder.sync_config.user.to_string() | ||||
|         device_id = sync_result_builder.sync_config.device_id | ||||
|         now_token = sync_result_builder.now_token | ||||
|         since_stream_id = 0 | ||||
|         if sync_result_builder.since_token is not None: | ||||
|             since_stream_id = int(sync_result_builder.since_token.to_device_key) | ||||
| 
 | ||||
|         if since_stream_id: | ||||
|             logger.debug("Deleting messages up to %d", since_stream_id) | ||||
|             yield self.store.delete_messages_for_device( | ||||
|                 user_id, device_id, since_stream_id | ||||
|             ) | ||||
| 
 | ||||
|         logger.debug("Getting messages up to %d", now_token.to_device_key) | ||||
|         messages, stream_id = yield self.store.get_new_messages_for_device( | ||||
|             user_id, device_id, now_token.to_device_key | ||||
|         ) | ||||
|         logger.debug("Got messages up to %d: %r", stream_id, messages) | ||||
|         sync_result_builder.now_token = now_token.copy_and_replace( | ||||
|             "to_device_key", stream_id | ||||
|         ) | ||||
|         sync_result_builder.to_device = messages | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _generate_sync_entry_for_account_data(self, sync_result_builder): | ||||
|         """Generates the account data portion of the sync response. Populates | ||||
|  | @ -1110,6 +1150,7 @@ class SyncResultBuilder(object): | |||
|         self.joined = [] | ||||
|         self.invited = [] | ||||
|         self.archived = [] | ||||
|         self.device = [] | ||||
| 
 | ||||
| 
 | ||||
| class RoomSyncResultBuilder(object): | ||||
|  |  | |||
|  | @ -49,6 +49,7 @@ from synapse.rest.client.v2_alpha import ( | |||
|     notifications, | ||||
|     devices, | ||||
|     thirdparty, | ||||
|     sendtodevice, | ||||
| ) | ||||
| 
 | ||||
| from synapse.http.server import JsonResource | ||||
|  | @ -96,3 +97,4 @@ class ClientRestResource(JsonResource): | |||
|         notifications.register_servlets(hs, client_resource) | ||||
|         devices.register_servlets(hs, client_resource) | ||||
|         thirdparty.register_servlets(hs, client_resource) | ||||
|         sendtodevice.register_servlets(hs, client_resource) | ||||
|  |  | |||
|  | @ -0,0 +1,84 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2016 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import logging | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| from synapse.http.servlet import parse_json_object_from_request | ||||
| 
 | ||||
| from synapse.http import servlet | ||||
| from synapse.rest.client.v1.transactions import HttpTransactionStore | ||||
| from ._base import client_v2_patterns | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class SendToDeviceRestServlet(servlet.RestServlet): | ||||
|     PATTERNS = client_v2_patterns( | ||||
|         "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", | ||||
|         releases=[], v2_alpha=False | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         """ | ||||
|         Args: | ||||
|             hs (synapse.server.HomeServer): server | ||||
|         """ | ||||
|         super(SendToDeviceRestServlet, self).__init__() | ||||
|         self.hs = hs | ||||
|         self.auth = hs.get_auth() | ||||
|         self.store = hs.get_datastore() | ||||
|         self.is_mine_id = hs.is_mine_id | ||||
|         self.txns = HttpTransactionStore() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_PUT(self, request, message_type, txn_id): | ||||
|         try: | ||||
|             defer.returnValue( | ||||
|                 self.txns.get_client_transaction(request, txn_id) | ||||
|             ) | ||||
|         except KeyError: | ||||
|             pass | ||||
| 
 | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
| 
 | ||||
|         content = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         # TODO: Prod the notifier to wake up sync streams. | ||||
|         # TODO: Implement replication for the messages. | ||||
|         # TODO: Send the messages to remote servers if needed. | ||||
| 
 | ||||
|         local_messages = {} | ||||
|         for user_id, by_device in content["messages"].items(): | ||||
|             if self.is_mine_id(user_id): | ||||
|                 messages_by_device = { | ||||
|                     device_id: { | ||||
|                         "content": message_content, | ||||
|                         "type": message_type, | ||||
|                         "sender": requester.user.to_string(), | ||||
|                     } | ||||
|                     for device_id, message_content in by_device.items() | ||||
|                 } | ||||
|                 local_messages[user_id] = messages_by_device | ||||
| 
 | ||||
|         yield self.store.add_messages_to_device_inbox(local_messages) | ||||
| 
 | ||||
|         response = (200, {}) | ||||
|         self.txns.store_client_transaction(request, txn_id, response) | ||||
|         defer.returnValue(response) | ||||
| 
 | ||||
| 
 | ||||
| def register_servlets(hs, http_server): | ||||
|     SendToDeviceRestServlet(hs).register(http_server) | ||||
|  | @ -97,6 +97,7 @@ class SyncRestServlet(RestServlet): | |||
|             request, allow_guest=True | ||||
|         ) | ||||
|         user = requester.user | ||||
|         device_id = requester.device_id | ||||
| 
 | ||||
|         timeout = parse_integer(request, "timeout", default=0) | ||||
|         since = parse_string(request, "since") | ||||
|  | @ -109,12 +110,12 @@ class SyncRestServlet(RestServlet): | |||
| 
 | ||||
|         logger.info( | ||||
|             "/sync: user=%r, timeout=%r, since=%r," | ||||
|             " set_presence=%r, filter_id=%r" % ( | ||||
|                 user, timeout, since, set_presence, filter_id | ||||
|             " set_presence=%r, filter_id=%r, device_id=%r" % ( | ||||
|                 user, timeout, since, set_presence, filter_id, device_id | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         request_key = (user, timeout, since, filter_id, full_state) | ||||
|         request_key = (user, timeout, since, filter_id, full_state, device_id) | ||||
| 
 | ||||
|         if filter_id: | ||||
|             if filter_id.startswith('{'): | ||||
|  | @ -136,6 +137,7 @@ class SyncRestServlet(RestServlet): | |||
|             filter_collection=filter, | ||||
|             is_guest=requester.is_guest, | ||||
|             request_key=request_key, | ||||
|             device_id=device_id, | ||||
|         ) | ||||
| 
 | ||||
|         if since is not None: | ||||
|  | @ -173,6 +175,7 @@ class SyncRestServlet(RestServlet): | |||
| 
 | ||||
|         response_content = { | ||||
|             "account_data": {"events": sync_result.account_data}, | ||||
|             "to_device": {"events": sync_result.to_device}, | ||||
|             "presence": self.encode_presence( | ||||
|                 sync_result.presence, time_now | ||||
|             ), | ||||
|  |  | |||
|  | @ -36,6 +36,7 @@ from .push_rule import PushRuleStore | |||
| from .media_repository import MediaRepositoryStore | ||||
| from .rejections import RejectionsStore | ||||
| from .event_push_actions import EventPushActionsStore | ||||
| from .deviceinbox import DeviceInboxStore | ||||
| 
 | ||||
| from .state import StateStore | ||||
| from .signatures import SignatureStore | ||||
|  | @ -84,6 +85,7 @@ class DataStore(RoomMemberStore, RoomStore, | |||
|                 OpenIdStore, | ||||
|                 ClientIpStore, | ||||
|                 DeviceStore, | ||||
|                 DeviceInboxStore, | ||||
|                 ): | ||||
| 
 | ||||
|     def __init__(self, db_conn, hs): | ||||
|  | @ -108,6 +110,9 @@ class DataStore(RoomMemberStore, RoomStore, | |||
|         self._presence_id_gen = StreamIdGenerator( | ||||
|             db_conn, "presence_stream", "stream_id" | ||||
|         ) | ||||
|         self._device_inbox_id_gen = StreamIdGenerator( | ||||
|             db_conn, "device_inbox", "stream_id" | ||||
|         ) | ||||
| 
 | ||||
|         self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") | ||||
|         self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id") | ||||
|  |  | |||
|  | @ -0,0 +1,140 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2016 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import logging | ||||
| import ujson | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from ._base import SQLBaseStore | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class DeviceInboxStore(SQLBaseStore): | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def add_messages_to_device_inbox(self, messages_by_user_then_device): | ||||
|         """ | ||||
|         Args: | ||||
|             messages_by_user_and_device(dict): | ||||
|                 Dictionary of user_id to device_id to message. | ||||
|         Returns: | ||||
|             A deferred that resolves when the messages have been inserted. | ||||
|         """ | ||||
| 
 | ||||
|         def select_devices_txn(txn, user_id, devices): | ||||
|             if not devices: | ||||
|                 return [] | ||||
|             sql = ( | ||||
|                 "SELECT user_id, device_id FROM devices" | ||||
|                 " WHERE user_id = ? AND device_id IN (" | ||||
|                 + ",".join("?" * len(devices)) | ||||
|                 + ")" | ||||
|             ) | ||||
|             # TODO: Maybe this needs to be done in batches if there are | ||||
|             # too many local devices for a given user. | ||||
|             args = [user_id] + devices | ||||
|             txn.execute(sql, args) | ||||
|             return [tuple(row) for row in txn.fetchall()] | ||||
| 
 | ||||
|         def add_messages_to_device_inbox_txn(txn, stream_id): | ||||
|             local_users_and_devices = set() | ||||
|             for user_id, messages_by_device in messages_by_user_then_device.items(): | ||||
|                 local_users_and_devices.update( | ||||
|                     select_devices_txn(txn, user_id, messages_by_device.keys()) | ||||
|                 ) | ||||
| 
 | ||||
|             sql = ( | ||||
|                 "INSERT INTO device_inbox" | ||||
|                 " (user_id, device_id, stream_id, message_json)" | ||||
|                 " VALUES (?,?,?,?)" | ||||
|             ) | ||||
|             rows = [] | ||||
|             for user_id, messages_by_device in messages_by_user_then_device.items(): | ||||
|                 for device_id, message in messages_by_device.items(): | ||||
|                     message_json = ujson.dumps(message) | ||||
|                     # Only insert into the local inbox if the device exists on | ||||
|                     # this server | ||||
|                     if (user_id, device_id) in local_users_and_devices: | ||||
|                         rows.append((user_id, device_id, stream_id, message_json)) | ||||
| 
 | ||||
|             txn.executemany(sql, rows) | ||||
| 
 | ||||
|         with self._device_inbox_id_gen.get_next() as stream_id: | ||||
|             yield self.runInteraction( | ||||
|                 "add_messages_to_device_inbox", | ||||
|                 add_messages_to_device_inbox_txn, | ||||
|                 stream_id | ||||
|             ) | ||||
| 
 | ||||
|     def get_new_messages_for_device( | ||||
|         self, user_id, device_id, current_stream_id, limit=100 | ||||
|     ): | ||||
|         """ | ||||
|         Args: | ||||
|             user_id(str): The recipient user_id. | ||||
|             device_id(str): The recipient device_id. | ||||
|             current_stream_id(int): The current position of the to device | ||||
|                 message stream. | ||||
|         Returns: | ||||
|             Deferred ([dict], int): List of messages for the device and where | ||||
|                 in the stream the messages got to. | ||||
|         """ | ||||
|         def get_new_messages_for_device_txn(txn): | ||||
|             sql = ( | ||||
|                 "SELECT stream_id, message_json FROM device_inbox" | ||||
|                 " WHERE user_id = ? AND device_id = ?" | ||||
|                 " AND stream_id <= ?" | ||||
|                 " ORDER BY stream_id ASC" | ||||
|                 " LIMIT ?" | ||||
|             ) | ||||
|             txn.execute(sql, (user_id, device_id, current_stream_id, limit)) | ||||
|             messages = [] | ||||
|             for row in txn.fetchall(): | ||||
|                 stream_pos = row[0] | ||||
|                 messages.append(ujson.loads(row[1])) | ||||
|             if len(messages) < limit: | ||||
|                 stream_pos = current_stream_id | ||||
|             return (messages, stream_pos) | ||||
| 
 | ||||
|         return self.runInteraction( | ||||
|             "get_new_messages_for_device", get_new_messages_for_device_txn, | ||||
|         ) | ||||
| 
 | ||||
|     def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): | ||||
|         """ | ||||
|         Args: | ||||
|             user_id(str): The recipient user_id. | ||||
|             device_id(str): The recipient device_id. | ||||
|             up_to_stream_id(int): Where to delete messages up to. | ||||
|         Returns: | ||||
|             A deferred that resolves when the messages have been deleted. | ||||
|         """ | ||||
|         def delete_messages_for_device_txn(txn): | ||||
|             sql = ( | ||||
|                 "DELETE FROM device_inbox" | ||||
|                 " WHERE user_id = ? AND device_id = ?" | ||||
|                 " AND stream_id <= ?" | ||||
|             ) | ||||
|             txn.execute(sql, (user_id, device_id, up_to_stream_id)) | ||||
| 
 | ||||
|         return self.runInteraction( | ||||
|             "delete_messages_for_device", delete_messages_for_device_txn | ||||
|         ) | ||||
| 
 | ||||
|     def get_to_device_stream_token(self): | ||||
|         return self._device_inbox_id_gen.get_current_token() | ||||
|  | @ -0,0 +1,24 @@ | |||
| /* Copyright 2016 OpenMarket Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *    http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| CREATE TABLE device_inbox ( | ||||
|     user_id TEXT NOT NULL, | ||||
|     device_id TEXT NOT NULL, | ||||
|     stream_id BIGINT NOT NULL, | ||||
|     message_json TEXT NOT NULL -- {"type":, "sender":, "content",} | ||||
| ); | ||||
| 
 | ||||
| CREATE INDEX device_inbox_user_stream_id ON device_inbox(user_id, device_id, stream_id); | ||||
| CREATE INDEX device_inbox_stream_id ON device_inbox(stream_id); | ||||
|  | @ -43,6 +43,7 @@ class EventSources(object): | |||
|     @defer.inlineCallbacks | ||||
|     def get_current_token(self, direction='f'): | ||||
|         push_rules_key, _ = self.store.get_push_rules_stream_token() | ||||
|         to_device_key = self.store.get_to_device_stream_token() | ||||
| 
 | ||||
|         token = StreamToken( | ||||
|             room_key=( | ||||
|  | @ -61,5 +62,6 @@ class EventSources(object): | |||
|                 yield self.sources["account_data"].get_current_key() | ||||
|             ), | ||||
|             push_rules_key=push_rules_key, | ||||
|             to_device_key=to_device_key, | ||||
|         ) | ||||
|         defer.returnValue(token) | ||||
|  |  | |||
|  | @ -154,6 +154,7 @@ class StreamToken( | |||
|         "receipt_key", | ||||
|         "account_data_key", | ||||
|         "push_rules_key", | ||||
|         "to_device_key", | ||||
|     )) | ||||
| ): | ||||
|     _SEPARATOR = "_" | ||||
|  | @ -190,6 +191,7 @@ class StreamToken( | |||
|             or (int(other.receipt_key) < int(self.receipt_key)) | ||||
|             or (int(other.account_data_key) < int(self.account_data_key)) | ||||
|             or (int(other.push_rules_key) < int(self.push_rules_key)) | ||||
|             or (int(other.to_device_key) < int(self.to_device_key)) | ||||
|         ) | ||||
| 
 | ||||
|     def copy_and_advance(self, key, new_value): | ||||
|  |  | |||
|  | @ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_topo_token_is_accepted(self): | ||||
|         token = "t1-0_0_0_0_0_0" | ||||
|         token = "t1-0_0_0_0_0_0_0" | ||||
|         (code, response) = yield self.mock_resource.trigger_get( | ||||
|             "/rooms/%s/messages?access_token=x&from=%s" % | ||||
|             (self.room_id, token)) | ||||
|  | @ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_stream_token_is_accepted_for_fwd_pagianation(self): | ||||
|         token = "s0_0_0_0_0_0" | ||||
|         token = "s0_0_0_0_0_0_0" | ||||
|         (code, response) = yield self.mock_resource.trigger_get( | ||||
|             "/rooms/%s/messages?access_token=x&from=%s" % | ||||
|             (self.room_id, token)) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Mark Haines
						Mark Haines