Add new API appservice specific public room list
							parent
							
								
									194b6259c5
								
							
						
					
					
						commit
						f32fb65552
					
				|  | @ -89,6 +89,9 @@ class ApplicationService(object): | |||
|         self.namespaces = self._check_namespaces(namespaces) | ||||
|         self.id = id | ||||
| 
 | ||||
|         if "|" in self.id: | ||||
|             raise Exception("application service ID cannot contain '|' character") | ||||
| 
 | ||||
|         # .protocols is a publicly visible field | ||||
|         if protocols: | ||||
|             self.protocols = set(protocols) | ||||
|  |  | |||
|  | @ -19,6 +19,7 @@ from synapse.api.errors import CodeMessageException | |||
| from synapse.http.client import SimpleHttpClient | ||||
| from synapse.events.utils import serialize_event | ||||
| from synapse.util.caches.response_cache import ResponseCache | ||||
| from synapse.types import ThirdPartyInstanceID | ||||
| 
 | ||||
| import logging | ||||
| import urllib | ||||
|  | @ -177,6 +178,14 @@ class ApplicationServiceApi(SimpleHttpClient): | |||
|                                    " valid result", uri) | ||||
|                     defer.returnValue(None) | ||||
| 
 | ||||
|                 for instance in info.get("instances", []): | ||||
|                     instance["appservice_id"] = service.id | ||||
|                     network_id = instance.get("network_id", None) | ||||
|                     if network_id is not None: | ||||
|                         instance["network_id"] = ThirdPartyInstanceID( | ||||
|                             service.id, network_id, | ||||
|                         ).to_string() | ||||
| 
 | ||||
|                 defer.returnValue(info) | ||||
|             except Exception as ex: | ||||
|                 logger.warning("query_3pe_protocol to %s threw exception %s", | ||||
|  |  | |||
|  | @ -655,12 +655,15 @@ class FederationClient(FederationBase): | |||
|         raise RuntimeError("Failed to send to any server.") | ||||
| 
 | ||||
|     def get_public_rooms(self, destination, limit=None, since_token=None, | ||||
|                          search_filter=None): | ||||
|                          search_filter=None, include_all_networks=False, | ||||
|                          third_party_instance_id=None): | ||||
|         if destination == self.server_name: | ||||
|             return | ||||
| 
 | ||||
|         return self.transport_layer.get_public_rooms( | ||||
|             destination, limit, since_token, search_filter | ||||
|             destination, limit, since_token, search_filter, | ||||
|             include_all_networks=include_all_networks, | ||||
|             third_party_instance_id=third_party_instance_id, | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  |  | |||
|  | @ -249,10 +249,15 @@ class TransportLayerClient(object): | |||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def get_public_rooms(self, remote_server, limit, since_token, | ||||
|                          search_filter=None): | ||||
|                          search_filter=None, include_all_networks=False, | ||||
|                          third_party_instance_id=None): | ||||
|         path = PREFIX + "/publicRooms" | ||||
| 
 | ||||
|         args = {} | ||||
|         args = { | ||||
|             "include_all_networks": "true" if include_all_networks else "false", | ||||
|         } | ||||
|         if third_party_instance_id: | ||||
|             args["third_party_instance_id"] = third_party_instance_id, | ||||
|         if limit: | ||||
|             args["limit"] = [str(limit)] | ||||
|         if since_token: | ||||
|  |  | |||
|  | @ -20,9 +20,11 @@ from synapse.api.errors import Codes, SynapseError | |||
| from synapse.http.server import JsonResource | ||||
| from synapse.http.servlet import ( | ||||
|     parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, | ||||
|     parse_boolean_from_args, | ||||
| ) | ||||
| from synapse.util.ratelimitutils import FederationRateLimiter | ||||
| from synapse.util.versionstring import get_version_string | ||||
| from synapse.types import ThirdPartyInstanceID | ||||
| 
 | ||||
| import functools | ||||
| import logging | ||||
|  | @ -558,8 +560,23 @@ class PublicRoomList(BaseFederationServlet): | |||
|     def on_GET(self, origin, content, query): | ||||
|         limit = parse_integer_from_args(query, "limit", 0) | ||||
|         since_token = parse_string_from_args(query, "since", None) | ||||
|         include_all_networks = parse_boolean_from_args( | ||||
|             query, "include_all_networks", False | ||||
|         ) | ||||
|         third_party_instance_id = parse_string_from_args( | ||||
|             query, "third_party_instance_id", None | ||||
|         ) | ||||
| 
 | ||||
|         if include_all_networks: | ||||
|             network_tuple = None | ||||
|         elif third_party_instance_id: | ||||
|             network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) | ||||
|         else: | ||||
|             network_tuple = ThirdPartyInstanceID(None, None) | ||||
| 
 | ||||
|         data = yield self.room_list_handler.get_local_public_room_list( | ||||
|             limit, since_token | ||||
|             limit, since_token, | ||||
|             network_tuple=network_tuple | ||||
|         ) | ||||
|         defer.returnValue((200, data)) | ||||
| 
 | ||||
|  |  | |||
|  | @ -339,3 +339,15 @@ class DirectoryHandler(BaseHandler): | |||
|         yield self.auth.check_can_change_room_list(room_id, requester.user) | ||||
| 
 | ||||
|         yield self.store.set_room_is_public(room_id, visibility == "public") | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def edit_published_appservice_room_list(self, appservice_id, network_id, | ||||
|                                             room_id, visibility): | ||||
|         """Edit the appservice/network specific public room list. | ||||
|         """ | ||||
|         if visibility not in ["public", "private"]: | ||||
|             raise SynapseError(400, "Invalid visibility setting") | ||||
| 
 | ||||
|         yield self.store.set_room_is_public_appservice( | ||||
|             room_id, appservice_id, network_id, visibility == "public" | ||||
|         ) | ||||
|  |  | |||
|  | @ -22,6 +22,7 @@ from synapse.api.constants import ( | |||
| ) | ||||
| from synapse.util.async import concurrently_execute | ||||
| from synapse.util.caches.response_cache import ResponseCache | ||||
| from synapse.types import ThirdPartyInstanceID | ||||
| 
 | ||||
| from collections import namedtuple | ||||
| from unpaddedbase64 import encode_base64, decode_base64 | ||||
|  | @ -34,6 +35,10 @@ logger = logging.getLogger(__name__) | |||
| REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 | ||||
| 
 | ||||
| 
 | ||||
| # This is used to indicate we should only return rooms published to the main list. | ||||
| EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) | ||||
| 
 | ||||
| 
 | ||||
| class RoomListHandler(BaseHandler): | ||||
|     def __init__(self, hs): | ||||
|         super(RoomListHandler, self).__init__(hs) | ||||
|  | @ -41,10 +46,27 @@ class RoomListHandler(BaseHandler): | |||
|         self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000) | ||||
| 
 | ||||
|     def get_local_public_room_list(self, limit=None, since_token=None, | ||||
|                                    search_filter=None): | ||||
|         if search_filter: | ||||
|                                    search_filter=None, | ||||
|                                    network_tuple=EMTPY_THIRD_PARTY_ID,): | ||||
|         """Generate a local public room list. | ||||
| 
 | ||||
|         There are multiple different lists: the main one plus one per third | ||||
|         party network. A client can ask for a specific list or to return all. | ||||
| 
 | ||||
|         Args: | ||||
|             limit (int) | ||||
|             since_token (str) | ||||
|             search_filter (dict) | ||||
|             network_tuple (ThirdPartyInstanceID): Which public list to use. | ||||
|                 This can be (None, None) to indicate the main list, or a particular | ||||
|                 appservice and network id to use an appservice specific one. | ||||
|                 Setting to None returns all public rooms across all lists. | ||||
|         """ | ||||
|         if search_filter or network_tuple is not (None, None): | ||||
|             # We explicitly don't bother caching searches. | ||||
|             return self._get_public_room_list(limit, since_token, search_filter) | ||||
|             return self._get_public_room_list( | ||||
|                 limit, since_token, search_filter, network_tuple=network_tuple, | ||||
|             ) | ||||
| 
 | ||||
|         result = self.response_cache.get((limit, since_token)) | ||||
|         if not result: | ||||
|  | @ -56,7 +78,8 @@ class RoomListHandler(BaseHandler): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _get_public_room_list(self, limit=None, since_token=None, | ||||
|                               search_filter=None): | ||||
|                               search_filter=None, | ||||
|                               network_tuple=EMTPY_THIRD_PARTY_ID,): | ||||
|         if since_token and since_token != "END": | ||||
|             since_token = RoomListNextBatch.from_token(since_token) | ||||
|         else: | ||||
|  | @ -73,14 +96,15 @@ class RoomListHandler(BaseHandler): | |||
|             current_public_id = yield self.store.get_current_public_room_stream_id() | ||||
|             public_room_stream_id = since_token.public_room_stream_id | ||||
|             newly_visible, newly_unpublished = yield self.store.get_public_room_changes( | ||||
|                 public_room_stream_id, current_public_id | ||||
|                 public_room_stream_id, current_public_id, | ||||
|                 network_tuple=network_tuple, | ||||
|             ) | ||||
|         else: | ||||
|             stream_token = yield self.store.get_room_max_stream_ordering() | ||||
|             public_room_stream_id = yield self.store.get_current_public_room_stream_id() | ||||
| 
 | ||||
|         room_ids = yield self.store.get_public_room_ids_at_stream_id( | ||||
|             public_room_stream_id | ||||
|             public_room_stream_id, network_tuple=network_tuple, | ||||
|         ) | ||||
| 
 | ||||
|         # We want to return rooms in a particular order: the number of joined | ||||
|  | @ -311,7 +335,8 @@ class RoomListHandler(BaseHandler): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_remote_public_room_list(self, server_name, limit=None, since_token=None, | ||||
|                                     search_filter=None): | ||||
|                                     search_filter=None, include_all_networks=False, | ||||
|                                     third_party_instance_id=None,): | ||||
|         if search_filter: | ||||
|             # We currently don't support searching across federation, so we have | ||||
|             # to do it manually without pagination | ||||
|  | @ -320,6 +345,8 @@ class RoomListHandler(BaseHandler): | |||
| 
 | ||||
|         res = yield self._get_remote_list_cached( | ||||
|             server_name, limit=limit, since_token=since_token, | ||||
|             include_all_networks=include_all_networks, | ||||
|             third_party_instance_id=third_party_instance_id, | ||||
|         ) | ||||
| 
 | ||||
|         if search_filter: | ||||
|  | @ -332,22 +359,30 @@ class RoomListHandler(BaseHandler): | |||
|         defer.returnValue(res) | ||||
| 
 | ||||
|     def _get_remote_list_cached(self, server_name, limit=None, since_token=None, | ||||
|                                 search_filter=None): | ||||
|                                 search_filter=None, include_all_networks=False, | ||||
|                                 third_party_instance_id=None,): | ||||
|         repl_layer = self.hs.get_replication_layer() | ||||
|         if search_filter: | ||||
|             # We can't cache when asking for search | ||||
|             return repl_layer.get_public_rooms( | ||||
|                 server_name, limit=limit, since_token=since_token, | ||||
|                 search_filter=search_filter, | ||||
|                 search_filter=search_filter, include_all_networks=include_all_networks, | ||||
|                 third_party_instance_id=third_party_instance_id, | ||||
|             ) | ||||
| 
 | ||||
|         result = self.remote_response_cache.get((server_name, limit, since_token)) | ||||
|         key = ( | ||||
|             server_name, limit, since_token, include_all_networks, | ||||
|             third_party_instance_id, | ||||
|         ) | ||||
|         result = self.remote_response_cache.get(key) | ||||
|         if not result: | ||||
|             result = self.remote_response_cache.set( | ||||
|                 (server_name, limit, since_token), | ||||
|                 key, | ||||
|                 repl_layer.get_public_rooms( | ||||
|                     server_name, limit=limit, since_token=since_token, | ||||
|                     search_filter=search_filter, | ||||
|                     include_all_networks=include_all_networks, | ||||
|                     third_party_instance_id=third_party_instance_id, | ||||
|                 ) | ||||
|             ) | ||||
|         return result | ||||
|  |  | |||
|  | @ -78,12 +78,16 @@ def parse_boolean(request, name, default=None, required=False): | |||
|             parameter is present and not one of "true" or "false". | ||||
|     """ | ||||
| 
 | ||||
|     if name in request.args: | ||||
|     return parse_boolean_from_args(request.args, name, default, required) | ||||
| 
 | ||||
| 
 | ||||
| def parse_boolean_from_args(args, name, default=None, required=False): | ||||
|     if name in args: | ||||
|         try: | ||||
|             return { | ||||
|                 "true": True, | ||||
|                 "false": False, | ||||
|             }[request.args[name][0]] | ||||
|             }[args[name][0]] | ||||
|         except: | ||||
|             message = ( | ||||
|                 "Boolean query parameter %r must be one of" | ||||
|  |  | |||
|  | @ -475,7 +475,7 @@ class ReplicationResource(Resource): | |||
|             ) | ||||
|             upto_token = _position_from_rows(public_rooms_rows, current_position) | ||||
|             writer.write_header_and_rows("public_rooms", public_rooms_rows, ( | ||||
|                 "position", "room_id", "visibility" | ||||
|                 "position", "room_id", "visibility", "appservice_id", "network_id", | ||||
|             ), position=upto_token) | ||||
| 
 | ||||
|     def federation(self, writer, current_token, limit, request_streams, federation_ack): | ||||
|  |  | |||
|  | @ -31,6 +31,7 @@ logger = logging.getLogger(__name__) | |||
| def register_servlets(hs, http_server): | ||||
|     ClientDirectoryServer(hs).register(http_server) | ||||
|     ClientDirectoryListServer(hs).register(http_server) | ||||
|     ClientAppserviceDirectoryListServer(hs).register(http_server) | ||||
| 
 | ||||
| 
 | ||||
| class ClientDirectoryServer(ClientV1RestServlet): | ||||
|  | @ -184,3 +185,36 @@ class ClientDirectoryListServer(ClientV1RestServlet): | |||
|         ) | ||||
| 
 | ||||
|         defer.returnValue((200, {})) | ||||
| 
 | ||||
| 
 | ||||
| class ClientAppserviceDirectoryListServer(ClientV1RestServlet): | ||||
|     PATTERNS = client_path_patterns( | ||||
|         "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$" | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(ClientAppserviceDirectoryListServer, self).__init__(hs) | ||||
|         self.store = hs.get_datastore() | ||||
|         self.handlers = hs.get_handlers() | ||||
| 
 | ||||
|     def on_PUT(self, request, network_id, room_id): | ||||
|         content = parse_json_object_from_request(request) | ||||
|         visibility = content.get("visibility", "public") | ||||
|         return self._edit(request, network_id, room_id, visibility) | ||||
| 
 | ||||
|     def on_DELETE(self, request, network_id, room_id): | ||||
|         return self._edit(request, network_id, room_id, "private") | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _edit(self, request, network_id, room_id, visibility): | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
|         if not requester.app_service: | ||||
|             raise AuthError( | ||||
|                 403, "Only appservices can edit the appservice published room list" | ||||
|             ) | ||||
| 
 | ||||
|         yield self.handlers.directory_handler.edit_published_appservice_room_list( | ||||
|             requester.app_service.id, network_id, room_id, visibility, | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue((200, {})) | ||||
|  |  | |||
|  | @ -21,7 +21,7 @@ from synapse.api.errors import SynapseError, Codes, AuthError | |||
| from synapse.streams.config import PaginationConfig | ||||
| from synapse.api.constants import EventTypes, Membership | ||||
| from synapse.api.filtering import Filter | ||||
| from synapse.types import UserID, RoomID, RoomAlias | ||||
| from synapse.types import UserID, RoomID, RoomAlias, ThirdPartyInstanceID | ||||
| from synapse.events.utils import serialize_event, format_event_for_client_v2 | ||||
| from synapse.http.servlet import ( | ||||
|     parse_json_object_from_request, parse_string, parse_integer | ||||
|  | @ -321,6 +321,20 @@ class PublicRoomListRestServlet(ClientV1RestServlet): | |||
|         since_token = content.get("since", None) | ||||
|         search_filter = content.get("filter", None) | ||||
| 
 | ||||
|         include_all_networks = content.get("include_all_networks", False) | ||||
|         third_party_instance_id = content.get("third_party_instance_id", None) | ||||
| 
 | ||||
|         if include_all_networks: | ||||
|             network_tuple = None | ||||
|             if third_party_instance_id is not None: | ||||
|                 raise SynapseError( | ||||
|                     400, "Can't use include_all_networks with an explicit network" | ||||
|                 ) | ||||
|         elif third_party_instance_id is None: | ||||
|             network_tuple = ThirdPartyInstanceID(None, None) | ||||
|         else: | ||||
|             network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) | ||||
| 
 | ||||
|         handler = self.hs.get_room_list_handler() | ||||
|         if server: | ||||
|             data = yield handler.get_remote_public_room_list( | ||||
|  | @ -328,12 +342,15 @@ class PublicRoomListRestServlet(ClientV1RestServlet): | |||
|                 limit=limit, | ||||
|                 since_token=since_token, | ||||
|                 search_filter=search_filter, | ||||
|                 include_all_networks=include_all_networks, | ||||
|                 third_party_instance_id=third_party_instance_id, | ||||
|             ) | ||||
|         else: | ||||
|             data = yield handler.get_local_public_room_list( | ||||
|                 limit=limit, | ||||
|                 since_token=since_token, | ||||
|                 search_filter=search_filter, | ||||
|                 network_tuple=network_tuple, | ||||
|             ) | ||||
| 
 | ||||
|         defer.returnValue((200, data)) | ||||
|  |  | |||
|  | @ -106,7 +106,11 @@ class RoomStore(SQLBaseStore): | |||
|             entries = self._simple_select_list_txn( | ||||
|                 txn, | ||||
|                 table="public_room_list_stream", | ||||
|                 keyvalues={"room_id": room_id}, | ||||
|                 keyvalues={ | ||||
|                     "room_id": room_id, | ||||
|                     "appservice_id": None, | ||||
|                     "network_id": None, | ||||
|                 }, | ||||
|                 retcols=("stream_id", "visibility"), | ||||
|             ) | ||||
| 
 | ||||
|  | @ -124,6 +128,8 @@ class RoomStore(SQLBaseStore): | |||
|                         "stream_id": next_id, | ||||
|                         "room_id": room_id, | ||||
|                         "visibility": is_public, | ||||
|                         "appservice_id": None, | ||||
|                         "network_id": None, | ||||
|                     } | ||||
|                 ) | ||||
| 
 | ||||
|  | @ -133,6 +139,73 @@ class RoomStore(SQLBaseStore): | |||
|                 set_room_is_public_txn, next_id, | ||||
|             ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def set_room_is_public_appservice(self, room_id, appservice_id, network_id, | ||||
|                                       is_public): | ||||
|         """Edit the appservice/network specific public room list. | ||||
|         """ | ||||
|         def set_room_is_public_appservice_txn(txn, next_id): | ||||
|             if is_public: | ||||
|                 try: | ||||
|                     self._simple_insert_txn( | ||||
|                         txn, | ||||
|                         table="appservice_room_list", | ||||
|                         values={ | ||||
|                             "appservice_id": appservice_id, | ||||
|                             "network_id": "network_id", | ||||
|                             "room_id": room_id | ||||
|                         }, | ||||
|                     ) | ||||
|                 except self.database_engine.module.IntegrityError: | ||||
|                     # We've already inserted, nothing to do. | ||||
|                     return | ||||
|             else: | ||||
|                 self._simple_delete_txn( | ||||
|                     txn, | ||||
|                     table="appservice_room_list", | ||||
|                     keyvalues={ | ||||
|                         "appservice_id": appservice_id, | ||||
|                         "network_id": network_id, | ||||
|                         "room_id": room_id | ||||
|                     }, | ||||
|                 ) | ||||
| 
 | ||||
|             entries = self._simple_select_list_txn( | ||||
|                 txn, | ||||
|                 table="public_room_list_stream", | ||||
|                 keyvalues={ | ||||
|                     "room_id": room_id, | ||||
|                     "appservice_id": appservice_id, | ||||
|                     "network_id": network_id, | ||||
|                 }, | ||||
|                 retcols=("stream_id", "visibility"), | ||||
|             ) | ||||
| 
 | ||||
|             entries.sort(key=lambda r: r["stream_id"]) | ||||
| 
 | ||||
|             add_to_stream = True | ||||
|             if entries: | ||||
|                 add_to_stream = bool(entries[-1]["visibility"]) != is_public | ||||
| 
 | ||||
|             if add_to_stream: | ||||
|                 self._simple_insert_txn( | ||||
|                     txn, | ||||
|                     table="public_room_list_stream", | ||||
|                     values={ | ||||
|                         "stream_id": next_id, | ||||
|                         "room_id": room_id, | ||||
|                         "visibility": is_public, | ||||
|                         "appservice_id": appservice_id, | ||||
|                         "network_id": network_id, | ||||
|                     } | ||||
|                 ) | ||||
| 
 | ||||
|         with self._public_room_id_gen.get_next() as next_id: | ||||
|             yield self.runInteraction( | ||||
|                 "set_room_is_public_appservice", | ||||
|                 set_room_is_public_appservice_txn, next_id, | ||||
|             ) | ||||
| 
 | ||||
|     def get_public_room_ids(self): | ||||
|         return self._simple_select_onecol( | ||||
|             table="rooms", | ||||
|  | @ -259,38 +332,95 @@ class RoomStore(SQLBaseStore): | |||
|     def get_current_public_room_stream_id(self): | ||||
|         return self._public_room_id_gen.get_current_token() | ||||
| 
 | ||||
|     def get_public_room_ids_at_stream_id(self, stream_id): | ||||
|     def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): | ||||
|         """Get pulbic rooms for a particular list, or across all lists. | ||||
| 
 | ||||
|         Args: | ||||
|             stream_id (int) | ||||
|             network_tuple (ThirdPartyInstanceID): The list to use (None, None) | ||||
|                 means the main list, None means all lsits. | ||||
|         """ | ||||
|         return self.runInteraction( | ||||
|             "get_public_room_ids_at_stream_id", | ||||
|             self.get_public_room_ids_at_stream_id_txn, stream_id | ||||
|             self.get_public_room_ids_at_stream_id_txn, | ||||
|             stream_id, network_tuple=network_tuple | ||||
|         ) | ||||
| 
 | ||||
|     def get_public_room_ids_at_stream_id_txn(self, txn, stream_id): | ||||
|     def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, | ||||
|                                              network_tuple): | ||||
|         return { | ||||
|             rm | ||||
|             for rm, vis in self.get_published_at_stream_id_txn(txn, stream_id).items() | ||||
|             for rm, vis in self.get_published_at_stream_id_txn( | ||||
|                 txn, stream_id, network_tuple=network_tuple | ||||
|             ).items() | ||||
|             if vis | ||||
|         } | ||||
| 
 | ||||
|     def get_published_at_stream_id_txn(self, txn, stream_id): | ||||
|         sql = (""" | ||||
|             SELECT room_id, visibility FROM public_room_list_stream | ||||
|             INNER JOIN ( | ||||
|                 SELECT room_id, max(stream_id) AS stream_id | ||||
|     def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): | ||||
|         if network_tuple: | ||||
|             # We want to get from a particular list. No aggregation required. | ||||
| 
 | ||||
|             sql = (""" | ||||
|                 SELECT room_id, visibility FROM public_room_list_stream | ||||
|                 INNER JOIN ( | ||||
|                     SELECT room_id, max(stream_id) AS stream_id | ||||
|                     FROM public_room_list_stream | ||||
|                     WHERE stream_id <= ? %s | ||||
|                     GROUP BY room_id | ||||
|                 ) grouped USING (room_id, stream_id) | ||||
|             """) | ||||
| 
 | ||||
|             if network_tuple.appservice_id is not None: | ||||
|                 txn.execute( | ||||
|                     sql % ("AND appservice_id = ? AND network_id = ?",), | ||||
|                     (stream_id, network_tuple.appservice_id, network_tuple.network_id,) | ||||
|                 ) | ||||
|             else: | ||||
|                 txn.execute( | ||||
|                     sql % ("AND appservice_id IS NULL",), | ||||
|                     (stream_id,) | ||||
|                 ) | ||||
|             return dict(txn.fetchall()) | ||||
|         else: | ||||
|             # We want to get from all lists, so we need to aggregate the results | ||||
| 
 | ||||
|             logger.info("Executing full list") | ||||
| 
 | ||||
|             sql = (""" | ||||
|                 SELECT room_id, visibility | ||||
|                 FROM public_room_list_stream | ||||
|                 WHERE stream_id <= ? | ||||
|                 GROUP BY room_id | ||||
|             ) grouped USING (room_id, stream_id) | ||||
|         """) | ||||
|                 INNER JOIN ( | ||||
|                     SELECT | ||||
|                         room_id, max(stream_id) AS stream_id, appservice_id, | ||||
|                         network_id | ||||
|                     FROM public_room_list_stream | ||||
|                     WHERE stream_id <= ? | ||||
|                     GROUP BY room_id, appservice_id, network_id | ||||
|                 ) grouped USING (room_id, stream_id) | ||||
|             """) | ||||
| 
 | ||||
|         txn.execute(sql, (stream_id,)) | ||||
|         return dict(txn.fetchall()) | ||||
|             txn.execute( | ||||
|                 sql, | ||||
|                 (stream_id,) | ||||
|             ) | ||||
| 
 | ||||
|     def get_public_room_changes(self, prev_stream_id, new_stream_id): | ||||
|             results = {} | ||||
|             # A room is visible if its visible on any list. | ||||
|             for room_id, visibility in txn.fetchall(): | ||||
|                 results[room_id] = bool(visibility) or results.get(room_id, False) | ||||
| 
 | ||||
|             return results | ||||
| 
 | ||||
|     def get_public_room_changes(self, prev_stream_id, new_stream_id, | ||||
|                                 network_tuple): | ||||
|         def get_public_room_changes_txn(txn): | ||||
|             then_rooms = self.get_public_room_ids_at_stream_id_txn(txn, prev_stream_id) | ||||
|             then_rooms = self.get_public_room_ids_at_stream_id_txn( | ||||
|                 txn, prev_stream_id, network_tuple | ||||
|             ) | ||||
| 
 | ||||
|             now_rooms_dict = self.get_published_at_stream_id_txn(txn, new_stream_id) | ||||
|             now_rooms_dict = self.get_published_at_stream_id_txn( | ||||
|                 txn, new_stream_id, network_tuple | ||||
|             ) | ||||
| 
 | ||||
|             now_rooms_visible = set( | ||||
|                 rm for rm, vis in now_rooms_dict.items() if vis | ||||
|  | @ -311,7 +441,8 @@ class RoomStore(SQLBaseStore): | |||
|     def get_all_new_public_rooms(self, prev_id, current_id, limit): | ||||
|         def get_all_new_public_rooms(txn): | ||||
|             sql = (""" | ||||
|                 SELECT stream_id, room_id, visibility FROM public_room_list_stream | ||||
|                 SELECT stream_id, room_id, visibility, appservice_id, network_id | ||||
|                 FROM public_room_list_stream | ||||
|                 WHERE stream_id > ? AND stream_id <= ? | ||||
|                 ORDER BY stream_id ASC | ||||
|                 LIMIT ? | ||||
|  |  | |||
|  | @ -0,0 +1,27 @@ | |||
| /* Copyright 2016 OpenMarket Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *    http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| CREATE TABLE appservice_room_list( | ||||
|     appservice_id TEXT NOT NULL, | ||||
|     network_id TEXT NOT NULL, | ||||
|     room_id TEXT NOT NULL | ||||
| ); | ||||
| 
 | ||||
| CREATE UNIQUE INDEX appservice_room_list_idx ON appservice_room_list( | ||||
|     appservice_id, network_id, room_id | ||||
| ); | ||||
| 
 | ||||
| ALTER TABLE public_room_list_stream ADD COLUMN appservice_id TEXT; | ||||
| ALTER TABLE public_room_list_stream ADD COLUMN network_id TEXT; | ||||
|  | @ -274,3 +274,37 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): | |||
|             return "t%d-%d" % (self.topological, self.stream) | ||||
|         else: | ||||
|             return "s%d" % (self.stream,) | ||||
| 
 | ||||
| 
 | ||||
| class ThirdPartyInstanceID( | ||||
|         namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) | ||||
| ): | ||||
|     # Deny iteration because it will bite you if you try to create a singleton | ||||
|     # set by: | ||||
|     #    users = set(user) | ||||
|     def __iter__(self): | ||||
|         raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) | ||||
| 
 | ||||
|     # Because this class is a namedtuple of strings, it is deeply immutable. | ||||
|     def __copy__(self): | ||||
|         return self | ||||
| 
 | ||||
|     def __deepcopy__(self, memo): | ||||
|         return self | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_string(cls, s): | ||||
|         bits = s.split("|", 2) | ||||
|         if len(bits) != 2: | ||||
|             raise SynapseError(400, "Invalid ID %r" % (s,)) | ||||
| 
 | ||||
|         return cls(appservice_id=bits[0], network_id=bits[1]) | ||||
| 
 | ||||
|     def to_string(self): | ||||
|         return "%s|%s" % (self.appservice_id, self.network_id,) | ||||
| 
 | ||||
|     __str__ = to_string | ||||
| 
 | ||||
|     @classmethod | ||||
|     def create(cls, appservice_id, network_id,): | ||||
|         return cls(appservice_id=appservice_id, network_id=network_id) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston