Merge pull request #128 from matrix-org/unify_http_wrappers

Unify http wrappers
pull/133/head
Mark Haines 2015-04-21 17:01:04 +01:00
commit 8b183781cb
7 changed files with 249 additions and 287 deletions

View File

@ -51,6 +51,80 @@ response_timer = metrics.register_distribution(
labels=["method", "servlet"] labels=["method", "servlet"]
) )
_next_request_id = 0
def request_handler(request_handler):
"""Wraps a method that acts as a request handler with the necessary logging
and exception handling.
The method must have a signature of "handle_foo(self, request)". The
argument "self" must have "version_string" and "clock" attributes. The
argument "request" must be a twisted HTTP request.
The method must return a deferred. If the deferred succeeds we assume that
a response has been sent. If the deferred fails with a SynapseError we use
it to send a JSON response with the appropriate HTTP reponse code. If the
deferred fails with any other type of error we send a 500 reponse.
We insert a unique request-id into the logging context for this request and
log the response and duration for this request.
"""
@defer.inlineCallbacks
def wrapped_request_handler(self, request):
global _next_request_id
request_id = "%s-%s" % (request.method, _next_request_id)
_next_request_id += 1
with LoggingContext(request_id) as request_context:
request_context.request = request_id
code = None
start = self.clock.time_msec()
try:
logger.info(
"Received request: %s %s",
request.method, request.path
)
yield request_handler(self, request)
code = request.code
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
code = 500
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{"error": "Internal server error"},
send_cors=True
)
finally:
code = str(code) if code else "-"
end = self.clock.time_msec()
logger.info(
"Processed request: %dms %s %s %s",
end-start, code, request.method, request.path
)
return wrapped_request_handler
class HttpServer(object): class HttpServer(object):
""" Interface for registering callbacks on a HTTP server """ Interface for registering callbacks on a HTTP server
@ -115,102 +189,56 @@ class JsonResource(HttpServer, resource.Resource):
def render(self, request): def render(self, request):
""" This get's called by twisted every time someone sends us a request. """ This get's called by twisted every time someone sends us a request.
""" """
self._async_render_with_logging_context(request) self._async_render(request)
return server.NOT_DONE_YET return server.NOT_DONE_YET
_request_id = 0 @request_handler
@defer.inlineCallbacks
def _async_render_with_logging_context(self, request):
request_id = "%s-%s" % (request.method, JsonResource._request_id)
JsonResource._request_id += 1
with LoggingContext(request_id) as request_context:
request_context.request = request_id
yield self._async_render(request)
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render(self, request): def _async_render(self, request):
""" This get's called by twisted every time someone sends us a request. """ This get's called by twisted every time someone sends us a request.
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
""" """
code = None
start = self.clock.time_msec() start = self.clock.time_msec()
try: if request.method == "OPTIONS":
# Just say yes to OPTIONS. self._send_response(request, 200, {})
if request.method == "OPTIONS": return
self._send_response(request, 200, {}) # Loop through all the registered callbacks to check if the method
return # and path regex match
for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path)
if not m:
continue
# Loop through all the registered callbacks to check if the method # We found a match! Trigger callback and then return the
# and path regex match # returned response. We pass both the request and any
for path_entry in self.path_regexs.get(request.method, []): # matched groups from the regex to the callback.
m = path_entry.pattern.match(request.path)
if not m:
continue
# We found a match! Trigger callback and then return the callback = path_entry.callback
# returned response. We pass both the request and any
# matched groups from the regex to the callback.
callback = path_entry.callback servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_instance = getattr(callback, "__self__", None) servlet_classname = servlet_instance.__class__.__name__
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
incoming_requests_counter.inc(request.method, servlet_classname)
args = [
urllib.unquote(u).decode("UTF-8") for u in m.groups()
]
logger.info(
"Received request: %s %s",
request.method, request.path
)
code, response = yield callback(request, *args)
self._send_response(request, code, response)
response_timer.inc_by(
self.clock.time_msec() - start, request.method, servlet_classname
)
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
raise UnrecognizedRequestError()
except CodeMessageException as e:
if isinstance(e, SynapseError):
logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)
else: else:
logger.exception(e) servlet_classname = "%r" % callback
incoming_requests_counter.inc(request.method, servlet_classname)
code = e.code args = [
self._send_response( urllib.unquote(u).decode("UTF-8") for u in m.groups()
request, ]
code,
cs_exception(e),
response_code_message=e.response_code_message
)
except Exception as e:
logger.exception(e)
self._send_response(
request,
500,
{"error": "Internal server error"}
)
finally:
code = str(code) if code else "-"
end = self.clock.time_msec() code, response = yield callback(request, *args)
logger.info(
"Processed request: %dms %s %s %s", self._send_response(request, code, response)
end-start, code, request.method, request.path response_timer.inc_by(
self.clock.time_msec() - start, request.method, servlet_classname
) )
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
raise UnrecognizedRequestError()
def _send_response(self, request, code, response_json_object, def _send_response(self, request, code, response_json_object,
response_code_message=None): response_code_message=None):
# could alternatively use request.notifyFinish() and flip a flag when # could alternatively use request.notifyFinish() and flip a flag when
@ -229,20 +257,10 @@ class JsonResource(HttpServer, resource.Resource):
request, code, response_json_object, request, code, response_json_object,
send_cors=True, send_cors=True,
response_code_message=response_code_message, response_code_message=response_code_message,
pretty_print=self._request_user_agent_is_curl, pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string, version_string=self.version_string,
) )
@staticmethod
def _request_user_agent_is_curl(request):
user_agents = request.requestHeaders.getRawHeaders(
"User-Agent", default=[]
)
for user_agent in user_agents:
if "curl" in user_agent:
return True
return False
class RootRedirect(resource.Resource): class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path.""" """Redirects the root '/' path to another path."""
@ -263,8 +281,8 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False, def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False, response_code_message=None, pretty_print=False,
version_string=""): version_string=""):
if not pretty_print: if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) json_bytes = encode_pretty_printed_json(json_object) + "\n"
else: else:
json_bytes = encode_canonical_json(json_object) json_bytes = encode_canonical_json(json_object)
@ -304,3 +322,13 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.write(json_bytes) request.write(json_bytes)
request.finish() request.finish()
return NOT_DONE_YET return NOT_DONE_YET
def _request_user_agent_is_curl(request):
user_agents = request.requestHeaders.getRawHeaders(
"User-Agent", default=[]
)
for user_agent in user_agents:
if "curl" in user_agent:
return True
return False

View File

@ -23,6 +23,61 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def parse_integer(request, name, default=None, required=False):
if name in request.args:
try:
return int(request.args[name][0])
except:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing integer query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
def parse_boolean(request, name, default=None, required=False):
if name in request.args:
try:
return {
"true": True,
"false": False,
}[request.args[name][0]]
except:
message = (
"Boolean query parameter %r must be one of"
" ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing boolean query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"):
if name in request.args:
value = request.args[name][0]
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values)
)
raise SynapseError(message)
else:
return value
else:
if required:
message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message)
else:
return default
class RestServlet(object): class RestServlet(object):
""" A Synapse REST Servlet. """ A Synapse REST Servlet.
@ -56,58 +111,3 @@ class RestServlet(object):
http_server.register_path(method, pattern, method_handler) http_server.register_path(method, pattern, method_handler)
else: else:
raise NotImplementedError("RestServlet must register something.") raise NotImplementedError("RestServlet must register something.")
@staticmethod
def parse_integer(request, name, default=None, required=False):
if name in request.args:
try:
return int(request.args[name][0])
except:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing integer query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
@staticmethod
def parse_boolean(request, name, default=None, required=False):
if name in request.args:
try:
return {
"true": True,
"false": False,
}[request.args[name][0]]
except:
message = (
"Boolean query parameter %r must be one of"
" ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing boolean query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
@staticmethod
def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"):
if name in request.args:
value = request.args[name][0]
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values)
)
raise SynapseError(message)
else:
return value
else:
if required:
message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message)
else:
return default

View File

@ -15,7 +15,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import RestServlet from synapse.http.servlet import (
RestServlet, parse_string, parse_integer, parse_boolean
)
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.events.utils import ( from synapse.events.utils import (
@ -87,20 +89,20 @@ class SyncRestServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
user, client = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
timeout = self.parse_integer(request, "timeout", default=0) timeout = parse_integer(request, "timeout", default=0)
limit = self.parse_integer(request, "limit", required=True) limit = parse_integer(request, "limit", required=True)
gap = self.parse_boolean(request, "gap", default=True) gap = parse_boolean(request, "gap", default=True)
sort = self.parse_string( sort = parse_string(
request, "sort", default="timeline,asc", request, "sort", default="timeline,asc",
allowed_values=self.ALLOWED_SORT allowed_values=self.ALLOWED_SORT
) )
since = self.parse_string(request, "since") since = parse_string(request, "since")
set_presence = self.parse_string( set_presence = parse_string(
request, "set_presence", default="online", request, "set_presence", default="online",
allowed_values=self.ALLOWED_PRESENCE allowed_values=self.ALLOWED_PRESENCE
) )
backfill = self.parse_boolean(request, "backfill", default=False) backfill = parse_boolean(request, "backfill", default=False)
filter_id = self.parse_string(request, "filter", default=None) filter_id = parse_string(request, "filter", default=None)
logger.info( logger.info(
"/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r," "/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r,"

View File

@ -18,7 +18,7 @@ from .thumbnailer import Thumbnailer
from synapse.http.server import respond_with_json from synapse.http.server import respond_with_json
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.api.errors import ( from synapse.api.errors import (
cs_exception, CodeMessageException, cs_error, Codes, SynapseError cs_error, Codes, SynapseError
) )
from twisted.internet import defer from twisted.internet import defer
@ -32,6 +32,18 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def parse_media_id(request):
try:
server_name, media_id = request.postpath
return (server_name, media_id)
except:
raise SynapseError(
404,
"Invalid media id token %r" % (request.postpath,),
Codes.UNKNOWN,
)
class BaseMediaResource(Resource): class BaseMediaResource(Resource):
isLeaf = True isLeaf = True
@ -45,74 +57,9 @@ class BaseMediaResource(Resource):
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels self.max_image_pixels = hs.config.max_image_pixels
self.filepaths = filepaths self.filepaths = filepaths
self.version_string = hs.version_string
self.downloads = {} self.downloads = {}
@staticmethod
def catch_errors(request_handler):
@defer.inlineCallbacks
def wrapped_request_handler(self, request):
try:
yield request_handler(self, request)
except CodeMessageException as e:
logger.info("Responding with error: %r", e)
respond_with_json(
request, e.code, cs_exception(e), send_cors=True
)
except:
logger.exception(
"Failed handle request %s.%s on %r",
request_handler.__module__,
request_handler.__name__,
self,
)
respond_with_json(
request,
500,
{"error": "Internal server error"},
send_cors=True
)
return wrapped_request_handler
@staticmethod
def _parse_media_id(request):
try:
server_name, media_id = request.postpath
return (server_name, media_id)
except:
raise SynapseError(
404,
"Invalid media id token %r" % (request.postpath,),
Codes.UNKNOWN,
)
@staticmethod
def _parse_integer(request, arg_name, default=None):
try:
if default is None:
return int(request.args[arg_name][0])
else:
return int(request.args.get(arg_name, [default])[0])
except:
raise SynapseError(
400,
"Missing integer argument %r" % (arg_name,),
Codes.UNKNOWN,
)
@staticmethod
def _parse_string(request, arg_name, default=None):
try:
if default is None:
return request.args[arg_name][0]
else:
return request.args.get(arg_name, [default])[0]
except:
raise SynapseError(
400,
"Missing string argument %r" % (arg_name,),
Codes.UNKNOWN,
)
def _respond_404(self, request): def _respond_404(self, request):
respond_with_json( respond_with_json(
request, 404, request, 404,

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .base_resource import BaseMediaResource from .base_resource import BaseMediaResource, parse_media_id
from synapse.http.server import request_handler
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer from twisted.internet import defer
@ -28,15 +29,10 @@ class DownloadResource(BaseMediaResource):
self._async_render_GET(request) self._async_render_GET(request)
return NOT_DONE_YET return NOT_DONE_YET
@BaseMediaResource.catch_errors @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
try: server_name, media_id = parse_media_id(request)
server_name, media_id = request.postpath
except:
self._respond_404(request)
return
if server_name == self.server_name: if server_name == self.server_name:
yield self._respond_local_file(request, media_id) yield self._respond_local_file(request, media_id)
else: else:

View File

@ -14,7 +14,9 @@
# limitations under the License. # limitations under the License.
from .base_resource import BaseMediaResource from .base_resource import BaseMediaResource, parse_media_id
from synapse.http.servlet import parse_string, parse_integer
from synapse.http.server import request_handler
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer from twisted.internet import defer
@ -31,14 +33,14 @@ class ThumbnailResource(BaseMediaResource):
self._async_render_GET(request) self._async_render_GET(request)
return NOT_DONE_YET return NOT_DONE_YET
@BaseMediaResource.catch_errors @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
server_name, media_id = self._parse_media_id(request) server_name, media_id = parse_media_id(request)
width = self._parse_integer(request, "width") width = parse_integer(request, "width")
height = self._parse_integer(request, "height") height = parse_integer(request, "height")
method = self._parse_string(request, "method", "scale") method = parse_string(request, "method", "scale")
m_type = self._parse_string(request, "type", "image/png") m_type = parse_string(request, "type", "image/png")
if server_name == self.server_name: if server_name == self.server_name:
yield self._respond_local_thumbnail( yield self._respond_local_thumbnail(

View File

@ -13,12 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.http.server import respond_with_json from synapse.http.server import respond_with_json, request_handler
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.api.errors import ( from synapse.api.errors import SynapseError
cs_exception, SynapseError, CodeMessageException
)
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer from twisted.internet import defer
@ -69,53 +67,42 @@ class UploadResource(BaseMediaResource):
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_POST(self, request): def _async_render_POST(self, request):
try: auth_user, client = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have
# TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point
# already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length")
content_length = request.getHeader("Content-Length") if content_length is None:
if content_length is None: raise SynapseError(
raise SynapseError( msg="Request must specify a Content-Length", code=400
msg="Request must specify a Content-Length", code=400 )
) if int(content_length) > self.max_upload_size:
if int(content_length) > self.max_upload_size: raise SynapseError(
raise SynapseError( msg="Upload request body is too large",
msg="Upload request body is too large", code=413,
code=413,
)
headers = request.requestHeaders
if headers.hasHeader("Content-Type"):
media_type = headers.getRawHeaders("Content-Type")[0]
else:
raise SynapseError(
msg="Upload request missing 'Content-Type'",
code=400,
)
# if headers.hasHeader("Content-Disposition"):
# disposition = headers.getRawHeaders("Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
content_uri = yield self.create_content(
media_type, None, request.content.read(),
content_length, auth_user
) )
respond_with_json( headers = request.requestHeaders
request, 200, {"content_uri": content_uri}, send_cors=True
) if headers.hasHeader("Content-Type"):
except CodeMessageException as e: media_type = headers.getRawHeaders("Content-Type")[0]
logger.exception(e) else:
respond_with_json(request, e.code, cs_exception(e), send_cors=True) raise SynapseError(
except: msg="Upload request missing 'Content-Type'",
logger.exception("Failed to store file") code=400,
respond_with_json(
request,
500,
{"error": "Internal server error"},
send_cors=True
) )
# if headers.hasHeader("Content-Disposition"):
# disposition = headers.getRawHeaders("Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
content_uri = yield self.create_content(
media_type, None, request.content.read(),
content_length, auth_user
)
respond_with_json(
request, 200, {"content_uri": content_uri}, send_cors=True
)