Slightly neater(?) arrangement of authentication wrapper for HTTP servlet methods

pull/101/head
Paul "LeoNerd" Evans 2015-03-05 20:33:16 +00:00
parent ba8ac996f9
commit 7644cb79b2
1 changed files with 37 additions and 25 deletions

View File

@ -19,6 +19,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function
import functools
import logging
import simplejson as json
import re
@ -30,8 +31,9 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(object):
"""Handles incoming federation HTTP requests"""
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def _authenticate_request(self, request):
def authenticate_request(self, request):
json_request = {
"method": request.method,
"uri": request.uri,
@ -93,22 +95,6 @@ class TransportLayerServer(object):
defer.returnValue((origin, content))
def _with_authentication(self, handler):
@defer.inlineCallbacks
def new_handler(request, *args, **kwargs):
try:
(origin, content) = yield self._authenticate_request(request)
with self.ratelimiter.ratelimit(origin) as d:
yield d
response = yield handler(
origin, content, request.args, *args, **kwargs
)
except:
logger.exception("_authenticate_request failed")
raise
defer.returnValue(response)
return new_handler
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
@ -116,8 +102,10 @@ class TransportLayerServer(object):
Args:
handler (TransportReceivedHandler)
"""
FederationSendServlet(
handler, self._with_authentication, self.server_name
FederationSendServlet(handler,
authenticator=self,
ratelimiter=self.ratelimiter,
server_name=self.server_name,
).register(self.server)
@log_function
@ -140,13 +128,37 @@ class TransportLayerServer(object):
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
):
servletclass(handler, self._with_authentication).register(self.server)
servletclass(handler,
authenticator=self,
ratelimiter=self.ratelimiter,
).register(self.server)
class BaseFederationServlet(object):
def __init__(self, handler, wrapper):
def __init__(self, handler, authenticator, ratelimiter):
self.handler = handler
self.wrapper = wrapper
self.authenticator = authenticator
self.ratelimiter = ratelimiter
def _wrap(self, code):
authenticator = self.authenticator
ratelimiter = self.ratelimiter
@defer.inlineCallbacks
@functools.wraps(code)
def new_code(request, *args, **kwargs):
try:
(origin, content) = yield authenticator.authenticate_request(request)
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield code(
origin, content, request.args, *args, **kwargs
)
except:
logger.exception("authenticate_request failed")
raise
defer.returnValue(response)
return new_code
def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH)
@ -156,14 +168,14 @@ class BaseFederationServlet(object):
if code is None:
continue
server.register_path(method, pattern, self.wrapper(code))
server.register_path(method, pattern, self._wrap(code))
class FederationSendServlet(BaseFederationServlet):
PATH = "/send/([^/]*)/$"
def __init__(self, handler, wrapper, server_name):
super(FederationSendServlet, self).__init__(handler, wrapper)
def __init__(self, handler, server_name, **kwargs):
super(FederationSendServlet, self).__init__(handler, **kwargs)
self.server_name = server_name
# This is when someone is trying to send us a bunch of data.