From a901ed16b5805adf04b5b8b1b99c14720e5abb3d Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 5 Mar 2015 19:10:57 +0000 Subject: [PATCH 1/5] Move federation API responding code out of weird mix of lambdas into Servlet-style methods on instances --- synapse/federation/transport/server.py | 282 +++++++++++-------------- 1 file changed, 121 insertions(+), 161 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index ece6dbcf62..eb3e30a189 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -122,14 +122,9 @@ class TransportLayerServer(object): Args: handler (TransportReceivedHandler) """ - self.received_handler = handler - - # This is when someone is trying to send us a bunch of data. - self.server.register_path( - "PUT", - re.compile("^" + PREFIX + "/send/([^/]*)/$"), - self._with_authentication(self._on_send_request) - ) + FederationSendServlet( + handler, self._with_authentication, self.server_name + ).register(self.server) @log_function def register_request_handler(self, handler): @@ -138,136 +133,48 @@ class TransportLayerServer(object): Args: handler (TransportRequestHandler) """ - self.request_handler = handler + for servletclass in ( + FederationPullServlet, + FederationEventServlet, + FederationStateServlet, + FederationBackfillServlet, + FederationQueryServlet, + FederationMakeJoinServlet, + FederationEventServlet, + FederationSendJoinServlet, + FederationInviteServlet, + FederationQueryAuthServlet, + FederationGetMissingEventsServlet, + ): + servletclass(handler, self._with_authentication).register(self.server) - # This is for when someone asks us for everything since version X - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/pull/$"), - self._with_authentication( - lambda origin, content, query: - handler.on_pull_request(query["origin"][0], query["v"]) - ) - ) - # This is when someone asks for a data item for a given server - # data_id pair. - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/event/([^/]*)/$"), - self._with_authentication( - lambda origin, content, query, event_id: - handler.on_pdu_request(origin, event_id) - ) - ) +class BaseFederationServlet(object): + def __init__(self, handler, wrapper): + self.handler = handler + self.wrapper = wrapper - # This is when someone asks for all data for a given context. - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/state/([^/]*)/$"), - self._with_authentication( - lambda origin, content, query, context: - handler.on_context_state_request( - origin, - context, - query.get("event_id", [None])[0], - ) - ) - ) + def register(self, server): + pattern = re.compile("^" + PREFIX + self.PATH) - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/backfill/([^/]*)/$"), - self._with_authentication( - lambda origin, content, query, context: - self._on_backfill_request( - origin, context, query["v"], query["limit"] - ) - ) - ) + for method in ("GET", "PUT", "POST"): + code = getattr(self, "on_%s" % (method), None) + if code is None: + continue - # This is when we receive a server-server Query - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/query/([^/]*)$"), - self._with_authentication( - lambda origin, content, query, query_type: - handler.on_query_request( - query_type, - {k: v[0].decode("utf-8") for k, v in query.items()} - ) - ) - ) + server.register_path(method, pattern, self.wrapper(code)) - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"), - self._with_authentication( - lambda origin, content, query, context, user_id: - self._on_make_join_request( - origin, content, query, context, user_id - ) - ) - ) - self.server.register_path( - "GET", - re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"), - self._with_authentication( - lambda origin, content, query, context, event_id: - handler.on_event_auth( - origin, context, event_id, - ) - ) - ) +class FederationSendServlet(BaseFederationServlet): + PATH = "/send/([^/]*)/$" - self.server.register_path( - "PUT", - re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"), - self._with_authentication( - lambda origin, content, query, context, event_id: - self._on_send_join_request( - origin, content, query, - ) - ) - ) - - self.server.register_path( - "PUT", - re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"), - self._with_authentication( - lambda origin, content, query, context, event_id: - self._on_invite_request( - origin, content, query, - ) - ) - ) - - self.server.register_path( - "POST", - re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"), - self._with_authentication( - lambda origin, content, query, context, event_id: - self._on_query_auth_request( - origin, content, event_id, - ) - ) - ) - - self.server.register_path( - "POST", - re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"), - self._with_authentication( - lambda origin, content, query, room_id: - self._get_missing_events( - origin, content, room_id, - ) - ) - ) + def __init__(self, handler, wrapper, server_name): + super(FederationSendServlet, self).__init__(handler, wrapper) + self.server_name = server_name + # This is when someone is trying to send us a bunch of data. @defer.inlineCallbacks - @log_function - def _on_send_request(self, origin, content, query, transaction_id): + def on_PUT(self, origin, content, query, transaction_id): """ Called on PUT /send// Args: @@ -305,8 +212,7 @@ class TransportLayerServer(object): return try: - handler = self.received_handler - code, response = yield handler.on_incoming_transaction( + code, response = yield self.handler.on_incoming_transaction( transaction_data ) except: @@ -315,65 +221,119 @@ class TransportLayerServer(object): defer.returnValue((code, response)) - @log_function - def _on_backfill_request(self, origin, context, v_list, limits): + +class FederationPullServlet(BaseFederationServlet): + PATH = "/pull/$" + + # This is for when someone asks us for everything since version X + def on_GET(self, origin, content, query): + return self.handler.on_pull_request(query["origin"][0], query["v"]) + + +class FederationEventServlet(BaseFederationServlet): + PATH = "/event/([^/]*)/$" + + # This is when someone asks for a data item for a given server data_id pair. + def on_GET(self, origin, content, query, event_id): + return self.handler.on_pdu_request(origin, event_id) + + +class FederationStateServlet(BaseFederationServlet): + PATH = "/state/([^/]*)/$" + + # This is when someone asks for all data for a given context. + def on_GET(self, origin, content, query, context): + return self.handler.on_context_state_request(origin, context, + query.get("event_id", [None])[0], + ) + + +class FederationBackfillServlet(BaseFederationServlet): + PATH = "/backfill/([^/]*)/$" + + def on_GET(self, origin, content, query, context): + versions = query["v"] + limits = query["limit"] + if not limits: - return defer.succeed( - (400, {"error": "Did not include limit param"}) - ) + return defer.succeed((400, {"error": "Did not include limit param"})) limit = int(limits[-1]) - versions = v_list + return self.handler.on_backfill_request(origin, context, versions, limit) - return self.request_handler.on_backfill_request( - origin, context, versions, limit + +class FederationQueryServlet(BaseFederationServlet): + PATH = "/query/([^/]*)$" + + # This is when we receive a server-server Query + def on_GET(self, origin, content, query, query_type): + return self.handler.on_query_request(query_type, + {k: v[0].decode("utf-8") for k, v in query.items()} ) + +class FederationMakeJoinServlet(BaseFederationServlet): + PATH = "/make_join/([^/]*)/([^/]*)$" + @defer.inlineCallbacks - @log_function - def _on_make_join_request(self, origin, content, query, context, user_id): - content = yield self.request_handler.on_make_join_request( - context, user_id, - ) + def on_GET(self, origin, content, query, context, user_id): + content = yield self.handler.on_make_join_request(context, user_id) defer.returnValue((200, content)) - @defer.inlineCallbacks - @log_function - def _on_send_join_request(self, origin, content, query): - content = yield self.request_handler.on_send_join_request( - origin, content, - ) +class FederationEventAuthServlet(BaseFederationServlet): + PATH = "/event_auth/([^/]*)/([^/]*)$" + + def on_GET(self, origin, content, query, context, event_id): + return self.handler.on_event_auth(origin, context, event_id) + + +class FederationSendJoinServlet(BaseFederationServlet): + PATH = "/send_join/([^/]*)/([^/]*)$" + + @defer.inlineCallbacks + def on_PUT(self, origin, content, query, context, event_id): + # TODO(paul): assert that context/event_id parsed from path actually + # match those given in content + content = yield self.handler.on_send_join_request(origin, content) defer.returnValue((200, content)) - @defer.inlineCallbacks - @log_function - def _on_invite_request(self, origin, content, query): - content = yield self.request_handler.on_invite_request( - origin, content, - ) +class FederationInviteServlet(BaseFederationServlet): + PATH = "/invite/([^/]*)/([^/]*)$" + + @defer.inlineCallbacks + def on_PUT(self, origin, content, query, context, event_id): + # TODO(paul): assert that context/event_id parsed from path actually + # match those given in content + content = yield self.handler.on_invite_request(origin, content) defer.returnValue((200, content)) + +class FederationQueryAuthServlet(BaseFederationServlet): + PATH = "/query_auth/([^/]*)/([^/]*)$" + @defer.inlineCallbacks - @log_function - def _on_query_auth_request(self, origin, content, event_id): - new_content = yield self.request_handler.on_query_auth_request( + def on_POST(self, origin, content, query, context, event_id): + new_content = yield self.handler.on_query_auth_request( origin, content, event_id ) defer.returnValue((200, new_content)) + +class FederationGetMissingEventsServlet(BaseFederationServlet): + PATH = "/get_missing_events/([^/]*)/?$" + @defer.inlineCallbacks - @log_function - def _get_missing_events(self, origin, content, room_id): + def on_POST(self, origin, content, query, room_id): limit = int(content.get("limit", 10)) min_depth = int(content.get("min_depth", 0)) earliest_events = content.get("earliest_events", []) latest_events = content.get("latest_events", []) - content = yield self.request_handler.on_get_missing_events( + content = yield self.handler.on_get_missing_events( origin, room_id=room_id, earliest_events=earliest_events, From ba8ac996f951c872c8815f09a4ffd3a508da6863 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 5 Mar 2015 19:43:17 +0000 Subject: [PATCH 2/5] Remove the dead 'rate_limit_origin' method from TransportLayerServer --- synapse/federation/transport/server.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index eb3e30a189..dc9f1e082b 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -109,12 +109,6 @@ class TransportLayerServer(object): defer.returnValue(response) return new_handler - def rate_limit_origin(self, handler): - def new_handler(origin, *args, **kwargs): - response = yield handler(origin, *args, **kwargs) - defer.returnValue(response) - return new_handler() - @log_function def register_received_handler(self, handler): """ Register a handler that will be fired when we receive data. From 7644cb79b222207ef739a9ca29699f32aa3cee0b Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 5 Mar 2015 20:33:16 +0000 Subject: [PATCH 3/5] Slightly neater(?) arrangement of authentication wrapper for HTTP servlet methods --- synapse/federation/transport/server.py | 62 +++++++++++++++----------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index dc9f1e082b..39b18ae303 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -19,6 +19,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.errors import Codes, SynapseError from synapse.util.logutils import log_function +import functools import logging import simplejson as json import re @@ -30,8 +31,9 @@ logger = logging.getLogger(__name__) class TransportLayerServer(object): """Handles incoming federation HTTP requests""" + # A method just so we can pass 'self' as the authenticator to the Servlets @defer.inlineCallbacks - def _authenticate_request(self, request): + def authenticate_request(self, request): json_request = { "method": request.method, "uri": request.uri, @@ -93,22 +95,6 @@ class TransportLayerServer(object): defer.returnValue((origin, content)) - def _with_authentication(self, handler): - @defer.inlineCallbacks - def new_handler(request, *args, **kwargs): - try: - (origin, content) = yield self._authenticate_request(request) - with self.ratelimiter.ratelimit(origin) as d: - yield d - response = yield handler( - origin, content, request.args, *args, **kwargs - ) - except: - logger.exception("_authenticate_request failed") - raise - defer.returnValue(response) - return new_handler - @log_function def register_received_handler(self, handler): """ Register a handler that will be fired when we receive data. @@ -116,8 +102,10 @@ class TransportLayerServer(object): Args: handler (TransportReceivedHandler) """ - FederationSendServlet( - handler, self._with_authentication, self.server_name + FederationSendServlet(handler, + authenticator=self, + ratelimiter=self.ratelimiter, + server_name=self.server_name, ).register(self.server) @log_function @@ -140,13 +128,37 @@ class TransportLayerServer(object): FederationQueryAuthServlet, FederationGetMissingEventsServlet, ): - servletclass(handler, self._with_authentication).register(self.server) + servletclass(handler, + authenticator=self, + ratelimiter=self.ratelimiter, + ).register(self.server) class BaseFederationServlet(object): - def __init__(self, handler, wrapper): + def __init__(self, handler, authenticator, ratelimiter): self.handler = handler - self.wrapper = wrapper + self.authenticator = authenticator + self.ratelimiter = ratelimiter + + def _wrap(self, code): + authenticator = self.authenticator + ratelimiter = self.ratelimiter + + @defer.inlineCallbacks + @functools.wraps(code) + def new_code(request, *args, **kwargs): + try: + (origin, content) = yield authenticator.authenticate_request(request) + with ratelimiter.ratelimit(origin) as d: + yield d + response = yield code( + origin, content, request.args, *args, **kwargs + ) + except: + logger.exception("authenticate_request failed") + raise + defer.returnValue(response) + return new_code def register(self, server): pattern = re.compile("^" + PREFIX + self.PATH) @@ -156,14 +168,14 @@ class BaseFederationServlet(object): if code is None: continue - server.register_path(method, pattern, self.wrapper(code)) + server.register_path(method, pattern, self._wrap(code)) class FederationSendServlet(BaseFederationServlet): PATH = "/send/([^/]*)/$" - def __init__(self, handler, wrapper, server_name): - super(FederationSendServlet, self).__init__(handler, wrapper) + def __init__(self, handler, server_name, **kwargs): + super(FederationSendServlet, self).__init__(handler, **kwargs) self.server_name = server_name # This is when someone is trying to send us a bunch of data. From 5eab2549ab13c14535de266cc153dc6d5b479590 Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 5 Mar 2015 20:36:05 +0000 Subject: [PATCH 4/5] Append a $ on PATH at registration time, meaning each PATH attribute doesn't need it --- synapse/federation/transport/server.py | 27 +++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 39b18ae303..8f985f8fe3 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -161,7 +161,7 @@ class BaseFederationServlet(object): return new_code def register(self, server): - pattern = re.compile("^" + PREFIX + self.PATH) + pattern = re.compile("^" + PREFIX + self.PATH + "$") for method in ("GET", "PUT", "POST"): code = getattr(self, "on_%s" % (method), None) @@ -172,7 +172,7 @@ class BaseFederationServlet(object): class FederationSendServlet(BaseFederationServlet): - PATH = "/send/([^/]*)/$" + PATH = "/send/([^/]*)/" def __init__(self, handler, server_name, **kwargs): super(FederationSendServlet, self).__init__(handler, **kwargs) @@ -229,7 +229,7 @@ class FederationSendServlet(BaseFederationServlet): class FederationPullServlet(BaseFederationServlet): - PATH = "/pull/$" + PATH = "/pull/" # This is for when someone asks us for everything since version X def on_GET(self, origin, content, query): @@ -237,7 +237,7 @@ class FederationPullServlet(BaseFederationServlet): class FederationEventServlet(BaseFederationServlet): - PATH = "/event/([^/]*)/$" + PATH = "/event/([^/]*)/" # This is when someone asks for a data item for a given server data_id pair. def on_GET(self, origin, content, query, event_id): @@ -245,7 +245,7 @@ class FederationEventServlet(BaseFederationServlet): class FederationStateServlet(BaseFederationServlet): - PATH = "/state/([^/]*)/$" + PATH = "/state/([^/]*)/" # This is when someone asks for all data for a given context. def on_GET(self, origin, content, query, context): @@ -255,7 +255,7 @@ class FederationStateServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet): - PATH = "/backfill/([^/]*)/$" + PATH = "/backfill/([^/]*)/" def on_GET(self, origin, content, query, context): versions = query["v"] @@ -270,7 +270,7 @@ class FederationBackfillServlet(BaseFederationServlet): class FederationQueryServlet(BaseFederationServlet): - PATH = "/query/([^/]*)$" + PATH = "/query/([^/]*)" # This is when we receive a server-server Query def on_GET(self, origin, content, query, query_type): @@ -280,7 +280,7 @@ class FederationQueryServlet(BaseFederationServlet): class FederationMakeJoinServlet(BaseFederationServlet): - PATH = "/make_join/([^/]*)/([^/]*)$" + PATH = "/make_join/([^/]*)/([^/]*)" @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): @@ -289,14 +289,14 @@ class FederationMakeJoinServlet(BaseFederationServlet): class FederationEventAuthServlet(BaseFederationServlet): - PATH = "/event_auth/([^/]*)/([^/]*)$" + PATH = "/event_auth/([^/]*)/([^/]*)" def on_GET(self, origin, content, query, context, event_id): return self.handler.on_event_auth(origin, context, event_id) class FederationSendJoinServlet(BaseFederationServlet): - PATH = "/send_join/([^/]*)/([^/]*)$" + PATH = "/send_join/([^/]*)/([^/]*)" @defer.inlineCallbacks def on_PUT(self, origin, content, query, context, event_id): @@ -307,7 +307,7 @@ class FederationSendJoinServlet(BaseFederationServlet): class FederationInviteServlet(BaseFederationServlet): - PATH = "/invite/([^/]*)/([^/]*)$" + PATH = "/invite/([^/]*)/([^/]*)" @defer.inlineCallbacks def on_PUT(self, origin, content, query, context, event_id): @@ -318,7 +318,7 @@ class FederationInviteServlet(BaseFederationServlet): class FederationQueryAuthServlet(BaseFederationServlet): - PATH = "/query_auth/([^/]*)/([^/]*)$" + PATH = "/query_auth/([^/]*)/([^/]*)" @defer.inlineCallbacks def on_POST(self, origin, content, query, context, event_id): @@ -330,7 +330,8 @@ class FederationQueryAuthServlet(BaseFederationServlet): class FederationGetMissingEventsServlet(BaseFederationServlet): - PATH = "/get_missing_events/([^/]*)/?$" + # TODO(paul): Why does this path alone end with "/?" optional? + PATH = "/get_missing_events/([^/]*)/?" @defer.inlineCallbacks def on_POST(self, origin, content, query, room_id): From d79d91a4a7bdd42bc6c4d0324623e11c8bd3c5ef Mon Sep 17 00:00:00 2001 From: "Paul \"LeoNerd\" Evans" Date: Thu, 5 Mar 2015 20:53:33 +0000 Subject: [PATCH 5/5] Appease pep8 --- synapse/federation/transport/server.py | 46 +++++++++++++++----------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 8f985f8fe3..6c624977d7 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -102,7 +102,8 @@ class TransportLayerServer(object): Args: handler (TransportReceivedHandler) """ - FederationSendServlet(handler, + FederationSendServlet( + handler, authenticator=self, ratelimiter=self.ratelimiter, server_name=self.server_name, @@ -115,20 +116,9 @@ class TransportLayerServer(object): Args: handler (TransportRequestHandler) """ - for servletclass in ( - FederationPullServlet, - FederationEventServlet, - FederationStateServlet, - FederationBackfillServlet, - FederationQueryServlet, - FederationMakeJoinServlet, - FederationEventServlet, - FederationSendJoinServlet, - FederationInviteServlet, - FederationQueryAuthServlet, - FederationGetMissingEventsServlet, - ): - servletclass(handler, + for servletclass in SERVLET_CLASSES: + servletclass( + handler, authenticator=self, ratelimiter=self.ratelimiter, ).register(self.server) @@ -138,11 +128,11 @@ class BaseFederationServlet(object): def __init__(self, handler, authenticator, ratelimiter): self.handler = handler self.authenticator = authenticator - self.ratelimiter = ratelimiter + self.ratelimiter = ratelimiter def _wrap(self, code): authenticator = self.authenticator - ratelimiter = self.ratelimiter + ratelimiter = self.ratelimiter @defer.inlineCallbacks @functools.wraps(code) @@ -249,7 +239,9 @@ class FederationStateServlet(BaseFederationServlet): # This is when someone asks for all data for a given context. def on_GET(self, origin, content, query, context): - return self.handler.on_context_state_request(origin, context, + return self.handler.on_context_state_request( + origin, + context, query.get("event_id", [None])[0], ) @@ -274,7 +266,8 @@ class FederationQueryServlet(BaseFederationServlet): # This is when we receive a server-server Query def on_GET(self, origin, content, query, query_type): - return self.handler.on_query_request(query_type, + return self.handler.on_query_request( + query_type, {k: v[0].decode("utf-8") for k, v in query.items()} ) @@ -350,3 +343,18 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): ) defer.returnValue((200, content)) + + +SERVLET_CLASSES = ( + FederationPullServlet, + FederationEventServlet, + FederationStateServlet, + FederationBackfillServlet, + FederationQueryServlet, + FederationMakeJoinServlet, + FederationEventServlet, + FederationSendJoinServlet, + FederationInviteServlet, + FederationQueryAuthServlet, + FederationGetMissingEventsServlet, +)