Merge branch 'develop' into email_login
						commit
						c50ad14bae
					
				|  | @ -11,7 +11,9 @@ if [ -f $PID_FILE ]; then | |||
|     exit 1 | ||||
| fi | ||||
| 
 | ||||
| find "$DIR" -name "*.log" -delete | ||||
| find "$DIR" -name "*.db" -delete | ||||
| for port in 8080 8081 8082; do | ||||
|     rm -rf $DIR/$port | ||||
|     rm -rf $DIR/media_store.$port | ||||
| done | ||||
| 
 | ||||
| rm -rf $DIR/etc | ||||
|  |  | |||
|  | @ -8,14 +8,6 @@ cd "$DIR/.." | |||
| 
 | ||||
| mkdir -p demo/etc | ||||
| 
 | ||||
| # Check the --no-rate-limit param | ||||
| PARAMS="" | ||||
| if [ $# -eq 1 ]; then | ||||
|     if [ $1 = "--no-rate-limit" ]; then | ||||
| 	    PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000" | ||||
|     fi | ||||
| fi | ||||
| 
 | ||||
| export PYTHONPATH=$(readlink -f $(pwd)) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -31,10 +23,20 @@ for port in 8080 8081 8082; do | |||
|     #rm $DIR/etc/$port.config | ||||
|     python -m synapse.app.homeserver \ | ||||
|         --generate-config \ | ||||
|         --enable_registration \ | ||||
|         -H "localhost:$https_port" \ | ||||
|         --config-path "$DIR/etc/$port.config" \ | ||||
| 
 | ||||
|     # Check script parameters | ||||
|     if [ $# -eq 1 ]; then | ||||
|         if [ $1 = "--no-rate-limit" ]; then | ||||
|             # Set high limits in config file to disable rate limiting | ||||
|             perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config | ||||
|             perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config | ||||
|         fi | ||||
|     fi | ||||
| 
 | ||||
|     perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config | ||||
| 
 | ||||
|     python -m synapse.app.homeserver \ | ||||
|         --config-path "$DIR/etc/$port.config" \ | ||||
|         -D \ | ||||
|  |  | |||
|  | @ -16,3 +16,6 @@ ignore = | |||
|     docs/* | ||||
|     pylint.cfg | ||||
|     tox.ini | ||||
| 
 | ||||
| [flake8] | ||||
| max-line-length = 90 | ||||
|  |  | |||
							
								
								
									
										2
									
								
								setup.py
								
								
								
								
							
							
						
						
									
										2
									
								
								setup.py
								
								
								
								
							|  | @ -48,7 +48,7 @@ setup( | |||
|     description="Reference Synapse Home Server", | ||||
|     install_requires=dependencies['requirements'](include_conditional=True).keys(), | ||||
|     setup_requires=[ | ||||
|         "Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0 | ||||
|         "Twisted>=15.1.0", # Here to override setuptools_trial's dependency on Twisted>=2.4.0 | ||||
|         "setuptools_trial", | ||||
|         "mock" | ||||
|     ], | ||||
|  |  | |||
|  | @ -44,6 +44,11 @@ class Auth(object): | |||
|     def check(self, event, auth_events): | ||||
|         """ Checks if this event is correctly authed. | ||||
| 
 | ||||
|         Args: | ||||
|             event: the event being checked. | ||||
|             auth_events (dict: event-key -> event): the existing room state. | ||||
| 
 | ||||
| 
 | ||||
|         Returns: | ||||
|             True if the auth checks pass. | ||||
|         """ | ||||
|  | @ -319,7 +324,7 @@ class Auth(object): | |||
|         Returns: | ||||
|             tuple : of UserID and device string: | ||||
|                 User ID object of the user making the request | ||||
|                 Client ID object of the client instance the user is using | ||||
|                 ClientInfo object of the client instance the user is using | ||||
|         Raises: | ||||
|             AuthError if no user by that token exists or the token is invalid. | ||||
|         """ | ||||
|  | @ -347,12 +352,14 @@ class Auth(object): | |||
|                 if not user_id: | ||||
|                     raise KeyError | ||||
| 
 | ||||
|                 request.authenticated_entity = user_id | ||||
| 
 | ||||
|                 defer.returnValue( | ||||
|                     (UserID.from_string(user_id), ClientInfo("", "")) | ||||
|                 ) | ||||
|                 return | ||||
|             except KeyError: | ||||
|                 pass  # normal users won't have this query parameter set | ||||
|                 pass  # normal users won't have the user_id query parameter set. | ||||
| 
 | ||||
|             user_info = yield self.get_user_by_token(access_token) | ||||
|             user = user_info["user"] | ||||
|  | @ -420,6 +427,7 @@ class Auth(object): | |||
|                     "Unrecognised access token.", | ||||
|                     errcode=Codes.UNKNOWN_TOKEN | ||||
|                 ) | ||||
|             request.authenticated_entity = service.sender | ||||
|             defer.returnValue(service) | ||||
|         except KeyError: | ||||
|             raise AuthError( | ||||
|  | @ -521,23 +529,22 @@ class Auth(object): | |||
| 
 | ||||
|         # Check state_key | ||||
|         if hasattr(event, "state_key"): | ||||
|             if not event.state_key.startswith("_"): | ||||
|                 if event.state_key.startswith("@"): | ||||
|                     if event.state_key != event.user_id: | ||||
|             if event.state_key.startswith("@"): | ||||
|                 if event.state_key != event.user_id: | ||||
|                     raise AuthError( | ||||
|                         403, | ||||
|                         "You are not allowed to set others state" | ||||
|                     ) | ||||
|                 else: | ||||
|                     sender_domain = UserID.from_string( | ||||
|                         event.user_id | ||||
|                     ).domain | ||||
| 
 | ||||
|                     if sender_domain != event.state_key: | ||||
|                         raise AuthError( | ||||
|                             403, | ||||
|                             "You are not allowed to set others state" | ||||
|                         ) | ||||
|                     else: | ||||
|                         sender_domain = UserID.from_string( | ||||
|                             event.user_id | ||||
|                         ).domain | ||||
| 
 | ||||
|                         if sender_domain != event.state_key: | ||||
|                             raise AuthError( | ||||
|                                 403, | ||||
|                                 "You are not allowed to set others state" | ||||
|                             ) | ||||
| 
 | ||||
|         return True | ||||
| 
 | ||||
|  |  | |||
|  | @ -657,7 +657,8 @@ def run(hs): | |||
| 
 | ||||
|     if hs.config.daemonize: | ||||
| 
 | ||||
|         print hs.config.pid_file | ||||
|         if hs.config.print_pidfile: | ||||
|             print hs.config.pid_file | ||||
| 
 | ||||
|         daemon = Daemonize( | ||||
|             app="synapse-homeserver", | ||||
|  |  | |||
|  | @ -138,12 +138,19 @@ class Config(object): | |||
|             action="store_true", | ||||
|             help="Generate a config file for the server name" | ||||
|         ) | ||||
|         config_parser.add_argument( | ||||
|             "--generate-keys", | ||||
|             action="store_true", | ||||
|             help="Generate any missing key files then exit" | ||||
|         ) | ||||
|         config_parser.add_argument( | ||||
|             "-H", "--server-name", | ||||
|             help="The server name to generate a config file for" | ||||
|         ) | ||||
|         config_args, remaining_args = config_parser.parse_known_args(argv) | ||||
| 
 | ||||
|         generate_keys = config_args.generate_keys | ||||
| 
 | ||||
|         if config_args.generate_config: | ||||
|             if not config_args.config_path: | ||||
|                 config_parser.error( | ||||
|  | @ -151,51 +158,40 @@ class Config(object): | |||
|                     " generated using \"--generate-config -H SERVER_NAME" | ||||
|                     " -c CONFIG-FILE\"" | ||||
|                 ) | ||||
| 
 | ||||
|             config_dir_path = os.path.dirname(config_args.config_path[0]) | ||||
|             config_dir_path = os.path.abspath(config_dir_path) | ||||
| 
 | ||||
|             server_name = config_args.server_name | ||||
|             if not server_name: | ||||
|                 print "Must specify a server_name to a generate config for." | ||||
|                 sys.exit(1) | ||||
|             (config_path,) = config_args.config_path | ||||
|             if not os.path.exists(config_dir_path): | ||||
|                 os.makedirs(config_dir_path) | ||||
|             if os.path.exists(config_path): | ||||
|                 print "Config file %r already exists" % (config_path,) | ||||
|                 yaml_config = cls.read_config_file(config_path) | ||||
|                 yaml_name = yaml_config["server_name"] | ||||
|                 if server_name != yaml_name: | ||||
|                     print ( | ||||
|                         "Config file %r has a different server_name: " | ||||
|                         " %r != %r" % (config_path, server_name, yaml_name) | ||||
|                     ) | ||||
|             if not os.path.exists(config_path): | ||||
|                 config_dir_path = os.path.dirname(config_path) | ||||
|                 config_dir_path = os.path.abspath(config_dir_path) | ||||
| 
 | ||||
|                 server_name = config_args.server_name | ||||
|                 if not server_name: | ||||
|                     print "Must specify a server_name to a generate config for." | ||||
|                     sys.exit(1) | ||||
|                 config_bytes, config = obj.generate_config( | ||||
|                     config_dir_path, server_name | ||||
|                 ) | ||||
|                 config.update(yaml_config) | ||||
|                 print "Generating any missing keys for %r" % (server_name,) | ||||
|                 obj.invoke_all("generate_files", config) | ||||
|                 sys.exit(0) | ||||
|             with open(config_path, "wb") as config_file: | ||||
|                 config_bytes, config = obj.generate_config( | ||||
|                     config_dir_path, server_name | ||||
|                 ) | ||||
|                 obj.invoke_all("generate_files", config) | ||||
|                 config_file.write(config_bytes) | ||||
|                 if not os.path.exists(config_dir_path): | ||||
|                     os.makedirs(config_dir_path) | ||||
|                 with open(config_path, "wb") as config_file: | ||||
|                     config_bytes, config = obj.generate_config( | ||||
|                         config_dir_path, server_name | ||||
|                     ) | ||||
|                     obj.invoke_all("generate_files", config) | ||||
|                     config_file.write(config_bytes) | ||||
|                 print ( | ||||
|                     "A config file has been generated in %s for server name" | ||||
|                     " '%s' with corresponding SSL keys and self-signed" | ||||
|                     " certificates. Please review this file and customise it to" | ||||
|                     " your needs." | ||||
|                     "A config file has been generated in %r for server name" | ||||
|                     " %r with corresponding SSL keys and self-signed" | ||||
|                     " certificates. Please review this file and customise it" | ||||
|                     " to your needs." | ||||
|                 ) % (config_path, server_name) | ||||
|             print ( | ||||
|                 "If this server name is incorrect, you will need to regenerate" | ||||
|                 " the SSL certificates" | ||||
|             ) | ||||
|             sys.exit(0) | ||||
|                 print ( | ||||
|                     "If this server name is incorrect, you will need to" | ||||
|                     " regenerate the SSL certificates" | ||||
|                 ) | ||||
|                 sys.exit(0) | ||||
|             else: | ||||
|                 print ( | ||||
|                     "Config file %r already exists. Generating any missing key" | ||||
|                     " files." | ||||
|                 ) % (config_path,) | ||||
|                 generate_keys = True | ||||
| 
 | ||||
|         parser = argparse.ArgumentParser( | ||||
|             parents=[config_parser], | ||||
|  | @ -213,7 +209,7 @@ class Config(object): | |||
|                 " -c CONFIG-FILE\"" | ||||
|             ) | ||||
| 
 | ||||
|         config_dir_path = os.path.dirname(config_args.config_path[0]) | ||||
|         config_dir_path = os.path.dirname(config_args.config_path[-1]) | ||||
|         config_dir_path = os.path.abspath(config_dir_path) | ||||
| 
 | ||||
|         specified_config = {} | ||||
|  | @ -226,6 +222,10 @@ class Config(object): | |||
|         config.pop("log_config") | ||||
|         config.update(specified_config) | ||||
| 
 | ||||
|         if generate_keys: | ||||
|             obj.invoke_all("generate_files", config) | ||||
|             sys.exit(0) | ||||
| 
 | ||||
|         obj.invoke_all("read_config", config) | ||||
| 
 | ||||
|         obj.invoke_all("read_arguments", args) | ||||
|  |  | |||
|  | @ -24,6 +24,7 @@ class ServerConfig(Config): | |||
|         self.web_client = config["web_client"] | ||||
|         self.soft_file_limit = config["soft_file_limit"] | ||||
|         self.daemonize = config.get("daemonize") | ||||
|         self.print_pidfile = config.get("print_pidfile") | ||||
|         self.use_frozen_dicts = config.get("use_frozen_dicts", True) | ||||
| 
 | ||||
|         self.listeners = config.get("listeners", []) | ||||
|  | @ -208,12 +209,18 @@ class ServerConfig(Config): | |||
|             self.manhole = args.manhole | ||||
|         if args.daemonize is not None: | ||||
|             self.daemonize = args.daemonize | ||||
|         if args.print_pidfile is not None: | ||||
|             self.print_pidfile = args.print_pidfile | ||||
| 
 | ||||
|     def add_arguments(self, parser): | ||||
|         server_group = parser.add_argument_group("server") | ||||
|         server_group.add_argument("-D", "--daemonize", action='store_true', | ||||
|                                   default=None, | ||||
|                                   help="Daemonize the home server") | ||||
|         server_group.add_argument("--print-pidfile", action='store_true', | ||||
|                                   default=None, | ||||
|                                   help="Print the path to the pidfile just" | ||||
|                                   " before daemonizing") | ||||
|         server_group.add_argument("--manhole", metavar="PORT", dest="manhole", | ||||
|                                   type=int, | ||||
|                                   help="Turn on the twisted telnet manhole" | ||||
|  |  | |||
|  | @ -23,7 +23,7 @@ from synapse.api.errors import ( | |||
|     CodeMessageException, HttpResponseException, SynapseError, | ||||
| ) | ||||
| from synapse.util import unwrapFirstError | ||||
| from synapse.util.expiringcache import ExpiringCache | ||||
| from synapse.util.caches.expiringcache import ExpiringCache | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.events import FrozenEvent | ||||
| import synapse.metrics | ||||
|  | @ -134,6 +134,36 @@ class FederationClient(FederationBase): | |||
|             destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail | ||||
|         ) | ||||
| 
 | ||||
|     @log_function | ||||
|     def query_client_keys(self, destination, content): | ||||
|         """Query device keys for a device hosted on a remote server. | ||||
| 
 | ||||
|         Args: | ||||
|             destination (str): Domain name of the remote homeserver | ||||
|             content (dict): The query content. | ||||
| 
 | ||||
|         Returns: | ||||
|             a Deferred which will eventually yield a JSON object from the | ||||
|             response | ||||
|         """ | ||||
|         sent_queries_counter.inc("client_device_keys") | ||||
|         return self.transport_layer.query_client_keys(destination, content) | ||||
| 
 | ||||
|     @log_function | ||||
|     def claim_client_keys(self, destination, content): | ||||
|         """Claims one-time keys for a device hosted on a remote server. | ||||
| 
 | ||||
|         Args: | ||||
|             destination (str): Domain name of the remote homeserver | ||||
|             content (dict): The query content. | ||||
| 
 | ||||
|         Returns: | ||||
|             a Deferred which will eventually yield a JSON object from the | ||||
|             response | ||||
|         """ | ||||
|         sent_queries_counter.inc("client_one_time_keys") | ||||
|         return self.transport_layer.claim_client_keys(destination, content) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def backfill(self, dest, context, limit, extremities): | ||||
|  |  | |||
|  | @ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError | |||
| 
 | ||||
| from synapse.crypto.event_signing import compute_event_signature | ||||
| 
 | ||||
| import simplejson as json | ||||
| import logging | ||||
| 
 | ||||
| 
 | ||||
|  | @ -312,6 +313,48 @@ class FederationServer(FederationBase): | |||
|             (200, send_content) | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def on_query_client_keys(self, origin, content): | ||||
|         query = [] | ||||
|         for user_id, device_ids in content.get("device_keys", {}).items(): | ||||
|             if not device_ids: | ||||
|                 query.append((user_id, None)) | ||||
|             else: | ||||
|                 for device_id in device_ids: | ||||
|                     query.append((user_id, device_id)) | ||||
| 
 | ||||
|         results = yield self.store.get_e2e_device_keys(query) | ||||
| 
 | ||||
|         json_result = {} | ||||
|         for user_id, device_keys in results.items(): | ||||
|             for device_id, json_bytes in device_keys.items(): | ||||
|                 json_result.setdefault(user_id, {})[device_id] = json.loads( | ||||
|                     json_bytes | ||||
|                 ) | ||||
| 
 | ||||
|         defer.returnValue({"device_keys": json_result}) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def on_claim_client_keys(self, origin, content): | ||||
|         query = [] | ||||
|         for user_id, device_keys in content.get("one_time_keys", {}).items(): | ||||
|             for device_id, algorithm in device_keys.items(): | ||||
|                 query.append((user_id, device_id, algorithm)) | ||||
| 
 | ||||
|         results = yield self.store.claim_e2e_one_time_keys(query) | ||||
| 
 | ||||
|         json_result = {} | ||||
|         for user_id, device_keys in results.items(): | ||||
|             for device_id, keys in device_keys.items(): | ||||
|                 for key_id, json_bytes in keys.items(): | ||||
|                     json_result.setdefault(user_id, {})[device_id] = { | ||||
|                         key_id: json.loads(json_bytes) | ||||
|                     } | ||||
| 
 | ||||
|         defer.returnValue({"one_time_keys": json_result}) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def on_get_missing_events(self, origin, room_id, earliest_events, | ||||
|  |  | |||
|  | @ -222,6 +222,76 @@ class TransportLayerClient(object): | |||
| 
 | ||||
|         defer.returnValue(content) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def query_client_keys(self, destination, query_content): | ||||
|         """Query the device keys for a list of user ids hosted on a remote | ||||
|         server. | ||||
| 
 | ||||
|         Request: | ||||
|             { | ||||
|               "device_keys": { | ||||
|                 "<user_id>": ["<device_id>"] | ||||
|             } } | ||||
| 
 | ||||
|         Response: | ||||
|             { | ||||
|               "device_keys": { | ||||
|                 "<user_id>": { | ||||
|                   "<device_id>": {...} | ||||
|             } } } | ||||
| 
 | ||||
|         Args: | ||||
|             destination(str): The server to query. | ||||
|             query_content(dict): The user ids to query. | ||||
|         Returns: | ||||
|             A dict containg the device keys. | ||||
|         """ | ||||
|         path = PREFIX + "/user/keys/query" | ||||
| 
 | ||||
|         content = yield self.client.post_json( | ||||
|             destination=destination, | ||||
|             path=path, | ||||
|             data=query_content, | ||||
|         ) | ||||
|         defer.returnValue(content) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def claim_client_keys(self, destination, query_content): | ||||
|         """Claim one-time keys for a list of devices hosted on a remote server. | ||||
| 
 | ||||
|         Request: | ||||
|             { | ||||
|               "one_time_keys": { | ||||
|                 "<user_id>": { | ||||
|                     "<device_id>": "<algorithm>" | ||||
|             } } } | ||||
| 
 | ||||
|         Response: | ||||
|             { | ||||
|               "device_keys": { | ||||
|                 "<user_id>": { | ||||
|                   "<device_id>": { | ||||
|                     "<algorithm>:<key_id>": "<key_base64>" | ||||
|             } } } } | ||||
| 
 | ||||
|         Args: | ||||
|             destination(str): The server to query. | ||||
|             query_content(dict): The user ids to query. | ||||
|         Returns: | ||||
|             A dict containg the one-time keys. | ||||
|         """ | ||||
| 
 | ||||
|         path = PREFIX + "/user/keys/claim" | ||||
| 
 | ||||
|         content = yield self.client.post_json( | ||||
|             destination=destination, | ||||
|             path=path, | ||||
|             data=query_content, | ||||
|         ) | ||||
|         defer.returnValue(content) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def get_missing_events(self, destination, room_id, earliest_events, | ||||
|  |  | |||
|  | @ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet): | |||
|         defer.returnValue((200, content)) | ||||
| 
 | ||||
| 
 | ||||
| class FederationClientKeysQueryServlet(BaseFederationServlet): | ||||
|     PATH = "/user/keys/query" | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, origin, content, query): | ||||
|         response = yield self.handler.on_query_client_keys(origin, content) | ||||
|         defer.returnValue((200, response)) | ||||
| 
 | ||||
| 
 | ||||
| class FederationClientKeysClaimServlet(BaseFederationServlet): | ||||
|     PATH = "/user/keys/claim" | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, origin, content, query): | ||||
|         response = yield self.handler.on_claim_client_keys(origin, content) | ||||
|         defer.returnValue((200, response)) | ||||
| 
 | ||||
| 
 | ||||
| class FederationQueryAuthServlet(BaseFederationServlet): | ||||
|     PATH = "/query_auth/([^/]*)/([^/]*)" | ||||
| 
 | ||||
|  | @ -373,4 +391,6 @@ SERVLET_CLASSES = ( | |||
|     FederationQueryAuthServlet, | ||||
|     FederationGetMissingEventsServlet, | ||||
|     FederationEventAuthServlet, | ||||
|     FederationClientKeysQueryServlet, | ||||
|     FederationClientKeysClaimServlet, | ||||
| ) | ||||
|  |  | |||
|  | @ -22,7 +22,6 @@ from .room import ( | |||
| from .message import MessageHandler | ||||
| from .events import EventStreamHandler, EventHandler | ||||
| from .federation import FederationHandler | ||||
| from .login import LoginHandler | ||||
| from .profile import ProfileHandler | ||||
| from .presence import PresenceHandler | ||||
| from .directory import DirectoryHandler | ||||
|  | @ -54,7 +53,6 @@ class Handlers(object): | |||
|         self.profile_handler = ProfileHandler(hs) | ||||
|         self.presence_handler = PresenceHandler(hs) | ||||
|         self.room_list_handler = RoomListHandler(hs) | ||||
|         self.login_handler = LoginHandler(hs) | ||||
|         self.directory_handler = DirectoryHandler(hs) | ||||
|         self.typing_notification_handler = TypingNotificationHandler(hs) | ||||
|         self.admin_handler = AdminHandler(hs) | ||||
|  |  | |||
|  | @ -47,17 +47,24 @@ class AuthHandler(BaseHandler): | |||
|         self.sessions = {} | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def check_auth(self, flows, clientdict, clientip=None): | ||||
|     def check_auth(self, flows, clientdict, clientip): | ||||
|         """ | ||||
|         Takes a dictionary sent by the client in the login / registration | ||||
|         protocol and handles the login flow. | ||||
| 
 | ||||
|         As a side effect, this function fills in the 'creds' key on the user's | ||||
|         session with a map, which maps each auth-type (str) to the relevant | ||||
|         identity authenticated by that auth-type (mostly str, but for captcha, bool). | ||||
| 
 | ||||
|         Args: | ||||
|             flows: list of list of stages | ||||
|             authdict: The dictionary from the client root level, not the | ||||
|                       'auth' key: this method prompts for auth if none is sent. | ||||
|             flows (list): A list of login flows. Each flow is an ordered list of | ||||
|                           strings representing auth-types. At least one full | ||||
|                           flow must be completed in order for auth to be successful. | ||||
|             clientdict: The dictionary from the client root level, not the | ||||
|                         'auth' key: this method prompts for auth if none is sent. | ||||
|             clientip (str): The IP address of the client. | ||||
|         Returns: | ||||
|             A tuple of authed, dict, dict where authed is true if the client | ||||
|             A tuple of (authed, dict, dict) where authed is true if the client | ||||
|             has successfully completed an auth flow. If it is true, the first | ||||
|             dict contains the authenticated credentials of each stage. | ||||
| 
 | ||||
|  | @ -75,7 +82,7 @@ class AuthHandler(BaseHandler): | |||
|             del clientdict['auth'] | ||||
|             if 'session' in authdict: | ||||
|                 sid = authdict['session'] | ||||
|         sess = self._get_session_info(sid) | ||||
|         session = self._get_session_info(sid) | ||||
| 
 | ||||
|         if len(clientdict) > 0: | ||||
|             # This was designed to allow the client to omit the parameters | ||||
|  | @ -87,20 +94,19 @@ class AuthHandler(BaseHandler): | |||
|             # on a home server. | ||||
|             # Revisit: Assumimg the REST APIs do sensible validation, the data | ||||
|             # isn't arbintrary. | ||||
|             sess['clientdict'] = clientdict | ||||
|             self._save_session(sess) | ||||
|             pass | ||||
|         elif 'clientdict' in sess: | ||||
|             clientdict = sess['clientdict'] | ||||
|             session['clientdict'] = clientdict | ||||
|             self._save_session(session) | ||||
|         elif 'clientdict' in session: | ||||
|             clientdict = session['clientdict'] | ||||
| 
 | ||||
|         if not authdict: | ||||
|             defer.returnValue( | ||||
|                 (False, self._auth_dict_for_flows(flows, sess), clientdict) | ||||
|                 (False, self._auth_dict_for_flows(flows, session), clientdict) | ||||
|             ) | ||||
| 
 | ||||
|         if 'creds' not in sess: | ||||
|             sess['creds'] = {} | ||||
|         creds = sess['creds'] | ||||
|         if 'creds' not in session: | ||||
|             session['creds'] = {} | ||||
|         creds = session['creds'] | ||||
| 
 | ||||
|         # check auth type currently being presented | ||||
|         if 'type' in authdict: | ||||
|  | @ -109,15 +115,15 @@ class AuthHandler(BaseHandler): | |||
|             result = yield self.checkers[authdict['type']](authdict, clientip) | ||||
|             if result: | ||||
|                 creds[authdict['type']] = result | ||||
|                 self._save_session(sess) | ||||
|                 self._save_session(session) | ||||
| 
 | ||||
|         for f in flows: | ||||
|             if len(set(f) - set(creds.keys())) == 0: | ||||
|                 logger.info("Auth completed with creds: %r", creds) | ||||
|                 self._remove_session(sess) | ||||
|                 self._remove_session(session) | ||||
|                 defer.returnValue((True, creds, clientdict)) | ||||
| 
 | ||||
|         ret = self._auth_dict_for_flows(flows, sess) | ||||
|         ret = self._auth_dict_for_flows(flows, session) | ||||
|         ret['completed'] = creds.keys() | ||||
|         defer.returnValue((False, ret, clientdict)) | ||||
| 
 | ||||
|  | @ -151,22 +157,13 @@ class AuthHandler(BaseHandler): | |||
|         if "user" not in authdict or "password" not in authdict: | ||||
|             raise LoginError(400, "", Codes.MISSING_PARAM) | ||||
| 
 | ||||
|         user = authdict["user"] | ||||
|         user_id = authdict["user"] | ||||
|         password = authdict["password"] | ||||
|         if not user.startswith('@'): | ||||
|             user = UserID.create(user, self.hs.hostname).to_string() | ||||
|         if not user_id.startswith('@'): | ||||
|             user_id = UserID.create(user_id, self.hs.hostname).to_string() | ||||
| 
 | ||||
|         user_info = yield self.store.get_user_by_id(user_id=user) | ||||
|         if not user_info: | ||||
|             logger.warn("Attempted to login as %s but they do not exist", user) | ||||
|             raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) | ||||
| 
 | ||||
|         stored_hash = user_info["password_hash"] | ||||
|         if bcrypt.checkpw(password, stored_hash): | ||||
|             defer.returnValue(user) | ||||
|         else: | ||||
|             logger.warn("Failed password login for user %s", user) | ||||
|             raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) | ||||
|         self._check_password(user_id, password) | ||||
|         defer.returnValue(user_id) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _check_recaptcha(self, authdict, clientip): | ||||
|  | @ -270,6 +267,59 @@ class AuthHandler(BaseHandler): | |||
| 
 | ||||
|         return self.sessions[session_id] | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def login_with_password(self, user_id, password): | ||||
|         """ | ||||
|         Authenticates the user with their username and password. | ||||
| 
 | ||||
|         Used only by the v1 login API. | ||||
| 
 | ||||
|         Args: | ||||
|             user_id (str): User ID | ||||
|             password (str): Password | ||||
|         Returns: | ||||
|             The access token for the user's session. | ||||
|         Raises: | ||||
|             StoreError if there was a problem storing the token. | ||||
|             LoginError if there was an authentication problem. | ||||
|         """ | ||||
|         yield self._check_password(user_id, password) | ||||
| 
 | ||||
|         reg_handler = self.hs.get_handlers().registration_handler | ||||
|         access_token = reg_handler.generate_token(user_id) | ||||
|         logger.info("Logging in user %s", user_id) | ||||
|         yield self.store.add_access_token_to_user(user_id, access_token) | ||||
|         defer.returnValue(access_token) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _check_password(self, user_id, password): | ||||
|         """Checks that user_id has passed password, raises LoginError if not.""" | ||||
|         user_info = yield self.store.get_user_by_id(user_id=user_id) | ||||
|         if not user_info: | ||||
|             logger.warn("Attempted to login as %s but they do not exist", user_id) | ||||
|             raise LoginError(403, "", errcode=Codes.FORBIDDEN) | ||||
| 
 | ||||
|         stored_hash = user_info["password_hash"] | ||||
|         if not bcrypt.checkpw(password, stored_hash): | ||||
|             logger.warn("Failed password login for user %s", user_id) | ||||
|             raise LoginError(403, "", errcode=Codes.FORBIDDEN) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def set_password(self, user_id, newpassword): | ||||
|         password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) | ||||
| 
 | ||||
|         yield self.store.user_set_password_hash(user_id, password_hash) | ||||
|         yield self.store.user_delete_access_tokens(user_id) | ||||
|         yield self.hs.get_pusherpool().remove_pushers_by_user(user_id) | ||||
|         yield self.store.flush_user(user_id) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def add_threepid(self, user_id, medium, address, validated_at): | ||||
|         yield self.store.user_add_threepid( | ||||
|             user_id, medium, address, validated_at, | ||||
|             self.hs.get_clock().time_msec() | ||||
|         ) | ||||
| 
 | ||||
|     def _save_session(self, session): | ||||
|         # TODO: Persistent storage | ||||
|         logger.debug("Saving session %s", session) | ||||
|  |  | |||
|  | @ -70,7 +70,15 @@ class EventStreamHandler(BaseHandler): | |||
|                 self._streams_per_user[auth_user] += 1 | ||||
| 
 | ||||
|             rm_handler = self.hs.get_handlers().room_member_handler | ||||
|             room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user) | ||||
| 
 | ||||
|             app_service = yield self.store.get_app_service_by_user_id( | ||||
|                 auth_user.to_string() | ||||
|             ) | ||||
|             if app_service: | ||||
|                 rooms = yield self.store.get_app_service_rooms(app_service) | ||||
|                 room_ids = set(r.room_id for r in rooms) | ||||
|             else: | ||||
|                 room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user) | ||||
| 
 | ||||
|             if timeout: | ||||
|                 # If they've set a timeout set a minimum limit. | ||||
|  |  | |||
|  | @ -229,15 +229,15 @@ class FederationHandler(BaseHandler): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _filter_events_for_server(self, server_name, room_id, events): | ||||
|         states = yield self.store.get_state_for_events( | ||||
|             room_id, [e.event_id for e in events], | ||||
|         event_to_state = yield self.store.get_state_for_events( | ||||
|             room_id, frozenset(e.event_id for e in events), | ||||
|             types=( | ||||
|                 (EventTypes.RoomHistoryVisibility, ""), | ||||
|                 (EventTypes.Member, None), | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         events_and_states = zip(events, states) | ||||
| 
 | ||||
|         def redact_disallowed(event_and_state): | ||||
|             event, state = event_and_state | ||||
| 
 | ||||
|         def redact_disallowed(event, state): | ||||
|             if not state: | ||||
|                 return event | ||||
| 
 | ||||
|  | @ -271,11 +271,10 @@ class FederationHandler(BaseHandler): | |||
| 
 | ||||
|             return event | ||||
| 
 | ||||
|         res = map(redact_disallowed, events_and_states) | ||||
| 
 | ||||
|         logger.info("_filter_events_for_server %r", res) | ||||
| 
 | ||||
|         defer.returnValue(res) | ||||
|         defer.returnValue([ | ||||
|             redact_disallowed(e, event_to_state[e.event_id]) | ||||
|             for e in events | ||||
|         ]) | ||||
| 
 | ||||
|     @log_function | ||||
|     @defer.inlineCallbacks | ||||
|  | @ -503,7 +502,7 @@ class FederationHandler(BaseHandler): | |||
|         event_ids = list(extremities.keys()) | ||||
| 
 | ||||
|         states = yield defer.gatherResults([ | ||||
|             self.state_handler.resolve_state_groups([e]) | ||||
|             self.state_handler.resolve_state_groups(room_id, [e]) | ||||
|             for e in event_ids | ||||
|         ]) | ||||
|         states = dict(zip(event_ids, [s[1] for s in states])) | ||||
|  |  | |||
|  | @ -1,83 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2014, 2015 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from ._base import BaseHandler | ||||
| from synapse.api.errors import LoginError, Codes | ||||
| 
 | ||||
| import bcrypt | ||||
| import logging | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class LoginHandler(BaseHandler): | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(LoginHandler, self).__init__(hs) | ||||
|         self.hs = hs | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def login(self, user, password): | ||||
|         """Login as the specified user with the specified password. | ||||
| 
 | ||||
|         Args: | ||||
|             user (str): The user ID. | ||||
|             password (str): The password. | ||||
|         Returns: | ||||
|             The newly allocated access token. | ||||
|         Raises: | ||||
|             StoreError if there was a problem storing the token. | ||||
|             LoginError if there was an authentication problem. | ||||
|         """ | ||||
|         # TODO do this better, it can't go in __init__ else it cyclic loops | ||||
|         if not hasattr(self, "reg_handler"): | ||||
|             self.reg_handler = self.hs.get_handlers().registration_handler | ||||
| 
 | ||||
|         # pull out the hash for this user if they exist | ||||
|         user_info = yield self.store.get_user_by_id(user_id=user) | ||||
|         if not user_info: | ||||
|             logger.warn("Attempted to login as %s but they do not exist", user) | ||||
|             raise LoginError(403, "", errcode=Codes.FORBIDDEN) | ||||
| 
 | ||||
|         stored_hash = user_info["password_hash"] | ||||
|         if bcrypt.checkpw(password, stored_hash): | ||||
|             # generate an access token and store it. | ||||
|             token = self.reg_handler._generate_token(user) | ||||
|             logger.info("Adding token %s for user %s", token, user) | ||||
|             yield self.store.add_access_token_to_user(user, token) | ||||
|             defer.returnValue(token) | ||||
|         else: | ||||
|             logger.warn("Failed password login for user %s", user) | ||||
|             raise LoginError(403, "", errcode=Codes.FORBIDDEN) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def set_password(self, user_id, newpassword, token_id=None): | ||||
|         password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt()) | ||||
| 
 | ||||
|         yield self.store.user_set_password_hash(user_id, password_hash) | ||||
|         yield self.store.user_delete_access_tokens_apart_from(user_id, token_id) | ||||
|         yield self.hs.get_pusherpool().remove_pushers_by_user_access_token( | ||||
|             user_id, token_id | ||||
|         ) | ||||
|         yield self.store.flush_user(user_id) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def add_threepid(self, user_id, medium, address, validated_at): | ||||
|         yield self.store.user_add_threepid( | ||||
|             user_id, medium, address, validated_at, | ||||
|             self.hs.get_clock().time_msec() | ||||
|         ) | ||||
|  | @ -137,15 +137,15 @@ class MessageHandler(BaseHandler): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _filter_events_for_client(self, user_id, room_id, events): | ||||
|         states = yield self.store.get_state_for_events( | ||||
|             room_id, [e.event_id for e in events], | ||||
|         event_id_to_state = yield self.store.get_state_for_events( | ||||
|             room_id, frozenset(e.event_id for e in events), | ||||
|             types=( | ||||
|                 (EventTypes.RoomHistoryVisibility, ""), | ||||
|                 (EventTypes.Member, user_id), | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         events_and_states = zip(events, states) | ||||
| 
 | ||||
|         def allowed(event_and_state): | ||||
|             event, state = event_and_state | ||||
| 
 | ||||
|         def allowed(event, state): | ||||
|             if event.type == EventTypes.RoomHistoryVisibility: | ||||
|                 return True | ||||
| 
 | ||||
|  | @ -175,10 +175,10 @@ class MessageHandler(BaseHandler): | |||
| 
 | ||||
|             return True | ||||
| 
 | ||||
|         events_and_states = filter(allowed, events_and_states) | ||||
|         defer.returnValue([ | ||||
|             ev | ||||
|             for ev, _ in events_and_states | ||||
|             event | ||||
|             for event in events | ||||
|             if allowed(event, event_id_to_state[event.event_id]) | ||||
|         ]) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  | @ -401,10 +401,14 @@ class MessageHandler(BaseHandler): | |||
|             except: | ||||
|                 logger.exception("Failed to get snapshot") | ||||
| 
 | ||||
|         yield defer.gatherResults( | ||||
|             [handle_room(e) for e in room_list], | ||||
|             consumeErrors=True | ||||
|         ).addErrback(unwrapFirstError) | ||||
|         # Only do N rooms at once | ||||
|         n = 5 | ||||
|         d_list = [handle_room(e) for e in room_list] | ||||
|         for i in range(0, len(d_list), n): | ||||
|             yield defer.gatherResults( | ||||
|                 d_list[i:i + n], | ||||
|                 consumeErrors=True | ||||
|             ).addErrback(unwrapFirstError) | ||||
| 
 | ||||
|         ret = { | ||||
|             "rooms": rooms_ret, | ||||
|  | @ -456,20 +460,14 @@ class MessageHandler(BaseHandler): | |||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
|         def get_presence(): | ||||
|             presence_defs = yield defer.DeferredList( | ||||
|                 [ | ||||
|                     presence_handler.get_state( | ||||
|                         target_user=UserID.from_string(m.user_id), | ||||
|                         auth_user=auth_user, | ||||
|                         as_event=True, | ||||
|                         check_auth=False, | ||||
|                     ) | ||||
|                     for m in room_members | ||||
|                 ], | ||||
|                 consumeErrors=True, | ||||
|             states = yield presence_handler.get_states( | ||||
|                 target_users=[UserID.from_string(m.user_id) for m in room_members], | ||||
|                 auth_user=auth_user, | ||||
|                 as_event=True, | ||||
|                 check_auth=False, | ||||
|             ) | ||||
| 
 | ||||
|             defer.returnValue([p for success, p in presence_defs if success]) | ||||
|             defer.returnValue(states.values()) | ||||
| 
 | ||||
|         receipts_handler = self.hs.get_handlers().receipts_handler | ||||
| 
 | ||||
|  |  | |||
|  | @ -192,6 +192,20 @@ class PresenceHandler(BaseHandler): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_state(self, target_user, auth_user, as_event=False, check_auth=True): | ||||
|         """Get the current presence state of the given user. | ||||
| 
 | ||||
|         Args: | ||||
|             target_user (UserID): The user whose presence we want | ||||
|             auth_user (UserID): The user requesting the presence, used for | ||||
|                 checking if said user is allowed to see the persence of the | ||||
|                 `target_user` | ||||
|             as_event (bool): Format the return as an event or not? | ||||
|             check_auth (bool): Perform the auth checks or not? | ||||
| 
 | ||||
|         Returns: | ||||
|             dict: The presence state of the `target_user`, whose format depends | ||||
|             on the `as_event` argument. | ||||
|         """ | ||||
|         if self.hs.is_mine(target_user): | ||||
|             if check_auth: | ||||
|                 visible = yield self.is_presence_visible( | ||||
|  | @ -232,6 +246,81 @@ class PresenceHandler(BaseHandler): | |||
|         else: | ||||
|             defer.returnValue(state) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_states(self, target_users, auth_user, as_event=False, check_auth=True): | ||||
|         """A batched version of the `get_state` method that accepts a list of | ||||
|         `target_users` | ||||
| 
 | ||||
|         Args: | ||||
|             target_users (list): The list of UserID's whose presence we want | ||||
|             auth_user (UserID): The user requesting the presence, used for | ||||
|                 checking if said user is allowed to see the persence of the | ||||
|                 `target_users` | ||||
|             as_event (bool): Format the return as an event or not? | ||||
|             check_auth (bool): Perform the auth checks or not? | ||||
| 
 | ||||
|         Returns: | ||||
|             dict: A mapping from user -> presence_state | ||||
|         """ | ||||
|         local_users, remote_users = partitionbool( | ||||
|             target_users, | ||||
|             lambda u: self.hs.is_mine(u) | ||||
|         ) | ||||
| 
 | ||||
|         if check_auth: | ||||
|             for user in local_users: | ||||
|                 visible = yield self.is_presence_visible( | ||||
|                     observer_user=auth_user, | ||||
|                     observed_user=user | ||||
|                 ) | ||||
| 
 | ||||
|                 if not visible: | ||||
|                     raise SynapseError(404, "Presence information not visible") | ||||
| 
 | ||||
|         results = {} | ||||
|         if local_users: | ||||
|             for user in local_users: | ||||
|                 if user in self._user_cachemap: | ||||
|                     results[user] = self._user_cachemap[user].get_state() | ||||
| 
 | ||||
|             local_to_user = {u.localpart: u for u in local_users} | ||||
| 
 | ||||
|             states = yield self.store.get_presence_states( | ||||
|                 [u.localpart for u in local_users if u not in results] | ||||
|             ) | ||||
| 
 | ||||
|             for local_part, state in states.items(): | ||||
|                 if state is None: | ||||
|                     continue | ||||
|                 res = {"presence": state["state"]} | ||||
|                 if "status_msg" in state and state["status_msg"]: | ||||
|                     res["status_msg"] = state["status_msg"] | ||||
|                 results[local_to_user[local_part]] = res | ||||
| 
 | ||||
|         for user in remote_users: | ||||
|             # TODO(paul): Have remote server send us permissions set | ||||
|             results[user] = self._get_or_offline_usercache(user).get_state() | ||||
| 
 | ||||
|         for state in results.values(): | ||||
|             if "last_active" in state: | ||||
|                 state["last_active_ago"] = int( | ||||
|                     self.clock.time_msec() - state.pop("last_active") | ||||
|                 ) | ||||
| 
 | ||||
|         if as_event: | ||||
|             for user, state in results.items(): | ||||
|                 content = state | ||||
|                 content["user_id"] = user.to_string() | ||||
| 
 | ||||
|                 if "last_active" in content: | ||||
|                     content["last_active_ago"] = int( | ||||
|                         self._clock.time_msec() - content.pop("last_active") | ||||
|                     ) | ||||
| 
 | ||||
|                 results[user] = {"type": "m.presence", "content": content} | ||||
| 
 | ||||
|         defer.returnValue(results) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def set_state(self, target_user, auth_user, state): | ||||
|  |  | |||
|  | @ -171,7 +171,6 @@ class ReceiptEventSource(object): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_new_events_for_user(self, user, from_key, limit): | ||||
|         defer.returnValue(([], from_key)) | ||||
|         from_key = int(from_key) | ||||
|         to_key = yield self.get_current_key() | ||||
| 
 | ||||
|  | @ -194,7 +193,6 @@ class ReceiptEventSource(object): | |||
|     @defer.inlineCallbacks | ||||
|     def get_pagination_rows(self, user, config, key): | ||||
|         to_key = int(config.from_key) | ||||
|         defer.returnValue(([], to_key)) | ||||
| 
 | ||||
|         if config.to_key: | ||||
|             from_key = int(config.to_key) | ||||
|  |  | |||
|  | @ -91,7 +91,7 @@ class RegistrationHandler(BaseHandler): | |||
|             user = UserID(localpart, self.hs.hostname) | ||||
|             user_id = user.to_string() | ||||
| 
 | ||||
|             token = self._generate_token(user_id) | ||||
|             token = self.generate_token(user_id) | ||||
|             yield self.store.register( | ||||
|                 user_id=user_id, | ||||
|                 token=token, | ||||
|  | @ -111,7 +111,7 @@ class RegistrationHandler(BaseHandler): | |||
|                     user_id = user.to_string() | ||||
|                     yield self.check_user_id_is_valid(user_id) | ||||
| 
 | ||||
|                     token = self._generate_token(user_id) | ||||
|                     token = self.generate_token(user_id) | ||||
|                     yield self.store.register( | ||||
|                         user_id=user_id, | ||||
|                         token=token, | ||||
|  | @ -161,7 +161,7 @@ class RegistrationHandler(BaseHandler): | |||
|                 400, "Invalid user localpart for this application service.", | ||||
|                 errcode=Codes.EXCLUSIVE | ||||
|             ) | ||||
|         token = self._generate_token(user_id) | ||||
|         token = self.generate_token(user_id) | ||||
|         yield self.store.register( | ||||
|             user_id=user_id, | ||||
|             token=token, | ||||
|  | @ -208,7 +208,7 @@ class RegistrationHandler(BaseHandler): | |||
|         user_id = user.to_string() | ||||
| 
 | ||||
|         yield self.check_user_id_is_valid(user_id) | ||||
|         token = self._generate_token(user_id) | ||||
|         token = self.generate_token(user_id) | ||||
|         try: | ||||
|             yield self.store.register( | ||||
|                 user_id=user_id, | ||||
|  | @ -273,7 +273,7 @@ class RegistrationHandler(BaseHandler): | |||
|                     errcode=Codes.EXCLUSIVE | ||||
|                 ) | ||||
| 
 | ||||
|     def _generate_token(self, user_id): | ||||
|     def generate_token(self, user_id): | ||||
|         # urlsafe variant uses _ and - so use . as the separator and replace | ||||
|         # all =s with .s so http clients don't quote =s when it is used as | ||||
|         # query params. | ||||
|  |  | |||
|  | @ -557,15 +557,9 @@ class RoomMemberHandler(BaseHandler): | |||
|         """Returns a list of roomids that the user has any of the given | ||||
|         membership states in.""" | ||||
| 
 | ||||
|         app_service = yield self.store.get_app_service_by_user_id( | ||||
|             user.to_string() | ||||
|         rooms = yield self.store.get_rooms_for_user( | ||||
|             user.to_string(), | ||||
|         ) | ||||
|         if app_service: | ||||
|             rooms = yield self.store.get_app_service_rooms(app_service) | ||||
|         else: | ||||
|             rooms = yield self.store.get_rooms_for_user( | ||||
|                 user.to_string(), | ||||
|             ) | ||||
| 
 | ||||
|         # For some reason the list of events contains duplicates | ||||
|         # TODO(paul): work out why because I really don't think it should | ||||
|  |  | |||
|  | @ -96,9 +96,18 @@ class SyncHandler(BaseHandler): | |||
|                 return self.current_sync_for_user(sync_config, since_token) | ||||
| 
 | ||||
|             rm_handler = self.hs.get_handlers().room_member_handler | ||||
|             room_ids = yield rm_handler.get_joined_rooms_for_user( | ||||
|                 sync_config.user | ||||
| 
 | ||||
|             app_service = yield self.store.get_app_service_by_user_id( | ||||
|                 sync_config.user.to_string() | ||||
|             ) | ||||
|             if app_service: | ||||
|                 rooms = yield self.store.get_app_service_rooms(app_service) | ||||
|                 room_ids = set(r.room_id for r in rooms) | ||||
|             else: | ||||
|                 room_ids = yield rm_handler.get_joined_rooms_for_user( | ||||
|                     sync_config.user | ||||
|                 ) | ||||
| 
 | ||||
|             result = yield self.notifier.wait_for_events( | ||||
|                 sync_config.user, room_ids, | ||||
|                 sync_config.filter, timeout, current_sync_callback | ||||
|  | @ -229,7 +238,16 @@ class SyncHandler(BaseHandler): | |||
|         logger.debug("Typing %r", typing_by_room) | ||||
| 
 | ||||
|         rm_handler = self.hs.get_handlers().room_member_handler | ||||
|         room_ids = yield rm_handler.get_joined_rooms_for_user(sync_config.user) | ||||
|         app_service = yield self.store.get_app_service_by_user_id( | ||||
|             sync_config.user.to_string() | ||||
|         ) | ||||
|         if app_service: | ||||
|             rooms = yield self.store.get_app_service_rooms(app_service) | ||||
|             room_ids = set(r.room_id for r in rooms) | ||||
|         else: | ||||
|             room_ids = yield rm_handler.get_joined_rooms_for_user( | ||||
|                 sync_config.user | ||||
|             ) | ||||
| 
 | ||||
|         # TODO (mjark): Does public mean "published"? | ||||
|         published_rooms = yield self.store.get_rooms(is_public=True) | ||||
|  | @ -294,15 +312,15 @@ class SyncHandler(BaseHandler): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _filter_events_for_client(self, user_id, room_id, events): | ||||
|         states = yield self.store.get_state_for_events( | ||||
|             room_id, [e.event_id for e in events], | ||||
|         event_id_to_state = yield self.store.get_state_for_events( | ||||
|             room_id, frozenset(e.event_id for e in events), | ||||
|             types=( | ||||
|                 (EventTypes.RoomHistoryVisibility, ""), | ||||
|                 (EventTypes.Member, user_id), | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         events_and_states = zip(events, states) | ||||
| 
 | ||||
|         def allowed(event_and_state): | ||||
|             event, state = event_and_state | ||||
| 
 | ||||
|         def allowed(event, state): | ||||
|             if event.type == EventTypes.RoomHistoryVisibility: | ||||
|                 return True | ||||
| 
 | ||||
|  | @ -331,10 +349,11 @@ class SyncHandler(BaseHandler): | |||
|                 return membership == Membership.INVITE | ||||
| 
 | ||||
|             return True | ||||
|         events_and_states = filter(allowed, events_and_states) | ||||
| 
 | ||||
|         defer.returnValue([ | ||||
|             ev | ||||
|             for ev, _ in events_and_states | ||||
|             event | ||||
|             for event in events | ||||
|             if allowed(event, event_id_to_state[event.event_id]) | ||||
|         ]) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  |  | |||
|  | @ -16,7 +16,7 @@ | |||
| 
 | ||||
| from twisted.internet import defer, reactor, protocol | ||||
| from twisted.internet.error import DNSLookupError | ||||
| from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool | ||||
| from twisted.web.client import readBody, HTTPConnectionPool, Agent | ||||
| from twisted.web.http_headers import Headers | ||||
| from twisted.web._newclient import ResponseDone | ||||
| 
 | ||||
|  | @ -55,41 +55,17 @@ incoming_responses_counter = metrics.register_counter( | |||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class MatrixFederationHttpAgent(_AgentBase): | ||||
| class MatrixFederationEndpointFactory(object): | ||||
|     def __init__(self, hs): | ||||
|         self.tls_context_factory = hs.tls_context_factory | ||||
| 
 | ||||
|     def __init__(self, reactor, pool=None): | ||||
|         _AgentBase.__init__(self, reactor, pool) | ||||
|     def endpointForURI(self, uri): | ||||
|         destination = uri.netloc | ||||
| 
 | ||||
|     def request(self, destination, endpoint, method, path, params, query, | ||||
|                 headers, body_producer): | ||||
| 
 | ||||
|         outgoing_requests_counter.inc(method) | ||||
| 
 | ||||
|         host = b"" | ||||
|         port = 0 | ||||
|         fragment = b"" | ||||
| 
 | ||||
|         parsed_URI = _URI(b"http", destination, host, port, path, params, | ||||
|                           query, fragment) | ||||
| 
 | ||||
|         # Set the connection pool key to be the destination. | ||||
|         key = destination | ||||
| 
 | ||||
|         d = self._requestWithEndpoint(key, endpoint, method, parsed_URI, | ||||
|                                       headers, body_producer, | ||||
|                                       parsed_URI.originForm) | ||||
| 
 | ||||
|         def _cb(response): | ||||
|             incoming_responses_counter.inc(method, response.code) | ||||
|             return response | ||||
| 
 | ||||
|         def _eb(failure): | ||||
|             incoming_responses_counter.inc(method, "ERR") | ||||
|             return failure | ||||
| 
 | ||||
|         d.addCallbacks(_cb, _eb) | ||||
| 
 | ||||
|         return d | ||||
|         return matrix_federation_endpoint( | ||||
|             reactor, destination, timeout=10, | ||||
|             ssl_context_factory=self.tls_context_factory | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class MatrixFederationHttpClient(object): | ||||
|  | @ -107,12 +83,18 @@ class MatrixFederationHttpClient(object): | |||
|         self.server_name = hs.hostname | ||||
|         pool = HTTPConnectionPool(reactor) | ||||
|         pool.maxPersistentPerHost = 10 | ||||
|         self.agent = MatrixFederationHttpAgent(reactor, pool=pool) | ||||
|         self.agent = Agent.usingEndpointFactory( | ||||
|             reactor, MatrixFederationEndpointFactory(hs), pool=pool | ||||
|         ) | ||||
|         self.clock = hs.get_clock() | ||||
|         self.version_string = hs.version_string | ||||
| 
 | ||||
|         self._next_id = 1 | ||||
| 
 | ||||
|     def _create_url(self, destination, path_bytes, param_bytes, query_bytes): | ||||
|         return urlparse.urlunparse( | ||||
|             ("matrix", destination, path_bytes, param_bytes, query_bytes, "") | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _create_request(self, destination, method, path_bytes, | ||||
|                         body_callback, headers_dict={}, param_bytes=b"", | ||||
|  | @ -123,8 +105,8 @@ class MatrixFederationHttpClient(object): | |||
|         headers_dict[b"User-Agent"] = [self.version_string] | ||||
|         headers_dict[b"Host"] = [destination] | ||||
| 
 | ||||
|         url_bytes = urlparse.urlunparse( | ||||
|             ("", "", path_bytes, param_bytes, query_bytes, "",) | ||||
|         url_bytes = self._create_url( | ||||
|             destination, path_bytes, param_bytes, query_bytes | ||||
|         ) | ||||
| 
 | ||||
|         txn_id = "%s-O-%s" % (method, self._next_id) | ||||
|  | @ -139,8 +121,8 @@ class MatrixFederationHttpClient(object): | |||
|         # (once we have reliable transactions in place) | ||||
|         retries_left = 5 | ||||
| 
 | ||||
|         endpoint = preserve_context_over_fn( | ||||
|             self._getEndpoint, reactor, destination | ||||
|         http_url_bytes = urlparse.urlunparse( | ||||
|             ("", "", path_bytes, param_bytes, query_bytes, "") | ||||
|         ) | ||||
| 
 | ||||
|         log_result = None | ||||
|  | @ -148,17 +130,14 @@ class MatrixFederationHttpClient(object): | |||
|             while True: | ||||
|                 producer = None | ||||
|                 if body_callback: | ||||
|                     producer = body_callback(method, url_bytes, headers_dict) | ||||
|                     producer = body_callback(method, http_url_bytes, headers_dict) | ||||
| 
 | ||||
|                 try: | ||||
|                     def send_request(): | ||||
|                         request_deferred = self.agent.request( | ||||
|                             destination, | ||||
|                             endpoint, | ||||
|                         request_deferred = preserve_context_over_fn( | ||||
|                             self.agent.request, | ||||
|                             method, | ||||
|                             path_bytes, | ||||
|                             param_bytes, | ||||
|                             query_bytes, | ||||
|                             url_bytes, | ||||
|                             Headers(headers_dict), | ||||
|                             producer | ||||
|                         ) | ||||
|  | @ -452,12 +431,6 @@ class MatrixFederationHttpClient(object): | |||
| 
 | ||||
|         defer.returnValue((length, headers)) | ||||
| 
 | ||||
|     def _getEndpoint(self, reactor, destination): | ||||
|         return matrix_federation_endpoint( | ||||
|             reactor, destination, timeout=10, | ||||
|             ssl_context_factory=self.hs.tls_context_factory | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class _ReadBodyToFileProtocol(protocol.Protocol): | ||||
|     def __init__(self, stream, deferred, max_size): | ||||
|  |  | |||
|  | @ -18,8 +18,12 @@ from __future__ import absolute_import | |||
| 
 | ||||
| import logging | ||||
| from resource import getrusage, getpagesize, RUSAGE_SELF | ||||
| import functools | ||||
| import os | ||||
| import stat | ||||
| import time | ||||
| 
 | ||||
| from twisted.internet import reactor | ||||
| 
 | ||||
| from .metric import ( | ||||
|     CounterMetric, CallbackMetric, DistributionMetric, CacheMetric | ||||
|  | @ -144,3 +148,50 @@ def _process_fds(): | |||
|     return counts | ||||
| 
 | ||||
| get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"]) | ||||
| 
 | ||||
| reactor_metrics = get_metrics_for("reactor") | ||||
| tick_time = reactor_metrics.register_distribution("tick_time") | ||||
| pending_calls_metric = reactor_metrics.register_distribution("pending_calls") | ||||
| 
 | ||||
| 
 | ||||
| def runUntilCurrentTimer(func): | ||||
| 
 | ||||
|     @functools.wraps(func) | ||||
|     def f(*args, **kwargs): | ||||
|         now = reactor.seconds() | ||||
|         num_pending = 0 | ||||
| 
 | ||||
|         # _newTimedCalls is one long list of *all* pending calls. Below loop | ||||
|         # is based off of impl of reactor.runUntilCurrent | ||||
|         for delayed_call in reactor._newTimedCalls: | ||||
|             if delayed_call.time > now: | ||||
|                 break | ||||
| 
 | ||||
|             if delayed_call.delayed_time > 0: | ||||
|                 continue | ||||
| 
 | ||||
|             num_pending += 1 | ||||
| 
 | ||||
|         num_pending += len(reactor.threadCallQueue) | ||||
| 
 | ||||
|         start = time.time() * 1000 | ||||
|         ret = func(*args, **kwargs) | ||||
|         end = time.time() * 1000 | ||||
|         tick_time.inc_by(end - start) | ||||
|         pending_calls_metric.inc_by(num_pending) | ||||
|         return ret | ||||
| 
 | ||||
|     return f | ||||
| 
 | ||||
| 
 | ||||
| try: | ||||
|     # Ensure the reactor has all the attributes we expect | ||||
|     reactor.runUntilCurrent | ||||
|     reactor._newTimedCalls | ||||
|     reactor.threadCallQueue | ||||
| 
 | ||||
|     # runUntilCurrent is called when we have pending calls. It is called once | ||||
|     # per iteratation after fd polling. | ||||
|     reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent) | ||||
| except AttributeError: | ||||
|     pass | ||||
|  |  | |||
|  | @ -94,17 +94,14 @@ class PusherPool: | |||
|                 self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def remove_pushers_by_user_access_token(self, user_id, not_access_token_id): | ||||
|     def remove_pushers_by_user(self, user_id): | ||||
|         all = yield self.store.get_all_pushers() | ||||
|         logger.info( | ||||
|             "Removing all pushers for user %s except access token %s", | ||||
|             user_id, not_access_token_id | ||||
|             "Removing all pushers for user %s", | ||||
|             user_id, | ||||
|         ) | ||||
|         for p in all: | ||||
|             if ( | ||||
|                 p['user_name'] == user_id and | ||||
|                 p['access_token'] != not_access_token_id | ||||
|             ): | ||||
|             if p['user_name'] == user_id: | ||||
|                 logger.info( | ||||
|                     "Removing pusher for app id %s, pushkey %s, user %s", | ||||
|                     p['app_id'], p['pushkey'], p['user_name'] | ||||
|  |  | |||
|  | @ -19,7 +19,7 @@ logger = logging.getLogger(__name__) | |||
| 
 | ||||
| REQUIREMENTS = { | ||||
|     "syutil>=0.0.7": ["syutil>=0.0.7"], | ||||
|     "Twisted==14.0.2": ["twisted==14.0.2"], | ||||
|     "Twisted>=15.1.0": ["twisted>=15.1.0"], | ||||
|     "service_identity>=1.0.0": ["service_identity>=1.0.0"], | ||||
|     "pyopenssl>=0.14": ["OpenSSL>=0.14"], | ||||
|     "pyyaml": ["yaml"], | ||||
|  |  | |||
|  | @ -85,9 +85,8 @@ class LoginRestServlet(ClientV1RestServlet): | |||
|             user_id = UserID.create( | ||||
|                 user_id, self.hs.hostname).to_string() | ||||
| 
 | ||||
|         handler = self.handlers.login_handler | ||||
|         token = yield handler.login( | ||||
|             user=user_id, | ||||
|         token = yield self.handlers.auth_handler.login_with_password( | ||||
|             user_id=user_id | ||||
|             password=login_submission["password"]) | ||||
| 
 | ||||
|         result = { | ||||
|  |  | |||
|  | @ -36,7 +36,6 @@ class PasswordRestServlet(RestServlet): | |||
|         self.hs = hs | ||||
|         self.auth = hs.get_auth() | ||||
|         self.auth_handler = hs.get_handlers().auth_handler | ||||
|         self.login_handler = hs.get_handlers().login_handler | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request): | ||||
|  | @ -47,7 +46,7 @@ class PasswordRestServlet(RestServlet): | |||
|         authed, result, params = yield self.auth_handler.check_auth([ | ||||
|             [LoginType.PASSWORD], | ||||
|             [LoginType.EMAIL_IDENTITY] | ||||
|         ], body) | ||||
|         ], body, self.hs.get_ip_from_request(request)) | ||||
| 
 | ||||
|         if not authed: | ||||
|             defer.returnValue((401, result)) | ||||
|  | @ -79,7 +78,7 @@ class PasswordRestServlet(RestServlet): | |||
|             raise SynapseError(400, "", Codes.MISSING_PARAM) | ||||
|         new_password = params['new_password'] | ||||
| 
 | ||||
|         yield self.login_handler.set_password( | ||||
|         yield self.auth_handler.set_password( | ||||
|             user_id, new_password, None | ||||
|         ) | ||||
| 
 | ||||
|  | @ -95,7 +94,6 @@ class ThreepidRestServlet(RestServlet): | |||
|     def __init__(self, hs): | ||||
|         super(ThreepidRestServlet, self).__init__() | ||||
|         self.hs = hs | ||||
|         self.login_handler = hs.get_handlers().login_handler | ||||
|         self.identity_handler = hs.get_handlers().identity_handler | ||||
|         self.auth = hs.get_auth() | ||||
| 
 | ||||
|  | @ -135,7 +133,7 @@ class ThreepidRestServlet(RestServlet): | |||
|                 logger.warn("Couldn't add 3pid: invalid response from ID sevrer") | ||||
|                 raise SynapseError(500, "Invalid response from ID Server") | ||||
| 
 | ||||
|         yield self.login_handler.add_threepid( | ||||
|         yield self.auth_handler.add_threepid( | ||||
|             auth_user.to_string(), | ||||
|             threepid['medium'], | ||||
|             threepid['address'], | ||||
|  |  | |||
|  | @ -17,6 +17,7 @@ from twisted.internet import defer | |||
| 
 | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.http.servlet import RestServlet | ||||
| from synapse.types import UserID | ||||
| from syutil.jsonutil import encode_canonical_json | ||||
| 
 | ||||
| from ._base import client_v2_pattern | ||||
|  | @ -164,45 +165,63 @@ class KeyQueryServlet(RestServlet): | |||
|         super(KeyQueryServlet, self).__init__() | ||||
|         self.store = hs.get_datastore() | ||||
|         self.auth = hs.get_auth() | ||||
|         self.federation = hs.get_replication_layer() | ||||
|         self.is_mine = hs.is_mine | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request, user_id, device_id): | ||||
|         logger.debug("onPOST") | ||||
|         yield self.auth.get_user_by_req(request) | ||||
|         try: | ||||
|             body = json.loads(request.content.read()) | ||||
|         except: | ||||
|             raise SynapseError(400, "Invalid key JSON") | ||||
|         query = [] | ||||
|         for user_id, device_ids in body.get("device_keys", {}).items(): | ||||
|             if not device_ids: | ||||
|                 query.append((user_id, None)) | ||||
|             else: | ||||
|                 for device_id in device_ids: | ||||
|                     query.append((user_id, device_id)) | ||||
|         results = yield self.store.get_e2e_device_keys(query) | ||||
|         defer.returnValue(self.json_result(request, results)) | ||||
|         result = yield self.handle_request(body) | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, user_id, device_id): | ||||
|         auth_user, client_info = yield self.auth.get_user_by_req(request) | ||||
|         auth_user_id = auth_user.to_string() | ||||
|         if not user_id: | ||||
|             user_id = auth_user_id | ||||
|         if not device_id: | ||||
|             device_id = None | ||||
|         # Returns a map of user_id->device_id->json_bytes. | ||||
|         results = yield self.store.get_e2e_device_keys([(user_id, device_id)]) | ||||
|         defer.returnValue(self.json_result(request, results)) | ||||
|         user_id = user_id if user_id else auth_user_id | ||||
|         device_ids = [device_id] if device_id else [] | ||||
|         result = yield self.handle_request( | ||||
|             {"device_keys": {user_id: device_ids}} | ||||
|         ) | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def handle_request(self, body): | ||||
|         local_query = [] | ||||
|         remote_queries = {} | ||||
|         for user_id, device_ids in body.get("device_keys", {}).items(): | ||||
|             user = UserID.from_string(user_id) | ||||
|             if self.is_mine(user): | ||||
|                 if not device_ids: | ||||
|                     local_query.append((user_id, None)) | ||||
|                 else: | ||||
|                     for device_id in device_ids: | ||||
|                         local_query.append((user_id, device_id)) | ||||
|             else: | ||||
|                 remote_queries.setdefault(user.domain, {})[user_id] = list( | ||||
|                     device_ids | ||||
|                 ) | ||||
|         results = yield self.store.get_e2e_device_keys(local_query) | ||||
| 
 | ||||
|     def json_result(self, request, results): | ||||
|         json_result = {} | ||||
|         for user_id, device_keys in results.items(): | ||||
|             for device_id, json_bytes in device_keys.items(): | ||||
|                 json_result.setdefault(user_id, {})[device_id] = json.loads( | ||||
|                     json_bytes | ||||
|                 ) | ||||
|         return (200, {"device_keys": json_result}) | ||||
| 
 | ||||
|         for destination, device_keys in remote_queries.items(): | ||||
|             remote_result = yield self.federation.query_client_keys( | ||||
|                 destination, {"device_keys": device_keys} | ||||
|             ) | ||||
|             for user_id, keys in remote_result["device_keys"].items(): | ||||
|                 if user_id in device_keys: | ||||
|                     json_result[user_id] = keys | ||||
|         defer.returnValue((200, {"device_keys": json_result})) | ||||
| 
 | ||||
| 
 | ||||
| class OneTimeKeyServlet(RestServlet): | ||||
|  | @ -236,14 +255,16 @@ class OneTimeKeyServlet(RestServlet): | |||
|         self.store = hs.get_datastore() | ||||
|         self.auth = hs.get_auth() | ||||
|         self.clock = hs.get_clock() | ||||
|         self.federation = hs.get_replication_layer() | ||||
|         self.is_mine = hs.is_mine | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, user_id, device_id, algorithm): | ||||
|         yield self.auth.get_user_by_req(request) | ||||
|         results = yield self.store.claim_e2e_one_time_keys( | ||||
|             [(user_id, device_id, algorithm)] | ||||
|         result = yield self.handle_request( | ||||
|             {"one_time_keys": {user_id: {device_id: algorithm}}} | ||||
|         ) | ||||
|         defer.returnValue(self.json_result(request, results)) | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request, user_id, device_id, algorithm): | ||||
|  | @ -252,14 +273,24 @@ class OneTimeKeyServlet(RestServlet): | |||
|             body = json.loads(request.content.read()) | ||||
|         except: | ||||
|             raise SynapseError(400, "Invalid key JSON") | ||||
|         query = [] | ||||
|         for user_id, device_keys in body.get("one_time_keys", {}).items(): | ||||
|             for device_id, algorithm in device_keys.items(): | ||||
|                 query.append((user_id, device_id, algorithm)) | ||||
|         results = yield self.store.claim_e2e_one_time_keys(query) | ||||
|         defer.returnValue(self.json_result(request, results)) | ||||
|         result = yield self.handle_request(body) | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def handle_request(self, body): | ||||
|         local_query = [] | ||||
|         remote_queries = {} | ||||
|         for user_id, device_keys in body.get("one_time_keys", {}).items(): | ||||
|             user = UserID.from_string(user_id) | ||||
|             if self.is_mine(user): | ||||
|                 for device_id, algorithm in device_keys.items(): | ||||
|                     local_query.append((user_id, device_id, algorithm)) | ||||
|             else: | ||||
|                 remote_queries.setdefault(user.domain, {})[user_id] = ( | ||||
|                     device_keys | ||||
|                 ) | ||||
|         results = yield self.store.claim_e2e_one_time_keys(local_query) | ||||
| 
 | ||||
|     def json_result(self, request, results): | ||||
|         json_result = {} | ||||
|         for user_id, device_keys in results.items(): | ||||
|             for device_id, keys in device_keys.items(): | ||||
|  | @ -267,7 +298,16 @@ class OneTimeKeyServlet(RestServlet): | |||
|                     json_result.setdefault(user_id, {})[device_id] = { | ||||
|                         key_id: json.loads(json_bytes) | ||||
|                     } | ||||
|         return (200, {"one_time_keys": json_result}) | ||||
| 
 | ||||
|         for destination, device_keys in remote_queries.items(): | ||||
|             remote_result = yield self.federation.claim_client_keys( | ||||
|                 destination, {"one_time_keys": device_keys} | ||||
|             ) | ||||
|             for user_id, keys in remote_result["one_time_keys"].items(): | ||||
|                 if user_id in device_keys: | ||||
|                     json_result[user_id] = keys | ||||
| 
 | ||||
|         defer.returnValue((200, {"one_time_keys": json_result})) | ||||
| 
 | ||||
| 
 | ||||
| def register_servlets(hs, http_server): | ||||
|  |  | |||
|  | @ -41,7 +41,7 @@ logger = logging.getLogger(__name__) | |||
| 
 | ||||
| 
 | ||||
| class RegisterRestServlet(RestServlet): | ||||
|     PATTERN = client_v2_pattern("/register*") | ||||
|     PATTERN = client_v2_pattern("/register") | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(RegisterRestServlet, self).__init__() | ||||
|  | @ -50,7 +50,6 @@ class RegisterRestServlet(RestServlet): | |||
|         self.auth_handler = hs.get_handlers().auth_handler | ||||
|         self.registration_handler = hs.get_handlers().registration_handler | ||||
|         self.identity_handler = hs.get_handlers().identity_handler | ||||
|         self.login_handler = hs.get_handlers().login_handler | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request): | ||||
|  | @ -148,7 +147,7 @@ class RegisterRestServlet(RestServlet): | |||
|                 if reqd not in threepid: | ||||
|                     logger.info("Can't add incomplete 3pid") | ||||
|                 else: | ||||
|                     yield self.login_handler.add_threepid( | ||||
|                     yield self.auth_handler.add_threepid( | ||||
|                         user_id, | ||||
|                         threepid['medium'], | ||||
|                         threepid['address'], | ||||
|  | @ -224,6 +223,9 @@ class RegisterRestServlet(RestServlet): | |||
|             if k not in body: | ||||
|                 absent.append(k) | ||||
| 
 | ||||
|         if len(absent) > 0: | ||||
|             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | ||||
| 
 | ||||
|         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||
|             'email', body['email'] | ||||
|         ) | ||||
|  | @ -231,9 +233,6 @@ class RegisterRestServlet(RestServlet): | |||
|         if existingUid is not None: | ||||
|             raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) | ||||
| 
 | ||||
|         if len(absent) > 0: | ||||
|             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | ||||
| 
 | ||||
|         ret = yield self.identity_handler.requestEmailToken(**body) | ||||
|         defer.returnValue((200, ret)) | ||||
| 
 | ||||
|  |  | |||
|  | @ -18,7 +18,7 @@ from twisted.internet import defer | |||
| 
 | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.util.async import run_on_reactor | ||||
| from synapse.util.expiringcache import ExpiringCache | ||||
| from synapse.util.caches.expiringcache import ExpiringCache | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.api.errors import AuthError | ||||
| from synapse.api.auth import AuthEventTypes | ||||
|  | @ -96,7 +96,7 @@ class StateHandler(object): | |||
|             cache.ts = self.clock.time_msec() | ||||
|             state = cache.state | ||||
|         else: | ||||
|             res = yield self.resolve_state_groups(event_ids) | ||||
|             res = yield self.resolve_state_groups(room_id, event_ids) | ||||
|             state = res[1] | ||||
| 
 | ||||
|         if event_type: | ||||
|  | @ -155,13 +155,13 @@ class StateHandler(object): | |||
| 
 | ||||
|         if event.is_state(): | ||||
|             ret = yield self.resolve_state_groups( | ||||
|                 [e for e, _ in event.prev_events], | ||||
|                 event.room_id, [e for e, _ in event.prev_events], | ||||
|                 event_type=event.type, | ||||
|                 state_key=event.state_key, | ||||
|             ) | ||||
|         else: | ||||
|             ret = yield self.resolve_state_groups( | ||||
|                 [e for e, _ in event.prev_events], | ||||
|                 event.room_id, [e for e, _ in event.prev_events], | ||||
|             ) | ||||
| 
 | ||||
|         group, curr_state, prev_state = ret | ||||
|  | @ -180,7 +180,7 @@ class StateHandler(object): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @log_function | ||||
|     def resolve_state_groups(self, event_ids, event_type=None, state_key=""): | ||||
|     def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""): | ||||
|         """ Given a list of event_ids this method fetches the state at each | ||||
|         event, resolves conflicts between them and returns them. | ||||
| 
 | ||||
|  | @ -205,7 +205,7 @@ class StateHandler(object): | |||
|                 ) | ||||
| 
 | ||||
|         state_groups = yield self.store.get_state_groups( | ||||
|             event_ids | ||||
|             room_id, event_ids | ||||
|         ) | ||||
| 
 | ||||
|         logger.debug( | ||||
|  |  | |||
|  | @ -99,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore, | |||
|         key = (user.to_string(), access_token, device_id, ip) | ||||
| 
 | ||||
|         try: | ||||
|             last_seen = self.client_ip_last_seen.get(*key) | ||||
|             last_seen = self.client_ip_last_seen.get(key) | ||||
|         except KeyError: | ||||
|             last_seen = None | ||||
| 
 | ||||
|  | @ -107,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore, | |||
|         if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: | ||||
|             defer.returnValue(None) | ||||
| 
 | ||||
|         self.client_ip_last_seen.prefill(*key + (now,)) | ||||
|         self.client_ip_last_seen.prefill(key, now) | ||||
| 
 | ||||
|         # It's safe not to lock here: a) no unique constraint, | ||||
|         # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely | ||||
|  | @ -354,6 +354,11 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, | |||
|                     ) | ||||
|                 logger.debug("Running script %s", relative_path) | ||||
|                 module.run_upgrade(cur, database_engine) | ||||
|             elif ext == ".pyc": | ||||
|                 # Sometimes .pyc files turn up anyway even though we've | ||||
|                 # disabled their generation; e.g. from distribution package | ||||
|                 # installers. Silently skip it | ||||
|                 pass | ||||
|             elif ext == ".sql": | ||||
|                 # A plain old .sql file, just read and execute it | ||||
|                 logger.debug("Applying schema %s", relative_path) | ||||
|  |  | |||
|  | @ -17,21 +17,20 @@ import logging | |||
| from synapse.api.errors import StoreError | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.util.logcontext import preserve_context_over_fn, LoggingContext | ||||
| from synapse.util.lrucache import LruCache | ||||
| from synapse.util.caches.dictionary_cache import DictionaryCache | ||||
| from synapse.util.caches.descriptors import Cache | ||||
| import synapse.metrics | ||||
| 
 | ||||
| from util.id_generators import IdGenerator, StreamIdGenerator | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from collections import namedtuple, OrderedDict | ||||
| from collections import namedtuple | ||||
| 
 | ||||
| import functools | ||||
| import sys | ||||
| import time | ||||
| import threading | ||||
| 
 | ||||
| DEBUG_CACHES = False | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|  | @ -47,159 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time") | |||
| sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) | ||||
| sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) | ||||
| 
 | ||||
| caches_by_name = {} | ||||
| cache_counter = metrics.register_cache( | ||||
|     "cache", | ||||
|     lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, | ||||
|     labels=["name"], | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class Cache(object): | ||||
| 
 | ||||
|     def __init__(self, name, max_entries=1000, keylen=1, lru=False): | ||||
|         if lru: | ||||
|             self.cache = LruCache(max_size=max_entries) | ||||
|             self.max_entries = None | ||||
|         else: | ||||
|             self.cache = OrderedDict() | ||||
|             self.max_entries = max_entries | ||||
| 
 | ||||
|         self.name = name | ||||
|         self.keylen = keylen | ||||
|         self.sequence = 0 | ||||
|         self.thread = None | ||||
|         caches_by_name[name] = self.cache | ||||
| 
 | ||||
|     def check_thread(self): | ||||
|         expected_thread = self.thread | ||||
|         if expected_thread is None: | ||||
|             self.thread = threading.current_thread() | ||||
|         else: | ||||
|             if expected_thread is not threading.current_thread(): | ||||
|                 raise ValueError( | ||||
|                     "Cache objects can only be accessed from the main thread" | ||||
|                 ) | ||||
| 
 | ||||
|     def get(self, *keyargs): | ||||
|         if len(keyargs) != self.keylen: | ||||
|             raise ValueError("Expected a key to have %d items", self.keylen) | ||||
| 
 | ||||
|         if keyargs in self.cache: | ||||
|             cache_counter.inc_hits(self.name) | ||||
|             return self.cache[keyargs] | ||||
| 
 | ||||
|         cache_counter.inc_misses(self.name) | ||||
|         raise KeyError() | ||||
| 
 | ||||
|     def update(self, sequence, *args): | ||||
|         self.check_thread() | ||||
|         if self.sequence == sequence: | ||||
|             # Only update the cache if the caches sequence number matches the | ||||
|             # number that the cache had before the SELECT was started (SYN-369) | ||||
|             self.prefill(*args) | ||||
| 
 | ||||
|     def prefill(self, *args):  # because I can't  *keyargs, value | ||||
|         keyargs = args[:-1] | ||||
|         value = args[-1] | ||||
| 
 | ||||
|         if len(keyargs) != self.keylen: | ||||
|             raise ValueError("Expected a key to have %d items", self.keylen) | ||||
| 
 | ||||
|         if self.max_entries is not None: | ||||
|             while len(self.cache) >= self.max_entries: | ||||
|                 self.cache.popitem(last=False) | ||||
| 
 | ||||
|         self.cache[keyargs] = value | ||||
| 
 | ||||
|     def invalidate(self, *keyargs): | ||||
|         self.check_thread() | ||||
|         if len(keyargs) != self.keylen: | ||||
|             raise ValueError("Expected a key to have %d items", self.keylen) | ||||
|         # Increment the sequence number so that any SELECT statements that | ||||
|         # raced with the INSERT don't update the cache (SYN-369) | ||||
|         self.sequence += 1 | ||||
|         self.cache.pop(keyargs, None) | ||||
| 
 | ||||
|     def invalidate_all(self): | ||||
|         self.check_thread() | ||||
|         self.sequence += 1 | ||||
|         self.cache.clear() | ||||
| 
 | ||||
| 
 | ||||
| class CacheDescriptor(object): | ||||
|     """ A method decorator that applies a memoizing cache around the function. | ||||
| 
 | ||||
|     The function is presumed to take zero or more arguments, which are used in | ||||
|     a tuple as the key for the cache. Hits are served directly from the cache; | ||||
|     misses use the function body to generate the value. | ||||
| 
 | ||||
|     The wrapped function has an additional member, a callable called | ||||
|     "invalidate". This can be used to remove individual entries from the cache. | ||||
| 
 | ||||
|     The wrapped function has another additional callable, called "prefill", | ||||
|     which can be used to insert values into the cache specifically, without | ||||
|     calling the calculation function. | ||||
|     """ | ||||
|     def __init__(self, orig, max_entries=1000, num_args=1, lru=False): | ||||
|         self.orig = orig | ||||
| 
 | ||||
|         self.max_entries = max_entries | ||||
|         self.num_args = num_args | ||||
|         self.lru = lru | ||||
| 
 | ||||
|     def __get__(self, obj, objtype=None): | ||||
|         cache = Cache( | ||||
|             name=self.orig.__name__, | ||||
|             max_entries=self.max_entries, | ||||
|             keylen=self.num_args, | ||||
|             lru=self.lru, | ||||
|         ) | ||||
| 
 | ||||
|         @functools.wraps(self.orig) | ||||
|         @defer.inlineCallbacks | ||||
|         def wrapped(*keyargs): | ||||
|             try: | ||||
|                 cached_result = cache.get(*keyargs[:self.num_args]) | ||||
|                 if DEBUG_CACHES: | ||||
|                     actual_result = yield self.orig(obj, *keyargs) | ||||
|                     if actual_result != cached_result: | ||||
|                         logger.error( | ||||
|                             "Stale cache entry %s%r: cached: %r, actual %r", | ||||
|                             self.orig.__name__, keyargs, | ||||
|                             cached_result, actual_result, | ||||
|                         ) | ||||
|                         raise ValueError("Stale cache entry") | ||||
|                 defer.returnValue(cached_result) | ||||
|             except KeyError: | ||||
|                 # Get the sequence number of the cache before reading from the | ||||
|                 # database so that we can tell if the cache is invalidated | ||||
|                 # while the SELECT is executing (SYN-369) | ||||
|                 sequence = cache.sequence | ||||
| 
 | ||||
|                 ret = yield self.orig(obj, *keyargs) | ||||
| 
 | ||||
|                 cache.update(sequence, *keyargs[:self.num_args] + (ret,)) | ||||
| 
 | ||||
|                 defer.returnValue(ret) | ||||
| 
 | ||||
|         wrapped.invalidate = cache.invalidate | ||||
|         wrapped.invalidate_all = cache.invalidate_all | ||||
|         wrapped.prefill = cache.prefill | ||||
| 
 | ||||
|         obj.__dict__[self.orig.__name__] = wrapped | ||||
| 
 | ||||
|         return wrapped | ||||
| 
 | ||||
| 
 | ||||
| def cached(max_entries=1000, num_args=1, lru=False): | ||||
|     return lambda orig: CacheDescriptor( | ||||
|         orig, | ||||
|         max_entries=max_entries, | ||||
|         num_args=num_args, | ||||
|         lru=lru | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| class LoggingTransaction(object): | ||||
|     """An object that almost-transparently proxies for the 'txn' object | ||||
|  | @ -321,6 +167,8 @@ class SQLBaseStore(object): | |||
|         self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, | ||||
|                                       max_entries=hs.config.event_cache_size) | ||||
| 
 | ||||
|         self._state_group_cache = DictionaryCache("*stateGroupCache*", 100000) | ||||
| 
 | ||||
|         self._event_fetch_lock = threading.Condition() | ||||
|         self._event_fetch_list = [] | ||||
|         self._event_fetch_ongoing = 0 | ||||
|  |  | |||
|  | @ -13,7 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached | ||||
| 
 | ||||
| from synapse.api.errors import SynapseError | ||||
| 
 | ||||
|  | @ -104,7 +105,7 @@ class DirectoryStore(SQLBaseStore): | |||
|                 }, | ||||
|                 desc="create_room_alias_association", | ||||
|             ) | ||||
|         self.get_aliases_for_room.invalidate(room_id) | ||||
|         self.get_aliases_for_room.invalidate((room_id,)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def delete_room_alias(self, room_alias): | ||||
|  | @ -114,7 +115,7 @@ class DirectoryStore(SQLBaseStore): | |||
|             room_alias, | ||||
|         ) | ||||
| 
 | ||||
|         self.get_aliases_for_room.invalidate(room_id) | ||||
|         self.get_aliases_for_room.invalidate((room_id,)) | ||||
|         defer.returnValue(room_id) | ||||
| 
 | ||||
|     def _delete_room_alias_txn(self, txn, room_alias): | ||||
|  |  | |||
|  | @ -15,7 +15,8 @@ | |||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached | ||||
| from syutil.base64util import encode_base64 | ||||
| 
 | ||||
| import logging | ||||
|  | @ -362,7 +363,7 @@ class EventFederationStore(SQLBaseStore): | |||
| 
 | ||||
|         for room_id in events_by_room: | ||||
|             txn.call_after( | ||||
|                 self.get_latest_event_ids_in_room.invalidate, room_id | ||||
|                 self.get_latest_event_ids_in_room.invalidate, (room_id,) | ||||
|             ) | ||||
| 
 | ||||
|     def get_backfill_events(self, room_id, event_list, limit): | ||||
|  | @ -505,4 +506,4 @@ class EventFederationStore(SQLBaseStore): | |||
|         query = "DELETE FROM event_forward_extremities WHERE room_id = ?" | ||||
| 
 | ||||
|         txn.execute(query, (room_id,)) | ||||
|         txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) | ||||
|         txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) | ||||
|  |  | |||
|  | @ -162,8 +162,8 @@ class EventsStore(SQLBaseStore): | |||
|         if current_state: | ||||
|             txn.call_after(self.get_current_state_for_key.invalidate_all) | ||||
|             txn.call_after(self.get_rooms_for_user.invalidate_all) | ||||
|             txn.call_after(self.get_users_in_room.invalidate, event.room_id) | ||||
|             txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) | ||||
|             txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) | ||||
|             txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) | ||||
|             txn.call_after(self.get_room_name_and_aliases, event.room_id) | ||||
| 
 | ||||
|             self._simple_delete_txn( | ||||
|  | @ -430,13 +430,13 @@ class EventsStore(SQLBaseStore): | |||
|                 if not context.rejected: | ||||
|                     txn.call_after( | ||||
|                         self.get_current_state_for_key.invalidate, | ||||
|                         event.room_id, event.type, event.state_key | ||||
|                         ) | ||||
|                         (event.room_id, event.type, event.state_key,) | ||||
|                     ) | ||||
| 
 | ||||
|                     if event.type in [EventTypes.Name, EventTypes.Aliases]: | ||||
|                         txn.call_after( | ||||
|                             self.get_room_name_and_aliases.invalidate, | ||||
|                             event.room_id | ||||
|                             (event.room_id,) | ||||
|                         ) | ||||
| 
 | ||||
|                     self._simple_upsert_txn( | ||||
|  | @ -567,8 +567,9 @@ class EventsStore(SQLBaseStore): | |||
|     def _invalidate_get_event_cache(self, event_id): | ||||
|         for check_redacted in (False, True): | ||||
|             for get_prev_content in (False, True): | ||||
|                 self._get_event_cache.invalidate(event_id, check_redacted, | ||||
|                                                  get_prev_content) | ||||
|                 self._get_event_cache.invalidate( | ||||
|                     (event_id, check_redacted, get_prev_content) | ||||
|                 ) | ||||
| 
 | ||||
|     def _get_event_txn(self, txn, event_id, check_redacted=True, | ||||
|                        get_prev_content=False, allow_rejected=False): | ||||
|  | @ -589,7 +590,7 @@ class EventsStore(SQLBaseStore): | |||
|         for event_id in events: | ||||
|             try: | ||||
|                 ret = self._get_event_cache.get( | ||||
|                     event_id, check_redacted, get_prev_content | ||||
|                     (event_id, check_redacted, get_prev_content,) | ||||
|                 ) | ||||
| 
 | ||||
|                 if allow_rejected or not ret.rejected_reason: | ||||
|  | @ -822,7 +823,7 @@ class EventsStore(SQLBaseStore): | |||
|                 ev.unsigned["prev_content"] = prev.get_dict()["content"] | ||||
| 
 | ||||
|         self._get_event_cache.prefill( | ||||
|             ev.event_id, check_redacted, get_prev_content, ev | ||||
|             (ev.event_id, check_redacted, get_prev_content), ev | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue(ev) | ||||
|  | @ -879,7 +880,7 @@ class EventsStore(SQLBaseStore): | |||
|                 ev.unsigned["prev_content"] = prev.get_dict()["content"] | ||||
| 
 | ||||
|         self._get_event_cache.prefill( | ||||
|             ev.event_id, check_redacted, get_prev_content, ev | ||||
|             (ev.event_id, check_redacted, get_prev_content), ev | ||||
|         ) | ||||
| 
 | ||||
|         return ev | ||||
|  |  | |||
|  | @ -13,7 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from _base import SQLBaseStore, cached | ||||
| from _base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
|  | @ -71,8 +72,7 @@ class KeyStore(SQLBaseStore): | |||
|             desc="store_server_certificate", | ||||
|         ) | ||||
| 
 | ||||
|     @cached() | ||||
|     @defer.inlineCallbacks | ||||
|     @cachedInlineCallbacks() | ||||
|     def get_all_server_verify_keys(self, server_name): | ||||
|         rows = yield self._simple_select_list( | ||||
|             table="server_signature_keys", | ||||
|  | @ -132,7 +132,7 @@ class KeyStore(SQLBaseStore): | |||
|             desc="store_server_verify_key", | ||||
|         ) | ||||
| 
 | ||||
|         self.get_all_server_verify_keys.invalidate(server_name) | ||||
|         self.get_all_server_verify_keys.invalidate((server_name,)) | ||||
| 
 | ||||
|     def store_server_keys_json(self, server_name, key_id, from_server, | ||||
|                                ts_now_ms, ts_expires_ms, key_json_bytes): | ||||
|  |  | |||
|  | @ -13,19 +13,23 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached, cachedList | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| 
 | ||||
| class PresenceStore(SQLBaseStore): | ||||
|     def create_presence(self, user_localpart): | ||||
|         return self._simple_insert( | ||||
|         res = self._simple_insert( | ||||
|             table="presence", | ||||
|             values={"user_id": user_localpart}, | ||||
|             desc="create_presence", | ||||
|         ) | ||||
| 
 | ||||
|         self.get_presence_state.invalidate((user_localpart,)) | ||||
|         return res | ||||
| 
 | ||||
|     def has_presence_state(self, user_localpart): | ||||
|         return self._simple_select_one( | ||||
|             table="presence", | ||||
|  | @ -35,6 +39,7 @@ class PresenceStore(SQLBaseStore): | |||
|             desc="has_presence_state", | ||||
|         ) | ||||
| 
 | ||||
|     @cached(max_entries=2000) | ||||
|     def get_presence_state(self, user_localpart): | ||||
|         return self._simple_select_one( | ||||
|             table="presence", | ||||
|  | @ -43,8 +48,27 @@ class PresenceStore(SQLBaseStore): | |||
|             desc="get_presence_state", | ||||
|         ) | ||||
| 
 | ||||
|     @cachedList(get_presence_state.cache, list_name="user_localparts") | ||||
|     def get_presence_states(self, user_localparts): | ||||
|         def f(txn): | ||||
|             results = {} | ||||
|             for user_localpart in user_localparts: | ||||
|                 res = self._simple_select_one_txn( | ||||
|                     txn, | ||||
|                     table="presence", | ||||
|                     keyvalues={"user_id": user_localpart}, | ||||
|                     retcols=["state", "status_msg", "mtime"], | ||||
|                     allow_none=True, | ||||
|                 ) | ||||
|                 if res: | ||||
|                     results[user_localpart] = res | ||||
| 
 | ||||
|             return results | ||||
| 
 | ||||
|         return self.runInteraction("get_presence_states", f) | ||||
| 
 | ||||
|     def set_presence_state(self, user_localpart, new_state): | ||||
|         return self._simple_update_one( | ||||
|         res = self._simple_update_one( | ||||
|             table="presence", | ||||
|             keyvalues={"user_id": user_localpart}, | ||||
|             updatevalues={"state": new_state["state"], | ||||
|  | @ -53,6 +77,9 @@ class PresenceStore(SQLBaseStore): | |||
|             desc="set_presence_state", | ||||
|         ) | ||||
| 
 | ||||
|         self.get_presence_state.invalidate((user_localpart,)) | ||||
|         return res | ||||
| 
 | ||||
|     def allow_presence_visible(self, observed_localpart, observer_userid): | ||||
|         return self._simple_insert( | ||||
|             table="presence_allow_inbound", | ||||
|  | @ -98,7 +125,7 @@ class PresenceStore(SQLBaseStore): | |||
|             updatevalues={"accepted": True}, | ||||
|             desc="set_presence_list_accepted", | ||||
|         ) | ||||
|         self.get_presence_list_accepted.invalidate(observer_localpart) | ||||
|         self.get_presence_list_accepted.invalidate((observer_localpart,)) | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
|     def get_presence_list(self, observer_localpart, accepted=None): | ||||
|  | @ -133,4 +160,4 @@ class PresenceStore(SQLBaseStore): | |||
|                        "observed_user_id": observed_userid}, | ||||
|             desc="del_presence_list", | ||||
|         ) | ||||
|         self.get_presence_list_accepted.invalidate(observer_localpart) | ||||
|         self.get_presence_list_accepted.invalidate((observer_localpart,)) | ||||
|  |  | |||
|  | @ -13,7 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| import logging | ||||
|  | @ -23,8 +24,7 @@ logger = logging.getLogger(__name__) | |||
| 
 | ||||
| 
 | ||||
| class PushRuleStore(SQLBaseStore): | ||||
|     @cached() | ||||
|     @defer.inlineCallbacks | ||||
|     @cachedInlineCallbacks() | ||||
|     def get_push_rules_for_user(self, user_name): | ||||
|         rows = yield self._simple_select_list( | ||||
|             table=PushRuleTable.table_name, | ||||
|  | @ -41,8 +41,7 @@ class PushRuleStore(SQLBaseStore): | |||
| 
 | ||||
|         defer.returnValue(rows) | ||||
| 
 | ||||
|     @cached() | ||||
|     @defer.inlineCallbacks | ||||
|     @cachedInlineCallbacks() | ||||
|     def get_push_rules_enabled_for_user(self, user_name): | ||||
|         results = yield self._simple_select_list( | ||||
|             table=PushRuleEnableTable.table_name, | ||||
|  | @ -153,11 +152,11 @@ class PushRuleStore(SQLBaseStore): | |||
|             txn.execute(sql, (user_name, priority_class, new_rule_priority)) | ||||
| 
 | ||||
|         txn.call_after( | ||||
|             self.get_push_rules_for_user.invalidate, user_name | ||||
|             self.get_push_rules_for_user.invalidate, (user_name,) | ||||
|         ) | ||||
| 
 | ||||
|         txn.call_after( | ||||
|             self.get_push_rules_enabled_for_user.invalidate, user_name | ||||
|             self.get_push_rules_enabled_for_user.invalidate, (user_name,) | ||||
|         ) | ||||
| 
 | ||||
|         self._simple_insert_txn( | ||||
|  | @ -189,10 +188,10 @@ class PushRuleStore(SQLBaseStore): | |||
|         new_rule['priority'] = new_prio | ||||
| 
 | ||||
|         txn.call_after( | ||||
|             self.get_push_rules_for_user.invalidate, user_name | ||||
|             self.get_push_rules_for_user.invalidate, (user_name,) | ||||
|         ) | ||||
|         txn.call_after( | ||||
|             self.get_push_rules_enabled_for_user.invalidate, user_name | ||||
|             self.get_push_rules_enabled_for_user.invalidate, (user_name,) | ||||
|         ) | ||||
| 
 | ||||
|         self._simple_insert_txn( | ||||
|  | @ -218,8 +217,8 @@ class PushRuleStore(SQLBaseStore): | |||
|             desc="delete_push_rule", | ||||
|         ) | ||||
| 
 | ||||
|         self.get_push_rules_for_user.invalidate(user_name) | ||||
|         self.get_push_rules_enabled_for_user.invalidate(user_name) | ||||
|         self.get_push_rules_for_user.invalidate((user_name,)) | ||||
|         self.get_push_rules_enabled_for_user.invalidate((user_name,)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def set_push_rule_enabled(self, user_name, rule_id, enabled): | ||||
|  | @ -240,10 +239,10 @@ class PushRuleStore(SQLBaseStore): | |||
|             {'id': new_id}, | ||||
|         ) | ||||
|         txn.call_after( | ||||
|             self.get_push_rules_for_user.invalidate, user_name | ||||
|             self.get_push_rules_for_user.invalidate, (user_name,) | ||||
|         ) | ||||
|         txn.call_after( | ||||
|             self.get_push_rules_enabled_for_user.invalidate, user_name | ||||
|             self.get_push_rules_enabled_for_user.invalidate, (user_name,) | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -13,12 +13,12 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList | ||||
| from synapse.util.caches import cache_counter, caches_by_name | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.util import unwrapFirstError | ||||
| 
 | ||||
| from blist import sorteddict | ||||
| import logging | ||||
| import ujson as json | ||||
|  | @ -53,19 +53,13 @@ class ReceiptsStore(SQLBaseStore): | |||
|                 self, room_ids, from_key | ||||
|             ) | ||||
| 
 | ||||
|         results = yield defer.gatherResults( | ||||
|             [ | ||||
|                 self.get_linearized_receipts_for_room( | ||||
|                     room_id, to_key, from_key=from_key | ||||
|                 ) | ||||
|                 for room_id in room_ids | ||||
|             ], | ||||
|             consumeErrors=True, | ||||
|         ).addErrback(unwrapFirstError) | ||||
|         results = yield self._get_linearized_receipts_for_rooms( | ||||
|             room_ids, to_key, from_key=from_key | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue([ev for res in results for ev in res]) | ||||
|         defer.returnValue([ev for res in results.values() for ev in res]) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     @cachedInlineCallbacks(num_args=3, max_entries=5000) | ||||
|     def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): | ||||
|         """Get receipts for a single room for sending to clients. | ||||
| 
 | ||||
|  | @ -125,11 +119,70 @@ class ReceiptsStore(SQLBaseStore): | |||
|             "content": content, | ||||
|         }]) | ||||
| 
 | ||||
|     @cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids", | ||||
|                 num_args=3, inlineCallbacks=True) | ||||
|     def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): | ||||
|         if not room_ids: | ||||
|             defer.returnValue({}) | ||||
| 
 | ||||
|         def f(txn): | ||||
|             if from_key: | ||||
|                 sql = ( | ||||
|                     "SELECT * FROM receipts_linearized WHERE" | ||||
|                     " room_id IN (%s) AND stream_id > ? AND stream_id <= ?" | ||||
|                 ) % ( | ||||
|                     ",".join(["?"] * len(room_ids)) | ||||
|                 ) | ||||
|                 args = list(room_ids) | ||||
|                 args.extend([from_key, to_key]) | ||||
| 
 | ||||
|                 txn.execute(sql, args) | ||||
|             else: | ||||
|                 sql = ( | ||||
|                     "SELECT * FROM receipts_linearized WHERE" | ||||
|                     " room_id IN (%s) AND stream_id <= ?" | ||||
|                 ) % ( | ||||
|                     ",".join(["?"] * len(room_ids)) | ||||
|                 ) | ||||
| 
 | ||||
|                 args = list(room_ids) | ||||
|                 args.append(to_key) | ||||
| 
 | ||||
|                 txn.execute(sql, args) | ||||
| 
 | ||||
|             return self.cursor_to_dict(txn) | ||||
| 
 | ||||
|         txn_results = yield self.runInteraction( | ||||
|             "_get_linearized_receipts_for_rooms", f | ||||
|         ) | ||||
| 
 | ||||
|         results = {} | ||||
|         for row in txn_results: | ||||
|             # We want a single event per room, since we want to batch the | ||||
|             # receipts by room, event and type. | ||||
|             room_event = results.setdefault(row["room_id"], { | ||||
|                 "type": "m.receipt", | ||||
|                 "room_id": row["room_id"], | ||||
|                 "content": {}, | ||||
|             }) | ||||
| 
 | ||||
|             # The content is of the form: | ||||
|             # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. } | ||||
|             event_entry = room_event["content"].setdefault(row["event_id"], {}) | ||||
|             receipt_type = event_entry.setdefault(row["receipt_type"], {}) | ||||
| 
 | ||||
|             receipt_type[row["user_id"]] = json.loads(row["data"]) | ||||
| 
 | ||||
|         results = { | ||||
|             room_id: [results[room_id]] if room_id in results else [] | ||||
|             for room_id in room_ids | ||||
|         } | ||||
|         defer.returnValue(results) | ||||
| 
 | ||||
|     def get_max_receipt_stream_id(self): | ||||
|         return self._receipts_id_gen.get_max_token(self) | ||||
| 
 | ||||
|     @cached | ||||
|     @defer.inlineCallbacks | ||||
|     @cachedInlineCallbacks() | ||||
|     def get_graph_receipts_for_room(self, room_id): | ||||
|         """Get receipts for sending to remote servers. | ||||
|         """ | ||||
|  | @ -305,6 +358,8 @@ class _RoomStreamChangeCache(object): | |||
|         self._room_to_key = {} | ||||
|         self._cache = sorteddict() | ||||
|         self._earliest_key = None | ||||
|         self.name = "ReceiptsRoomChangeCache" | ||||
|         caches_by_name[self.name] = self._cache | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_rooms_changed(self, store, room_ids, key): | ||||
|  | @ -318,8 +373,11 @@ class _RoomStreamChangeCache(object): | |||
|             result = set( | ||||
|                 self._cache[k] for k in keys[i:] | ||||
|             ).intersection(room_ids) | ||||
| 
 | ||||
|             cache_counter.inc_hits(self.name) | ||||
|         else: | ||||
|             result = room_ids | ||||
|             cache_counter.inc_misses(self.name) | ||||
| 
 | ||||
|         defer.returnValue(result) | ||||
| 
 | ||||
|  |  | |||
|  | @ -17,7 +17,8 @@ from twisted.internet import defer | |||
| 
 | ||||
| from synapse.api.errors import StoreError, Codes | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached | ||||
| 
 | ||||
| 
 | ||||
| class RegistrationStore(SQLBaseStore): | ||||
|  | @ -111,16 +112,16 @@ class RegistrationStore(SQLBaseStore): | |||
|         }) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def user_delete_access_tokens_apart_from(self, user_id, token_id): | ||||
|     def user_delete_access_tokens(self, user_id): | ||||
|         yield self.runInteraction( | ||||
|             "user_delete_access_tokens_apart_from", | ||||
|             self._user_delete_access_tokens_apart_from, user_id, token_id | ||||
|             "user_delete_access_tokens", | ||||
|             self._user_delete_access_tokens, user_id | ||||
|         ) | ||||
| 
 | ||||
|     def _user_delete_access_tokens_apart_from(self, txn, user_id, token_id): | ||||
|     def _user_delete_access_tokens(self, txn, user_id): | ||||
|         txn.execute( | ||||
|             "DELETE FROM access_tokens WHERE user_id = ? AND id != ?", | ||||
|             (user_id, token_id) | ||||
|             "DELETE FROM access_tokens WHERE user_id = ?", | ||||
|             (user_id, ) | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  | @ -131,7 +132,7 @@ class RegistrationStore(SQLBaseStore): | |||
|             user_id | ||||
|         ) | ||||
|         for r in rows: | ||||
|             self.get_user_by_token.invalidate(r) | ||||
|             self.get_user_by_token.invalidate((r,)) | ||||
| 
 | ||||
|     @cached() | ||||
|     def get_user_by_token(self, token): | ||||
|  |  | |||
|  | @ -17,7 +17,8 @@ from twisted.internet import defer | |||
| 
 | ||||
| from synapse.api.errors import StoreError | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||
| 
 | ||||
| import collections | ||||
| import logging | ||||
|  | @ -186,8 +187,7 @@ class RoomStore(SQLBaseStore): | |||
|                 } | ||||
|             ) | ||||
| 
 | ||||
|     @cached() | ||||
|     @defer.inlineCallbacks | ||||
|     @cachedInlineCallbacks() | ||||
|     def get_room_name_and_aliases(self, room_id): | ||||
|         def f(txn): | ||||
|             sql = ( | ||||
|  |  | |||
|  | @ -17,7 +17,8 @@ from twisted.internet import defer | |||
| 
 | ||||
| from collections import namedtuple | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached | ||||
| 
 | ||||
| from synapse.api.constants import Membership | ||||
| from synapse.types import UserID | ||||
|  | @ -54,9 +55,9 @@ class RoomMemberStore(SQLBaseStore): | |||
|         ) | ||||
| 
 | ||||
|         for event in events: | ||||
|             txn.call_after(self.get_rooms_for_user.invalidate, event.state_key) | ||||
|             txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) | ||||
|             txn.call_after(self.get_users_in_room.invalidate, event.room_id) | ||||
|             txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) | ||||
|             txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) | ||||
|             txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) | ||||
| 
 | ||||
|     def get_room_member(self, user_id, room_id): | ||||
|         """Retrieve the current state of a room member. | ||||
|  | @ -78,7 +79,7 @@ class RoomMemberStore(SQLBaseStore): | |||
|             lambda events: events[0] if events else None | ||||
|         ) | ||||
| 
 | ||||
|     @cached() | ||||
|     @cached(max_entries=5000) | ||||
|     def get_users_in_room(self, room_id): | ||||
|         def f(txn): | ||||
| 
 | ||||
|  | @ -154,7 +155,7 @@ class RoomMemberStore(SQLBaseStore): | |||
|             RoomsForUser(**r) for r in self.cursor_to_dict(txn) | ||||
|         ] | ||||
| 
 | ||||
|     @cached() | ||||
|     @cached(max_entries=5000) | ||||
|     def get_joined_hosts_for_room(self, room_id): | ||||
|         return self.runInteraction( | ||||
|             "get_joined_hosts_for_room", | ||||
|  |  | |||
|  | @ -0,0 +1,18 @@ | |||
| /* Copyright 2015 OpenMarket Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *    http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( | ||||
|     room_id, stream_id | ||||
| ); | ||||
|  | @ -13,7 +13,10 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import ( | ||||
|     cached, cachedInlineCallbacks, cachedList | ||||
| ) | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
|  | @ -44,60 +47,25 @@ class StateStore(SQLBaseStore): | |||
|     """ | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_state_groups(self, event_ids): | ||||
|     def get_state_groups(self, room_id, event_ids): | ||||
|         """ Get the state groups for the given list of event_ids | ||||
| 
 | ||||
|         The return value is a dict mapping group names to lists of events. | ||||
|         """ | ||||
|         if not event_ids: | ||||
|             defer.returnValue({}) | ||||
| 
 | ||||
|         def f(txn): | ||||
|             groups = set() | ||||
|             for event_id in event_ids: | ||||
|                 group = self._simple_select_one_onecol_txn( | ||||
|                     txn, | ||||
|                     table="event_to_state_groups", | ||||
|                     keyvalues={"event_id": event_id}, | ||||
|                     retcol="state_group", | ||||
|                     allow_none=True, | ||||
|                 ) | ||||
|                 if group: | ||||
|                     groups.add(group) | ||||
| 
 | ||||
|             res = {} | ||||
|             for group in groups: | ||||
|                 state_ids = self._simple_select_onecol_txn( | ||||
|                     txn, | ||||
|                     table="state_groups_state", | ||||
|                     keyvalues={"state_group": group}, | ||||
|                     retcol="event_id", | ||||
|                 ) | ||||
| 
 | ||||
|                 res[group] = state_ids | ||||
| 
 | ||||
|             return res | ||||
| 
 | ||||
|         states = yield self.runInteraction( | ||||
|             "get_state_groups", | ||||
|             f, | ||||
|         event_to_groups = yield self._get_state_group_for_events( | ||||
|             room_id, event_ids, | ||||
|         ) | ||||
| 
 | ||||
|         state_list = yield defer.gatherResults( | ||||
|             [ | ||||
|                 self._fetch_events_for_group(group, vals) | ||||
|                 for group, vals in states.items() | ||||
|             ], | ||||
|             consumeErrors=True, | ||||
|         ) | ||||
|         groups = set(event_to_groups.values()) | ||||
|         group_to_state = yield self._get_state_for_groups(groups) | ||||
| 
 | ||||
|         defer.returnValue(dict(state_list)) | ||||
| 
 | ||||
|     @cached(num_args=1) | ||||
|     def _fetch_events_for_group(self, key, events): | ||||
|         return self._get_events( | ||||
|             events, get_prev_content=False | ||||
|         ).addCallback( | ||||
|             lambda evs: (key, evs) | ||||
|         ) | ||||
|         defer.returnValue({ | ||||
|             group: state_map.values() | ||||
|             for group, state_map in group_to_state.items() | ||||
|         }) | ||||
| 
 | ||||
|     def _store_state_groups_txn(self, txn, event, context): | ||||
|         return self._store_mult_state_groups_txn(txn, [(event, context)]) | ||||
|  | @ -189,8 +157,7 @@ class StateStore(SQLBaseStore): | |||
|         events = yield self._get_events(event_ids, get_prev_content=False) | ||||
|         defer.returnValue(events) | ||||
| 
 | ||||
|     @cached(num_args=3) | ||||
|     @defer.inlineCallbacks | ||||
|     @cachedInlineCallbacks(num_args=3) | ||||
|     def get_current_state_for_key(self, room_id, event_type, state_key): | ||||
|         def f(txn): | ||||
|             sql = ( | ||||
|  | @ -206,64 +173,254 @@ class StateStore(SQLBaseStore): | |||
|         events = yield self._get_events(event_ids, get_prev_content=False) | ||||
|         defer.returnValue(events) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_state_for_events(self, room_id, event_ids): | ||||
|     def _get_state_groups_from_groups(self, groups_and_types): | ||||
|         """Returns dictionary state_group -> state event ids | ||||
| 
 | ||||
|         Args: | ||||
|             groups_and_types (list): list of 2-tuple (`group`, `types`) | ||||
|         """ | ||||
|         def f(txn): | ||||
|             groups = set() | ||||
|             event_to_group = {} | ||||
|             for event_id in event_ids: | ||||
|                 # TODO: Remove this loop. | ||||
|                 group = self._simple_select_one_onecol_txn( | ||||
|                     txn, | ||||
|                     table="event_to_state_groups", | ||||
|                     keyvalues={"event_id": event_id}, | ||||
|                     retcol="state_group", | ||||
|                     allow_none=True, | ||||
|                 ) | ||||
|                 if group: | ||||
|                     event_to_group[event_id] = group | ||||
|                     groups.add(group) | ||||
|             results = {} | ||||
|             for group, types in groups_and_types: | ||||
|                 if types is not None: | ||||
|                     where_clause = "AND (%s)" % ( | ||||
|                         " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), | ||||
|                     ) | ||||
|                 else: | ||||
|                     where_clause = "" | ||||
| 
 | ||||
|             group_to_state_ids = {} | ||||
|             for group in groups: | ||||
|                 state_ids = self._simple_select_onecol_txn( | ||||
|                     txn, | ||||
|                     table="state_groups_state", | ||||
|                     keyvalues={"state_group": group}, | ||||
|                     retcol="event_id", | ||||
|                 ) | ||||
|                 sql = ( | ||||
|                     "SELECT event_id FROM state_groups_state WHERE" | ||||
|                     " state_group = ? %s" | ||||
|                 ) % (where_clause,) | ||||
| 
 | ||||
|                 group_to_state_ids[group] = state_ids | ||||
|                 args = [group] | ||||
|                 if types is not None: | ||||
|                     args.extend([i for typ in types for i in typ]) | ||||
| 
 | ||||
|             return event_to_group, group_to_state_ids | ||||
|                 txn.execute(sql, args) | ||||
| 
 | ||||
|         res = yield self.runInteraction( | ||||
|             "annotate_events_with_state_groups", | ||||
|                 results[group] = [r[0] for r in txn.fetchall()] | ||||
| 
 | ||||
|             return results | ||||
| 
 | ||||
|         return self.runInteraction( | ||||
|             "_get_state_groups_from_groups", | ||||
|             f, | ||||
|         ) | ||||
| 
 | ||||
|         event_to_group, group_to_state_ids = res | ||||
|     @defer.inlineCallbacks | ||||
|     def get_state_for_events(self, room_id, event_ids, types): | ||||
|         """Given a list of event_ids and type tuples, return a list of state | ||||
|         dicts for each event. The state dicts will only have the type/state_keys | ||||
|         that are in the `types` list. | ||||
| 
 | ||||
|         state_list = yield defer.gatherResults( | ||||
|             [ | ||||
|                 self._fetch_events_for_group(group, vals) | ||||
|                 for group, vals in group_to_state_ids.items() | ||||
|             ], | ||||
|             consumeErrors=True, | ||||
|         Args: | ||||
|             room_id (str) | ||||
|             event_ids (list) | ||||
|             types (list): List of (type, state_key) tuples which are used to | ||||
|                 filter the state fetched. `state_key` may be None, which matches | ||||
|                 any `state_key` | ||||
| 
 | ||||
|         Returns: | ||||
|             deferred: A list of dicts corresponding to the event_ids given. | ||||
|             The dicts are mappings from (type, state_key) -> state_events | ||||
|         """ | ||||
|         event_to_groups = yield self._get_state_group_for_events( | ||||
|             room_id, event_ids, | ||||
|         ) | ||||
| 
 | ||||
|         state_dict = { | ||||
|             group: { | ||||
|                 (ev.type, ev.state_key): ev | ||||
|                 for ev in state | ||||
|             } | ||||
|             for group, state in state_list | ||||
|         groups = set(event_to_groups.values()) | ||||
|         group_to_state = yield self._get_state_for_groups(groups, types) | ||||
| 
 | ||||
|         event_to_state = { | ||||
|             event_id: group_to_state[group] | ||||
|             for event_id, group in event_to_groups.items() | ||||
|         } | ||||
| 
 | ||||
|         defer.returnValue([ | ||||
|             state_dict.get(event_to_group.get(event, None), None) | ||||
|             for event in event_ids | ||||
|         ]) | ||||
|         defer.returnValue({event: event_to_state[event] for event in event_ids}) | ||||
| 
 | ||||
|     @cached(num_args=2, lru=True, max_entries=10000) | ||||
|     def _get_state_group_for_event(self, room_id, event_id): | ||||
|         return self._simple_select_one_onecol( | ||||
|             table="event_to_state_groups", | ||||
|             keyvalues={ | ||||
|                 "event_id": event_id, | ||||
|             }, | ||||
|             retcol="state_group", | ||||
|             allow_none=True, | ||||
|             desc="_get_state_group_for_event", | ||||
|         ) | ||||
| 
 | ||||
|     @cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids", | ||||
|                 num_args=2) | ||||
|     def _get_state_group_for_events(self, room_id, event_ids): | ||||
|         """Returns mapping event_id -> state_group | ||||
|         """ | ||||
|         def f(txn): | ||||
|             results = {} | ||||
|             for event_id in event_ids: | ||||
|                 results[event_id] = self._simple_select_one_onecol_txn( | ||||
|                     txn, | ||||
|                     table="event_to_state_groups", | ||||
|                     keyvalues={ | ||||
|                         "event_id": event_id, | ||||
|                     }, | ||||
|                     retcol="state_group", | ||||
|                     allow_none=True, | ||||
|                 ) | ||||
| 
 | ||||
|             return results | ||||
| 
 | ||||
|         return self.runInteraction("_get_state_group_for_events", f) | ||||
| 
 | ||||
|     def _get_some_state_from_cache(self, group, types): | ||||
|         """Checks if group is in cache. See `_get_state_for_groups` | ||||
| 
 | ||||
|         Returns 3-tuple (`state_dict`, `missing_types`, `got_all`). | ||||
|         `missing_types` is the list of types that aren't in the cache for that | ||||
|         group. `got_all` is a bool indicating if we successfully retrieved all | ||||
|         requests state from the cache, if False we need to query the DB for the | ||||
|         missing state. | ||||
| 
 | ||||
|         Args: | ||||
|             group: The state group to lookup | ||||
|             types (list): List of 2-tuples of the form (`type`, `state_key`), | ||||
|                 where a `state_key` of `None` matches all state_keys for the | ||||
|                 `type`. | ||||
|         """ | ||||
|         is_all, state_dict = self._state_group_cache.get(group) | ||||
| 
 | ||||
|         type_to_key = {} | ||||
|         missing_types = set() | ||||
|         for typ, state_key in types: | ||||
|             if state_key is None: | ||||
|                 type_to_key[typ] = None | ||||
|                 missing_types.add((typ, state_key)) | ||||
|             else: | ||||
|                 if type_to_key.get(typ, object()) is not None: | ||||
|                     type_to_key.setdefault(typ, set()).add(state_key) | ||||
| 
 | ||||
|                 if (typ, state_key) not in state_dict: | ||||
|                     missing_types.add((typ, state_key)) | ||||
| 
 | ||||
|         sentinel = object() | ||||
| 
 | ||||
|         def include(typ, state_key): | ||||
|             valid_state_keys = type_to_key.get(typ, sentinel) | ||||
|             if valid_state_keys is sentinel: | ||||
|                 return False | ||||
|             if valid_state_keys is None: | ||||
|                 return True | ||||
|             if state_key in valid_state_keys: | ||||
|                 return True | ||||
|             return False | ||||
| 
 | ||||
|         got_all = not (missing_types or types is None) | ||||
| 
 | ||||
|         return { | ||||
|             k: v for k, v in state_dict.items() | ||||
|             if include(k[0], k[1]) | ||||
|         }, missing_types, got_all | ||||
| 
 | ||||
|     def _get_all_state_from_cache(self, group): | ||||
|         """Checks if group is in cache. See `_get_state_for_groups` | ||||
| 
 | ||||
|         Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool | ||||
|         indicating if we successfully retrieved all requests state from the | ||||
|         cache, if False we need to query the DB for the missing state. | ||||
| 
 | ||||
|         Args: | ||||
|             group: The state group to lookup | ||||
|         """ | ||||
|         is_all, state_dict = self._state_group_cache.get(group) | ||||
|         return state_dict, is_all | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _get_state_for_groups(self, groups, types=None): | ||||
|         """Given list of groups returns dict of group -> list of state events | ||||
|         with matching types. `types` is a list of `(type, state_key)`, where | ||||
|         a `state_key` of None matches all state_keys. If `types` is None then | ||||
|         all events are returned. | ||||
|         """ | ||||
|         results = {} | ||||
|         missing_groups_and_types = [] | ||||
|         if types is not None: | ||||
|             for group in set(groups): | ||||
|                 state_dict, missing_types, got_all = self._get_some_state_from_cache( | ||||
|                     group, types | ||||
|                 ) | ||||
|                 results[group] = state_dict | ||||
| 
 | ||||
|                 if not got_all: | ||||
|                     missing_groups_and_types.append((group, missing_types)) | ||||
|         else: | ||||
|             for group in set(groups): | ||||
|                 state_dict, got_all = self._get_all_state_from_cache( | ||||
|                     group | ||||
|                 ) | ||||
|                 results[group] = state_dict | ||||
| 
 | ||||
|                 if not got_all: | ||||
|                     missing_groups_and_types.append((group, None)) | ||||
| 
 | ||||
|         if not missing_groups_and_types: | ||||
|             defer.returnValue({ | ||||
|                 group: { | ||||
|                     type_tuple: event | ||||
|                     for type_tuple, event in state.items() | ||||
|                     if event | ||||
|                 } | ||||
|                 for group, state in results.items() | ||||
|             }) | ||||
| 
 | ||||
|         # Okay, so we have some missing_types, lets fetch them. | ||||
|         cache_seq_num = self._state_group_cache.sequence | ||||
| 
 | ||||
|         group_state_dict = yield self._get_state_groups_from_groups( | ||||
|             missing_groups_and_types | ||||
|         ) | ||||
| 
 | ||||
|         state_events = yield self._get_events( | ||||
|             [e_id for l in group_state_dict.values() for e_id in l], | ||||
|             get_prev_content=False | ||||
|         ) | ||||
| 
 | ||||
|         state_events = {e.event_id: e for e in state_events} | ||||
| 
 | ||||
|         # Now we want to update the cache with all the things we fetched | ||||
|         # from the database. | ||||
|         for group, state_ids in group_state_dict.items(): | ||||
|             if types: | ||||
|                 # We delibrately put key -> None mappings into the cache to | ||||
|                 # cache absence of the key, on the assumption that if we've | ||||
|                 # explicitly asked for some types then we will probably ask | ||||
|                 # for them again. | ||||
|                 state_dict = {key: None for key in types} | ||||
|                 state_dict.update(results[group]) | ||||
|                 results[group] = state_dict | ||||
|             else: | ||||
|                 state_dict = results[group] | ||||
| 
 | ||||
|             for event_id in state_ids: | ||||
|                 state_event = state_events[event_id] | ||||
|                 state_dict[(state_event.type, state_event.state_key)] = state_event | ||||
| 
 | ||||
|             self._state_group_cache.update( | ||||
|                 cache_seq_num, | ||||
|                 key=group, | ||||
|                 value=state_dict, | ||||
|                 full=(types is None), | ||||
|             ) | ||||
| 
 | ||||
|         # Remove all the entries with None values. The None values were just | ||||
|         # used for bookkeeping in the cache. | ||||
|         for group, state_dict in results.items(): | ||||
|             results[group] = { | ||||
|                 key: event for key, event in state_dict.items() if event | ||||
|             } | ||||
| 
 | ||||
|         defer.returnValue(results) | ||||
| 
 | ||||
| 
 | ||||
| def _make_group_id(clock): | ||||
|  |  | |||
|  | @ -36,6 +36,7 @@ what sort order was used: | |||
| from twisted.internet import defer | ||||
| 
 | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cachedInlineCallbacks | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.types import RoomStreamToken | ||||
| from synapse.util.logutils import log_function | ||||
|  | @ -299,9 +300,8 @@ class StreamStore(SQLBaseStore): | |||
| 
 | ||||
|         defer.returnValue((events, token)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_recent_events_for_room(self, room_id, limit, end_token, | ||||
|                                    with_feedback=False, from_token=None): | ||||
|     @cachedInlineCallbacks(num_args=4) | ||||
|     def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): | ||||
|         # TODO (erikj): Handle compressed feedback | ||||
| 
 | ||||
|         end_token = RoomStreamToken.parse_stream_token(end_token) | ||||
|  |  | |||
|  | @ -13,7 +13,8 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import SQLBaseStore, cached | ||||
| from ._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached | ||||
| 
 | ||||
| from collections import namedtuple | ||||
| 
 | ||||
|  |  | |||
|  | @ -178,7 +178,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): | |||
| 
 | ||||
|     Live tokens start with an "s" followed by the "stream_ordering" id of the | ||||
|     event it comes after. Historic tokens start with a "t" followed by the | ||||
|     "topological_ordering" id of the event it comes after, follewed by "-", | ||||
|     "topological_ordering" id of the event it comes after, followed by "-", | ||||
|     followed by the "stream_ordering" id of the event it comes after. | ||||
|     """ | ||||
|     __slots__ = [] | ||||
|  | @ -211,4 +211,5 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): | |||
|             return "s%d" % (self.stream,) | ||||
| 
 | ||||
| 
 | ||||
| # token_id is the primary key ID of the access token, not the access token itself. | ||||
| ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) | ||||
|  |  | |||
|  | @ -51,7 +51,7 @@ class ObservableDeferred(object): | |||
|         object.__setattr__(self, "_observers", set()) | ||||
| 
 | ||||
|         def callback(r): | ||||
|             self._result = (True, r) | ||||
|             object.__setattr__(self, "_result", (True, r)) | ||||
|             while self._observers: | ||||
|                 try: | ||||
|                     self._observers.pop().callback(r) | ||||
|  | @ -60,7 +60,7 @@ class ObservableDeferred(object): | |||
|             return r | ||||
| 
 | ||||
|         def errback(f): | ||||
|             self._result = (False, f) | ||||
|             object.__setattr__(self, "_result", (False, f)) | ||||
|             while self._observers: | ||||
|                 try: | ||||
|                     self._observers.pop().errback(f) | ||||
|  | @ -97,3 +97,8 @@ class ObservableDeferred(object): | |||
| 
 | ||||
|     def __setattr__(self, name, value): | ||||
|         setattr(self._deferred, name, value) | ||||
| 
 | ||||
|     def __repr__(self): | ||||
|         return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % ( | ||||
|             id(self), self._result, self._deferred, | ||||
|         ) | ||||
|  |  | |||
|  | @ -0,0 +1,27 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import synapse.metrics | ||||
| 
 | ||||
| DEBUG_CACHES = False | ||||
| 
 | ||||
| metrics = synapse.metrics.get_metrics_for("synapse.util.caches") | ||||
| 
 | ||||
| caches_by_name = {} | ||||
| cache_counter = metrics.register_cache( | ||||
|     "cache", | ||||
|     lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()}, | ||||
|     labels=["name"], | ||||
| ) | ||||
|  | @ -0,0 +1,377 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import logging | ||||
| 
 | ||||
| from synapse.util.async import ObservableDeferred | ||||
| from synapse.util import unwrapFirstError | ||||
| from synapse.util.caches.lrucache import LruCache | ||||
| 
 | ||||
| from . import caches_by_name, DEBUG_CACHES, cache_counter | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from collections import OrderedDict | ||||
| 
 | ||||
| import functools | ||||
| import inspect | ||||
| import threading | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| _CacheSentinel = object() | ||||
| 
 | ||||
| 
 | ||||
| class Cache(object): | ||||
| 
 | ||||
|     def __init__(self, name, max_entries=1000, keylen=1, lru=True): | ||||
|         if lru: | ||||
|             self.cache = LruCache(max_size=max_entries) | ||||
|             self.max_entries = None | ||||
|         else: | ||||
|             self.cache = OrderedDict() | ||||
|             self.max_entries = max_entries | ||||
| 
 | ||||
|         self.name = name | ||||
|         self.keylen = keylen | ||||
|         self.sequence = 0 | ||||
|         self.thread = None | ||||
|         caches_by_name[name] = self.cache | ||||
| 
 | ||||
|     def check_thread(self): | ||||
|         expected_thread = self.thread | ||||
|         if expected_thread is None: | ||||
|             self.thread = threading.current_thread() | ||||
|         else: | ||||
|             if expected_thread is not threading.current_thread(): | ||||
|                 raise ValueError( | ||||
|                     "Cache objects can only be accessed from the main thread" | ||||
|                 ) | ||||
| 
 | ||||
|     def get(self, key, default=_CacheSentinel): | ||||
|         val = self.cache.get(key, _CacheSentinel) | ||||
|         if val is not _CacheSentinel: | ||||
|             cache_counter.inc_hits(self.name) | ||||
|             return val | ||||
| 
 | ||||
|         cache_counter.inc_misses(self.name) | ||||
| 
 | ||||
|         if default is _CacheSentinel: | ||||
|             raise KeyError() | ||||
|         else: | ||||
|             return default | ||||
| 
 | ||||
|     def update(self, sequence, key, value): | ||||
|         self.check_thread() | ||||
|         if self.sequence == sequence: | ||||
|             # Only update the cache if the caches sequence number matches the | ||||
|             # number that the cache had before the SELECT was started (SYN-369) | ||||
|             self.prefill(key, value) | ||||
| 
 | ||||
|     def prefill(self, key, value): | ||||
|         if self.max_entries is not None: | ||||
|             while len(self.cache) >= self.max_entries: | ||||
|                 self.cache.popitem(last=False) | ||||
| 
 | ||||
|         self.cache[key] = value | ||||
| 
 | ||||
|     def invalidate(self, key): | ||||
|         self.check_thread() | ||||
|         if not isinstance(key, tuple): | ||||
|             raise TypeError( | ||||
|                 "The cache key must be a tuple not %r" % (type(key),) | ||||
|             ) | ||||
| 
 | ||||
|         # Increment the sequence number so that any SELECT statements that | ||||
|         # raced with the INSERT don't update the cache (SYN-369) | ||||
|         self.sequence += 1 | ||||
|         self.cache.pop(key, None) | ||||
| 
 | ||||
|     def invalidate_all(self): | ||||
|         self.check_thread() | ||||
|         self.sequence += 1 | ||||
|         self.cache.clear() | ||||
| 
 | ||||
| 
 | ||||
| class CacheDescriptor(object): | ||||
|     """ A method decorator that applies a memoizing cache around the function. | ||||
| 
 | ||||
|     This caches deferreds, rather than the results themselves. Deferreds that | ||||
|     fail are removed from the cache. | ||||
| 
 | ||||
|     The function is presumed to take zero or more arguments, which are used in | ||||
|     a tuple as the key for the cache. Hits are served directly from the cache; | ||||
|     misses use the function body to generate the value. | ||||
| 
 | ||||
|     The wrapped function has an additional member, a callable called | ||||
|     "invalidate". This can be used to remove individual entries from the cache. | ||||
| 
 | ||||
|     The wrapped function has another additional callable, called "prefill", | ||||
|     which can be used to insert values into the cache specifically, without | ||||
|     calling the calculation function. | ||||
|     """ | ||||
|     def __init__(self, orig, max_entries=1000, num_args=1, lru=True, | ||||
|                  inlineCallbacks=False): | ||||
|         self.orig = orig | ||||
| 
 | ||||
|         if inlineCallbacks: | ||||
|             self.function_to_call = defer.inlineCallbacks(orig) | ||||
|         else: | ||||
|             self.function_to_call = orig | ||||
| 
 | ||||
|         self.max_entries = max_entries | ||||
|         self.num_args = num_args | ||||
|         self.lru = lru | ||||
| 
 | ||||
|         self.arg_names = inspect.getargspec(orig).args[1:num_args+1] | ||||
| 
 | ||||
|         if len(self.arg_names) < self.num_args: | ||||
|             raise Exception( | ||||
|                 "Not enough explicit positional arguments to key off of for %r." | ||||
|                 " (@cached cannot key off of *args or **kwars)" | ||||
|                 % (orig.__name__,) | ||||
|             ) | ||||
| 
 | ||||
|         self.cache = Cache( | ||||
|             name=self.orig.__name__, | ||||
|             max_entries=self.max_entries, | ||||
|             keylen=self.num_args, | ||||
|             lru=self.lru, | ||||
|         ) | ||||
| 
 | ||||
|     def __get__(self, obj, objtype=None): | ||||
| 
 | ||||
|         @functools.wraps(self.orig) | ||||
|         def wrapped(*args, **kwargs): | ||||
|             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) | ||||
|             cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) | ||||
|             try: | ||||
|                 cached_result_d = self.cache.get(cache_key) | ||||
| 
 | ||||
|                 observer = cached_result_d.observe() | ||||
|                 if DEBUG_CACHES: | ||||
|                     @defer.inlineCallbacks | ||||
|                     def check_result(cached_result): | ||||
|                         actual_result = yield self.function_to_call(obj, *args, **kwargs) | ||||
|                         if actual_result != cached_result: | ||||
|                             logger.error( | ||||
|                                 "Stale cache entry %s%r: cached: %r, actual %r", | ||||
|                                 self.orig.__name__, cache_key, | ||||
|                                 cached_result, actual_result, | ||||
|                             ) | ||||
|                             raise ValueError("Stale cache entry") | ||||
|                         defer.returnValue(cached_result) | ||||
|                     observer.addCallback(check_result) | ||||
| 
 | ||||
|                 return observer | ||||
|             except KeyError: | ||||
|                 # Get the sequence number of the cache before reading from the | ||||
|                 # database so that we can tell if the cache is invalidated | ||||
|                 # while the SELECT is executing (SYN-369) | ||||
|                 sequence = self.cache.sequence | ||||
| 
 | ||||
|                 ret = defer.maybeDeferred( | ||||
|                     self.function_to_call, | ||||
|                     obj, *args, **kwargs | ||||
|                 ) | ||||
| 
 | ||||
|                 def onErr(f): | ||||
|                     self.cache.invalidate(cache_key) | ||||
|                     return f | ||||
| 
 | ||||
|                 ret.addErrback(onErr) | ||||
| 
 | ||||
|                 ret = ObservableDeferred(ret, consumeErrors=True) | ||||
|                 self.cache.update(sequence, cache_key, ret) | ||||
| 
 | ||||
|                 return ret.observe() | ||||
| 
 | ||||
|         wrapped.invalidate = self.cache.invalidate | ||||
|         wrapped.invalidate_all = self.cache.invalidate_all | ||||
|         wrapped.prefill = self.cache.prefill | ||||
| 
 | ||||
|         obj.__dict__[self.orig.__name__] = wrapped | ||||
| 
 | ||||
|         return wrapped | ||||
| 
 | ||||
| 
 | ||||
| class CacheListDescriptor(object): | ||||
|     """Wraps an existing cache to support bulk fetching of keys. | ||||
| 
 | ||||
|     Given a list of keys it looks in the cache to find any hits, then passes | ||||
|     the list of missing keys to the wrapped fucntion. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False): | ||||
|         """ | ||||
|         Args: | ||||
|             orig (function) | ||||
|             cache (Cache) | ||||
|             list_name (str): Name of the argument which is the bulk lookup list | ||||
|             num_args (int) | ||||
|             inlineCallbacks (bool): Whether orig is a generator that should | ||||
|                 be wrapped by defer.inlineCallbacks | ||||
|         """ | ||||
|         self.orig = orig | ||||
| 
 | ||||
|         if inlineCallbacks: | ||||
|             self.function_to_call = defer.inlineCallbacks(orig) | ||||
|         else: | ||||
|             self.function_to_call = orig | ||||
| 
 | ||||
|         self.num_args = num_args | ||||
|         self.list_name = list_name | ||||
| 
 | ||||
|         self.arg_names = inspect.getargspec(orig).args[1:num_args+1] | ||||
|         self.list_pos = self.arg_names.index(self.list_name) | ||||
| 
 | ||||
|         self.cache = cache | ||||
| 
 | ||||
|         self.sentinel = object() | ||||
| 
 | ||||
|         if len(self.arg_names) < self.num_args: | ||||
|             raise Exception( | ||||
|                 "Not enough explicit positional arguments to key off of for %r." | ||||
|                 " (@cached cannot key off of *args or **kwars)" | ||||
|                 % (orig.__name__,) | ||||
|             ) | ||||
| 
 | ||||
|         if self.list_name not in self.arg_names: | ||||
|             raise Exception( | ||||
|                 "Couldn't see arguments %r for %r." | ||||
|                 % (self.list_name, cache.name,) | ||||
|             ) | ||||
| 
 | ||||
|     def __get__(self, obj, objtype=None): | ||||
| 
 | ||||
|         @functools.wraps(self.orig) | ||||
|         def wrapped(*args, **kwargs): | ||||
|             arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) | ||||
|             keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names] | ||||
|             list_args = arg_dict[self.list_name] | ||||
| 
 | ||||
|             # cached is a dict arg -> deferred, where deferred results in a | ||||
|             # 2-tuple (`arg`, `result`) | ||||
|             cached = {} | ||||
|             missing = [] | ||||
|             for arg in list_args: | ||||
|                 key = list(keyargs) | ||||
|                 key[self.list_pos] = arg | ||||
| 
 | ||||
|                 try: | ||||
|                     res = self.cache.get(tuple(key)).observe() | ||||
|                     res.addCallback(lambda r, arg: (arg, r), arg) | ||||
|                     cached[arg] = res | ||||
|                 except KeyError: | ||||
|                     missing.append(arg) | ||||
| 
 | ||||
|             if missing: | ||||
|                 sequence = self.cache.sequence | ||||
|                 args_to_call = dict(arg_dict) | ||||
|                 args_to_call[self.list_name] = missing | ||||
| 
 | ||||
|                 ret_d = defer.maybeDeferred( | ||||
|                     self.function_to_call, | ||||
|                     **args_to_call | ||||
|                 ) | ||||
| 
 | ||||
|                 ret_d = ObservableDeferred(ret_d) | ||||
| 
 | ||||
|                 # We need to create deferreds for each arg in the list so that | ||||
|                 # we can insert the new deferred into the cache. | ||||
|                 for arg in missing: | ||||
|                     observer = ret_d.observe() | ||||
|                     observer.addCallback(lambda r, arg: r.get(arg, None), arg) | ||||
| 
 | ||||
|                     observer = ObservableDeferred(observer) | ||||
| 
 | ||||
|                     key = list(keyargs) | ||||
|                     key[self.list_pos] = arg | ||||
|                     self.cache.update(sequence, tuple(key), observer) | ||||
| 
 | ||||
|                     def invalidate(f, key): | ||||
|                         self.cache.invalidate(key) | ||||
|                         return f | ||||
|                     observer.addErrback(invalidate, tuple(key)) | ||||
| 
 | ||||
|                     res = observer.observe() | ||||
|                     res.addCallback(lambda r, arg: (arg, r), arg) | ||||
| 
 | ||||
|                     cached[arg] = res | ||||
| 
 | ||||
|             return defer.gatherResults( | ||||
|                 cached.values(), | ||||
|                 consumeErrors=True, | ||||
|             ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)) | ||||
| 
 | ||||
|         obj.__dict__[self.orig.__name__] = wrapped | ||||
| 
 | ||||
|         return wrapped | ||||
| 
 | ||||
| 
 | ||||
| def cached(max_entries=1000, num_args=1, lru=True): | ||||
|     return lambda orig: CacheDescriptor( | ||||
|         orig, | ||||
|         max_entries=max_entries, | ||||
|         num_args=num_args, | ||||
|         lru=lru | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False): | ||||
|     return lambda orig: CacheDescriptor( | ||||
|         orig, | ||||
|         max_entries=max_entries, | ||||
|         num_args=num_args, | ||||
|         lru=lru, | ||||
|         inlineCallbacks=True, | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def cachedList(cache, list_name, num_args=1, inlineCallbacks=False): | ||||
|     """Creates a descriptor that wraps a function in a `CacheListDescriptor`. | ||||
| 
 | ||||
|     Used to do batch lookups for an already created cache. A single argument | ||||
|     is specified as a list that is iterated through to lookup keys in the | ||||
|     original cache. A new list consisting of the keys that weren't in the cache | ||||
|     get passed to the original function, the result of which is stored in the | ||||
|     cache. | ||||
| 
 | ||||
|     Args: | ||||
|         cache (Cache): The underlying cache to use. | ||||
|         list_name (str): The name of the argument that is the list to use to | ||||
|             do batch lookups in the cache. | ||||
|         num_args (int): Number of arguments to use as the key in the cache. | ||||
|         inlineCallbacks (bool): Should the function be wrapped in an | ||||
|             `defer.inlineCallbacks`? | ||||
| 
 | ||||
|     Example: | ||||
| 
 | ||||
|         class Example(object): | ||||
|             @cached(num_args=2) | ||||
|             def do_something(self, first_arg): | ||||
|                 ... | ||||
| 
 | ||||
|             @cachedList(do_something.cache, list_name="second_args", num_args=2) | ||||
|             def batch_do_something(self, first_arg, second_args): | ||||
|                 ... | ||||
|     """ | ||||
|     return lambda orig: CacheListDescriptor( | ||||
|         orig, | ||||
|         cache=cache, | ||||
|         list_name=list_name, | ||||
|         num_args=num_args, | ||||
|         inlineCallbacks=inlineCallbacks, | ||||
|     ) | ||||
|  | @ -0,0 +1,103 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from synapse.util.caches.lrucache import LruCache | ||||
| from collections import namedtuple | ||||
| from . import caches_by_name, cache_counter | ||||
| import threading | ||||
| import logging | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value")) | ||||
| 
 | ||||
| 
 | ||||
| class DictionaryCache(object): | ||||
|     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e. | ||||
|     fetching a subset of dictionary keys for a particular key. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, name, max_entries=1000): | ||||
|         self.cache = LruCache(max_size=max_entries) | ||||
| 
 | ||||
|         self.name = name | ||||
|         self.sequence = 0 | ||||
|         self.thread = None | ||||
|         # caches_by_name[name] = self.cache | ||||
| 
 | ||||
|         class Sentinel(object): | ||||
|             __slots__ = [] | ||||
| 
 | ||||
|         self.sentinel = Sentinel() | ||||
|         caches_by_name[name] = self.cache | ||||
| 
 | ||||
|     def check_thread(self): | ||||
|         expected_thread = self.thread | ||||
|         if expected_thread is None: | ||||
|             self.thread = threading.current_thread() | ||||
|         else: | ||||
|             if expected_thread is not threading.current_thread(): | ||||
|                 raise ValueError( | ||||
|                     "Cache objects can only be accessed from the main thread" | ||||
|                 ) | ||||
| 
 | ||||
|     def get(self, key, dict_keys=None): | ||||
|         entry = self.cache.get(key, self.sentinel) | ||||
|         if entry is not self.sentinel: | ||||
|             cache_counter.inc_hits(self.name) | ||||
| 
 | ||||
|             if dict_keys is None: | ||||
|                 return DictionaryEntry(entry.full, dict(entry.value)) | ||||
|             else: | ||||
|                 return DictionaryEntry(entry.full, { | ||||
|                     k: entry.value[k] | ||||
|                     for k in dict_keys | ||||
|                     if k in entry.value | ||||
|                 }) | ||||
| 
 | ||||
|         cache_counter.inc_misses(self.name) | ||||
|         return DictionaryEntry(False, {}) | ||||
| 
 | ||||
|     def invalidate(self, key): | ||||
|         self.check_thread() | ||||
| 
 | ||||
|         # Increment the sequence number so that any SELECT statements that | ||||
|         # raced with the INSERT don't update the cache (SYN-369) | ||||
|         self.sequence += 1 | ||||
|         self.cache.pop(key, None) | ||||
| 
 | ||||
|     def invalidate_all(self): | ||||
|         self.check_thread() | ||||
|         self.sequence += 1 | ||||
|         self.cache.clear() | ||||
| 
 | ||||
|     def update(self, sequence, key, value, full=False): | ||||
|         self.check_thread() | ||||
|         if self.sequence == sequence: | ||||
|             # Only update the cache if the caches sequence number matches the | ||||
|             # number that the cache had before the SELECT was started (SYN-369) | ||||
|             if full: | ||||
|                 self._insert(key, value) | ||||
|             else: | ||||
|                 self._update_or_insert(key, value) | ||||
| 
 | ||||
|     def _update_or_insert(self, key, value): | ||||
|         entry = self.cache.setdefault(key, DictionaryEntry(False, {})) | ||||
|         entry.value.update(value) | ||||
| 
 | ||||
|     def _insert(self, key, value): | ||||
|         self.cache[key] = DictionaryEntry(True, value) | ||||
|  | @ -17,7 +17,9 @@ | |||
| from tests import unittest | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.storage._base import Cache, cached | ||||
| from synapse.util.async import ObservableDeferred | ||||
| 
 | ||||
| from synapse.util.caches.descriptors import Cache, cached | ||||
| 
 | ||||
| 
 | ||||
| class CacheTestCase(unittest.TestCase): | ||||
|  | @ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase): | |||
|         self.assertEquals(self.cache.get("foo"), 123) | ||||
| 
 | ||||
|     def test_invalidate(self): | ||||
|         self.cache.prefill("foo", 123) | ||||
|         self.cache.invalidate("foo") | ||||
|         self.cache.prefill(("foo",), 123) | ||||
|         self.cache.invalidate(("foo",)) | ||||
| 
 | ||||
|         failed = False | ||||
|         try: | ||||
|             self.cache.get("foo") | ||||
|             self.cache.get(("foo",)) | ||||
|         except KeyError: | ||||
|             failed = True | ||||
| 
 | ||||
|  | @ -139,7 +141,7 @@ class CacheDecoratorTestCase(unittest.TestCase): | |||
| 
 | ||||
|         self.assertEquals(callcount[0], 1) | ||||
| 
 | ||||
|         a.func.invalidate("foo") | ||||
|         a.func.invalidate(("foo",)) | ||||
| 
 | ||||
|         yield a.func("foo") | ||||
| 
 | ||||
|  | @ -151,7 +153,7 @@ class CacheDecoratorTestCase(unittest.TestCase): | |||
|             def func(self, key): | ||||
|                 return key | ||||
| 
 | ||||
|         A().func.invalidate("what") | ||||
|         A().func.invalidate(("what",)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_max_entries(self): | ||||
|  | @ -178,19 +180,20 @@ class CacheDecoratorTestCase(unittest.TestCase): | |||
|         self.assertTrue(callcount[0] >= 14, | ||||
|             msg="Expected callcount >= 14, got %d" % (callcount[0])) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_prefill(self): | ||||
|         callcount = [0] | ||||
| 
 | ||||
|         d = defer.succeed(123) | ||||
| 
 | ||||
|         class A(object): | ||||
|             @cached() | ||||
|             def func(self, key): | ||||
|                 callcount[0] += 1 | ||||
|                 return key | ||||
|                 return d | ||||
| 
 | ||||
|         a = A() | ||||
| 
 | ||||
|         a.func.prefill("foo", 123) | ||||
|         a.func.prefill(("foo",), ObservableDeferred(d)) | ||||
| 
 | ||||
|         self.assertEquals((yield a.func("foo")), 123) | ||||
|         self.assertEquals(a.func("foo").result, d.result) | ||||
|         self.assertEquals(callcount[0], 0) | ||||
|  |  | |||
|  | @ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase): | |||
|             yield d | ||||
|             self.assertTrue(d.called) | ||||
| 
 | ||||
|             observers[0].assert_called_once("Go") | ||||
|             observers[1].assert_called_once("Go") | ||||
|             observers[0].assert_called_once_with("Go") | ||||
|             observers[1].assert_called_once_with("Go") | ||||
| 
 | ||||
|             self.assertEquals(mock_logger.warning.call_count, 1) | ||||
|             self.assertIsInstance(mock_logger.warning.call_args[0][0], | ||||
|  |  | |||
|  | @ -69,7 +69,7 @@ class StateGroupStore(object): | |||
| 
 | ||||
|         self._next_group = 1 | ||||
| 
 | ||||
|     def get_state_groups(self, event_ids): | ||||
|     def get_state_groups(self, room_id, event_ids): | ||||
|         groups = {} | ||||
|         for event_id in event_ids: | ||||
|             group = self._event_to_state_group.get(event_id) | ||||
|  |  | |||
|  | @ -0,0 +1,101 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2015 OpenMarket Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| from tests import unittest | ||||
| 
 | ||||
| from synapse.util.caches.dictionary_cache import DictionaryCache | ||||
| 
 | ||||
| 
 | ||||
| class DictCacheTestCase(unittest.TestCase): | ||||
| 
 | ||||
|     def setUp(self): | ||||
|         self.cache = DictionaryCache("foobar") | ||||
| 
 | ||||
|     def test_simple_cache_hit_full(self): | ||||
|         key = "test_simple_cache_hit_full" | ||||
| 
 | ||||
|         v = self.cache.get(key) | ||||
|         self.assertEqual((False, {}), v) | ||||
| 
 | ||||
|         seq = self.cache.sequence | ||||
|         test_value = {"test": "test_simple_cache_hit_full"} | ||||
|         self.cache.update(seq, key, test_value, full=True) | ||||
| 
 | ||||
|         c = self.cache.get(key) | ||||
|         self.assertEqual(test_value, c.value) | ||||
| 
 | ||||
|     def test_simple_cache_hit_partial(self): | ||||
|         key = "test_simple_cache_hit_partial" | ||||
| 
 | ||||
|         seq = self.cache.sequence | ||||
|         test_value = { | ||||
|             "test": "test_simple_cache_hit_partial" | ||||
|         } | ||||
|         self.cache.update(seq, key, test_value, full=True) | ||||
| 
 | ||||
|         c = self.cache.get(key, ["test"]) | ||||
|         self.assertEqual(test_value, c.value) | ||||
| 
 | ||||
|     def test_simple_cache_miss_partial(self): | ||||
|         key = "test_simple_cache_miss_partial" | ||||
| 
 | ||||
|         seq = self.cache.sequence | ||||
|         test_value = { | ||||
|             "test": "test_simple_cache_miss_partial" | ||||
|         } | ||||
|         self.cache.update(seq, key, test_value, full=True) | ||||
| 
 | ||||
|         c = self.cache.get(key, ["test2"]) | ||||
|         self.assertEqual({}, c.value) | ||||
| 
 | ||||
|     def test_simple_cache_hit_miss_partial(self): | ||||
|         key = "test_simple_cache_hit_miss_partial" | ||||
| 
 | ||||
|         seq = self.cache.sequence | ||||
|         test_value = { | ||||
|             "test": "test_simple_cache_hit_miss_partial", | ||||
|             "test2": "test_simple_cache_hit_miss_partial2", | ||||
|             "test3": "test_simple_cache_hit_miss_partial3", | ||||
|         } | ||||
|         self.cache.update(seq, key, test_value, full=True) | ||||
| 
 | ||||
|         c = self.cache.get(key, ["test2"]) | ||||
|         self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) | ||||
| 
 | ||||
|     def test_multi_insert(self): | ||||
|         key = "test_simple_cache_hit_miss_partial" | ||||
| 
 | ||||
|         seq = self.cache.sequence | ||||
|         test_value_1 = { | ||||
|             "test": "test_simple_cache_hit_miss_partial", | ||||
|         } | ||||
|         self.cache.update(seq, key, test_value_1, full=False) | ||||
| 
 | ||||
|         seq = self.cache.sequence | ||||
|         test_value_2 = { | ||||
|             "test2": "test_simple_cache_hit_miss_partial2", | ||||
|         } | ||||
|         self.cache.update(seq, key, test_value_2, full=False) | ||||
| 
 | ||||
|         c = self.cache.get(key) | ||||
|         self.assertEqual( | ||||
|             { | ||||
|                 "test": "test_simple_cache_hit_miss_partial", | ||||
|                 "test2": "test_simple_cache_hit_miss_partial2", | ||||
|             }, | ||||
|             c.value | ||||
|         ) | ||||
|  | @ -16,7 +16,7 @@ | |||
| 
 | ||||
| from .. import unittest | ||||
| 
 | ||||
| from synapse.util.lrucache import LruCache | ||||
| from synapse.util.caches.lrucache import LruCache | ||||
| 
 | ||||
| class LruCacheTestCase(unittest.TestCase): | ||||
| 
 | ||||
|  | @ -52,5 +52,3 @@ class LruCacheTestCase(unittest.TestCase): | |||
|         cache["key"] = 1 | ||||
|         self.assertEquals(cache.pop("key"), 1) | ||||
|         self.assertEquals(cache.pop("key"), None) | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 David Baker
						David Baker