Merge different Resource implementation classes (#7732)
parent
21a212f8e5
commit
5cdca53aa0
|
@ -0,0 +1 @@
|
|||
Fix "Tried to close a non-active scope!" error messages when opentracing is enabled.
|
|
@ -361,11 +361,7 @@ class BaseFederationServlet(object):
|
|||
continue
|
||||
|
||||
server.register_paths(
|
||||
method,
|
||||
(pattern,),
|
||||
self._wrap(code),
|
||||
self.__class__.__name__,
|
||||
trace=False,
|
||||
method, (pattern,), self._wrap(code), self.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -13,13 +13,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
||||
from synapse.http.server import wrap_json_request_handler
|
||||
from synapse.http.server import DirectServeJsonResource
|
||||
|
||||
|
||||
class AdditionalResource(Resource):
|
||||
class AdditionalResource(DirectServeJsonResource):
|
||||
"""Resource wrapper for additional_resources
|
||||
|
||||
If the user has configured additional_resources, we need to wrap the
|
||||
|
@ -41,16 +38,10 @@ class AdditionalResource(Resource):
|
|||
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
|
||||
function to be called to handle the request.
|
||||
"""
|
||||
Resource.__init__(self)
|
||||
super().__init__()
|
||||
self._handler = handler
|
||||
|
||||
# required by the request_handler wrapper
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def render(self, request):
|
||||
self._async_render(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@wrap_json_request_handler
|
||||
def _async_render(self, request):
|
||||
# Cheekily pass the result straight through, so we don't need to worry
|
||||
# if its an awaitable or not.
|
||||
return self._handler(request)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import html
|
||||
import logging
|
||||
|
@ -21,7 +22,7 @@ import types
|
|||
import urllib
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Awaitable, Callable, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
import jinja2
|
||||
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
||||
|
@ -62,99 +63,43 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
|
|||
"""
|
||||
|
||||
|
||||
def wrap_json_request_handler(h):
|
||||
"""Wraps a request handler method with exception handling.
|
||||
|
||||
Also does the wrapping with request.processing as per wrap_async_request_handler.
|
||||
|
||||
The handler method must have a signature of "handle_foo(self, request)",
|
||||
where "request" must be a SynapseRequest.
|
||||
|
||||
The handler must return a deferred or a coroutine. If the deferred succeeds
|
||||
we assume that a response has been sent. If the deferred fails with a SynapseError we use
|
||||
it to send a JSON response with the appropriate HTTP reponse code. If the
|
||||
deferred fails with any other type of error we send a 500 reponse.
|
||||
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
||||
"""Sends a JSON error response to clients.
|
||||
"""
|
||||
|
||||
async def wrapped_request_handler(self, request):
|
||||
try:
|
||||
await h(self, request)
|
||||
except SynapseError as e:
|
||||
code = e.code
|
||||
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
|
||||
if f.check(SynapseError):
|
||||
error_code = f.value.code
|
||||
error_dict = f.value.error_dict()
|
||||
|
||||
# Only respond with an error response if we haven't already started
|
||||
# writing, otherwise lets just kill the connection
|
||||
if request.startedWriting:
|
||||
if request.transport:
|
||||
try:
|
||||
request.transport.abortConnection()
|
||||
except Exception:
|
||||
# abortConnection throws if the connection is already closed
|
||||
pass
|
||||
else:
|
||||
respond_with_json(
|
||||
request,
|
||||
code,
|
||||
e.error_dict(),
|
||||
send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
)
|
||||
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
|
||||
else:
|
||||
error_code = 500
|
||||
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
|
||||
|
||||
except Exception:
|
||||
# failure.Failure() fishes the original Failure out
|
||||
# of our stack, and thus gives us a sensible stack
|
||||
# trace.
|
||||
f = failure.Failure()
|
||||
logger.error(
|
||||
"Failed handle request via %r: %r",
|
||||
request.request_metrics.name,
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
# Only respond with an error response if we haven't already started
|
||||
# writing, otherwise lets just kill the connection
|
||||
if request.startedWriting:
|
||||
if request.transport:
|
||||
try:
|
||||
request.transport.abortConnection()
|
||||
except Exception:
|
||||
# abortConnection throws if the connection is already closed
|
||||
pass
|
||||
else:
|
||||
respond_with_json(
|
||||
request,
|
||||
500,
|
||||
{"error": "Internal server error", "errcode": Codes.UNKNOWN},
|
||||
send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
)
|
||||
logger.error(
|
||||
"Failed handle request via %r: %r",
|
||||
request.request_metrics.name,
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
|
||||
return wrap_async_request_handler(wrapped_request_handler)
|
||||
|
||||
|
||||
TV = TypeVar("TV")
|
||||
|
||||
|
||||
def wrap_html_request_handler(
|
||||
h: Callable[[TV, SynapseRequest], Awaitable]
|
||||
) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
|
||||
"""Wraps a request handler method with exception handling.
|
||||
|
||||
Also does the wrapping with request.processing as per wrap_async_request_handler.
|
||||
|
||||
The handler method must have a signature of "handle_foo(self, request)",
|
||||
where "request" must be a SynapseRequest.
|
||||
"""
|
||||
|
||||
async def wrapped_request_handler(self, request):
|
||||
try:
|
||||
await h(self, request)
|
||||
except Exception:
|
||||
f = failure.Failure()
|
||||
return_html_error(f, request, HTML_ERROR_TEMPLATE)
|
||||
|
||||
return wrap_async_request_handler(wrapped_request_handler)
|
||||
# Only respond with an error response if we haven't already started writing,
|
||||
# otherwise lets just kill the connection
|
||||
if request.startedWriting:
|
||||
if request.transport:
|
||||
try:
|
||||
request.transport.abortConnection()
|
||||
except Exception:
|
||||
# abortConnection throws if the connection is already closed
|
||||
pass
|
||||
else:
|
||||
respond_with_json(
|
||||
request,
|
||||
error_code,
|
||||
error_dict,
|
||||
send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
)
|
||||
|
||||
|
||||
def return_html_error(
|
||||
|
@ -249,7 +194,113 @@ class HttpServer(object):
|
|||
pass
|
||||
|
||||
|
||||
class JsonResource(HttpServer, resource.Resource):
|
||||
class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
||||
"""Base class for resources that have async handlers.
|
||||
|
||||
Sub classes can either implement `_async_render_<METHOD>` to handle
|
||||
requests by method, or override `_async_render` to handle all requests.
|
||||
|
||||
Args:
|
||||
extract_context: Whether to attempt to extract the opentracing
|
||||
context from the request the servlet is handling.
|
||||
"""
|
||||
|
||||
def __init__(self, extract_context=False):
|
||||
super().__init__()
|
||||
|
||||
self._extract_context = extract_context
|
||||
|
||||
def render(self, request):
|
||||
""" This gets called by twisted every time someone sends us a request.
|
||||
"""
|
||||
defer.ensureDeferred(self._async_render_wrapper(request))
|
||||
return NOT_DONE_YET
|
||||
|
||||
@wrap_async_request_handler
|
||||
async def _async_render_wrapper(self, request):
|
||||
"""This is a wrapper that delegates to `_async_render` and handles
|
||||
exceptions, return values, metrics, etc.
|
||||
"""
|
||||
try:
|
||||
request.request_metrics.name = self.__class__.__name__
|
||||
|
||||
with trace_servlet(request, self._extract_context):
|
||||
callback_return = await self._async_render(request)
|
||||
|
||||
if callback_return is not None:
|
||||
code, response = callback_return
|
||||
self._send_response(request, code, response)
|
||||
except Exception:
|
||||
# failure.Failure() fishes the original Failure out
|
||||
# of our stack, and thus gives us a sensible stack
|
||||
# trace.
|
||||
f = failure.Failure()
|
||||
self._send_error_response(f, request)
|
||||
|
||||
async def _async_render(self, request):
|
||||
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
|
||||
no appropriate method exists. Can be overriden in sub classes for
|
||||
different routing.
|
||||
"""
|
||||
|
||||
method_handler = getattr(
|
||||
self, "_async_render_%s" % (request.method.decode("ascii"),), None
|
||||
)
|
||||
if method_handler:
|
||||
raw_callback_return = method_handler(request)
|
||||
|
||||
# Is it synchronous? We'll allow this for now.
|
||||
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
|
||||
callback_return = await raw_callback_return
|
||||
else:
|
||||
callback_return = raw_callback_return
|
||||
|
||||
return callback_return
|
||||
|
||||
_unrecognised_request_handler(request)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _send_response(
|
||||
self, request: SynapseRequest, code: int, response_object: Any,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _send_error_response(
|
||||
self, f: failure.Failure, request: SynapseRequest,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DirectServeJsonResource(_AsyncResource):
|
||||
"""A resource that will call `self._async_on_<METHOD>` on new requests,
|
||||
formatting responses and errors as JSON.
|
||||
"""
|
||||
|
||||
def _send_response(
|
||||
self, request, code, response_object,
|
||||
):
|
||||
"""Implements _AsyncResource._send_response
|
||||
"""
|
||||
# TODO: Only enable CORS for the requests that need it.
|
||||
respond_with_json(
|
||||
request,
|
||||
code,
|
||||
response_object,
|
||||
send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
canonical_json=self.canonical_json,
|
||||
)
|
||||
|
||||
def _send_error_response(
|
||||
self, f: failure.Failure, request: SynapseRequest,
|
||||
) -> None:
|
||||
"""Implements _AsyncResource._send_error_response
|
||||
"""
|
||||
return_json_error(f, request)
|
||||
|
||||
|
||||
class JsonResource(DirectServeJsonResource):
|
||||
""" This implements the HttpServer interface and provides JSON support for
|
||||
Resources.
|
||||
|
||||
|
@ -269,17 +320,15 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
"_PathEntry", ["pattern", "callback", "servlet_classname"]
|
||||
)
|
||||
|
||||
def __init__(self, hs, canonical_json=True):
|
||||
resource.Resource.__init__(self)
|
||||
def __init__(self, hs, canonical_json=True, extract_context=False):
|
||||
super().__init__(extract_context)
|
||||
|
||||
self.canonical_json = canonical_json
|
||||
self.clock = hs.get_clock()
|
||||
self.path_regexs = {}
|
||||
self.hs = hs
|
||||
|
||||
def register_paths(
|
||||
self, method, path_patterns, callback, servlet_classname, trace=True
|
||||
):
|
||||
def register_paths(self, method, path_patterns, callback, servlet_classname):
|
||||
"""
|
||||
Registers a request handler against a regular expression. Later request URLs are
|
||||
checked against these regular expressions in order to identify an appropriate
|
||||
|
@ -295,74 +344,23 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
|
||||
servlet_classname (str): The name of the handler to be used in prometheus
|
||||
and opentracing logs.
|
||||
|
||||
trace (bool): Whether we should start a span to trace the servlet.
|
||||
"""
|
||||
method = method.encode("utf-8") # method is bytes on py3
|
||||
|
||||
if trace:
|
||||
# We don't extract the context from the servlet because we can't
|
||||
# trust the sender
|
||||
callback = trace_servlet(servlet_classname)(callback)
|
||||
|
||||
for path_pattern in path_patterns:
|
||||
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||
self.path_regexs.setdefault(method, []).append(
|
||||
self._PathEntry(path_pattern, callback, servlet_classname)
|
||||
)
|
||||
|
||||
def render(self, request):
|
||||
""" This gets called by twisted every time someone sends us a request.
|
||||
"""
|
||||
defer.ensureDeferred(self._async_render(request))
|
||||
return NOT_DONE_YET
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render(self, request):
|
||||
""" This gets called from render() every time someone sends us a request.
|
||||
This checks if anyone has registered a callback for that method and
|
||||
path.
|
||||
"""
|
||||
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
|
||||
|
||||
# Make sure we have a name for this handler in prometheus.
|
||||
request.request_metrics.name = servlet_classname
|
||||
|
||||
# Now trigger the callback. If it returns a response, we send it
|
||||
# here. If it throws an exception, that is handled by the wrapper
|
||||
# installed by @request_handler.
|
||||
kwargs = intern_dict(
|
||||
{
|
||||
name: urllib.parse.unquote(value) if value else value
|
||||
for name, value in group_dict.items()
|
||||
}
|
||||
)
|
||||
|
||||
callback_return = callback(request, **kwargs)
|
||||
|
||||
# Is it synchronous? We'll allow this for now.
|
||||
if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
|
||||
callback_return = await callback_return
|
||||
|
||||
if callback_return is not None:
|
||||
code, response = callback_return
|
||||
self._send_response(request, code, response)
|
||||
|
||||
def _get_handler_for_request(self, request):
|
||||
"""Finds a callback method to handle the given request
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request):
|
||||
def _get_handler_for_request(
|
||||
self, request: SynapseRequest
|
||||
) -> Tuple[Callable, str, Dict[str, str]]:
|
||||
"""Finds a callback method to handle the given request.
|
||||
|
||||
Returns:
|
||||
Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
|
||||
label to use for that method in prometheus metrics, and the
|
||||
dict mapping keys to path components as specified in the
|
||||
handler's path match regexp.
|
||||
|
||||
The callback will normally be a method registered via
|
||||
register_paths, so will return (possibly via Deferred) either
|
||||
None, or a tuple of (http code, response body).
|
||||
A tuple of the callback to use, the name of the servlet, and the
|
||||
key word arguments to pass to the callback
|
||||
"""
|
||||
request_path = request.path.decode("ascii")
|
||||
|
||||
|
@ -377,42 +375,59 @@ class JsonResource(HttpServer, resource.Resource):
|
|||
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
||||
return _unrecognised_request_handler, "unrecognised_request_handler", {}
|
||||
|
||||
def _send_response(
|
||||
self, request, code, response_json_object, response_code_message=None
|
||||
):
|
||||
# TODO: Only enable CORS for the requests that need it.
|
||||
respond_with_json(
|
||||
request,
|
||||
code,
|
||||
response_json_object,
|
||||
send_cors=True,
|
||||
response_code_message=response_code_message,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
canonical_json=self.canonical_json,
|
||||
async def _async_render(self, request):
|
||||
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
|
||||
|
||||
# Make sure we have an appopriate name for this handler in prometheus
|
||||
# (rather than the default of JsonResource).
|
||||
request.request_metrics.name = servlet_classname
|
||||
|
||||
# Now trigger the callback. If it returns a response, we send it
|
||||
# here. If it throws an exception, that is handled by the wrapper
|
||||
# installed by @request_handler.
|
||||
kwargs = intern_dict(
|
||||
{
|
||||
name: urllib.parse.unquote(value) if value else value
|
||||
for name, value in group_dict.items()
|
||||
}
|
||||
)
|
||||
|
||||
raw_callback_return = callback(request, **kwargs)
|
||||
|
||||
class DirectServeResource(resource.Resource):
|
||||
def render(self, request):
|
||||
# Is it synchronous? We'll allow this for now.
|
||||
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
|
||||
callback_return = await raw_callback_return
|
||||
else:
|
||||
callback_return = raw_callback_return
|
||||
|
||||
return callback_return
|
||||
|
||||
|
||||
class DirectServeHtmlResource(_AsyncResource):
|
||||
"""A resource that will call `self._async_on_<METHOD>` on new requests,
|
||||
formatting responses and errors as HTML.
|
||||
"""
|
||||
|
||||
# The error template to use for this resource
|
||||
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
|
||||
|
||||
def _send_response(
|
||||
self, request: SynapseRequest, code: int, response_object: Any,
|
||||
):
|
||||
"""Implements _AsyncResource._send_response
|
||||
"""
|
||||
Render the request, using an asynchronous render handler if it exists.
|
||||
# We expect to get bytes for us to write
|
||||
assert isinstance(response_object, bytes)
|
||||
html_bytes = response_object
|
||||
|
||||
respond_with_html_bytes(request, 200, html_bytes)
|
||||
|
||||
def _send_error_response(
|
||||
self, f: failure.Failure, request: SynapseRequest,
|
||||
) -> None:
|
||||
"""Implements _AsyncResource._send_error_response
|
||||
"""
|
||||
async_render_callback_name = "_async_render_" + request.method.decode("ascii")
|
||||
|
||||
# Try and get the async renderer
|
||||
callback = getattr(self, async_render_callback_name, None)
|
||||
|
||||
# No async renderer for this request method.
|
||||
if not callback:
|
||||
return super().render(request)
|
||||
|
||||
resp = trace_servlet(self.__class__.__name__)(callback)(request)
|
||||
|
||||
# If it's a coroutine, turn it into a Deferred
|
||||
if isinstance(resp, types.CoroutineType):
|
||||
defer.ensureDeferred(resp)
|
||||
|
||||
return NOT_DONE_YET
|
||||
return_html_error(f, request, self.ERROR_TEMPLATE)
|
||||
|
||||
|
||||
class StaticResource(File):
|
||||
|
|
|
@ -169,7 +169,6 @@ import contextlib
|
|||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import types
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Type
|
||||
|
||||
|
@ -182,6 +181,7 @@ from synapse.config import ConfigError
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.http.site import SynapseRequest
|
||||
|
||||
# Helper class
|
||||
|
||||
|
@ -793,48 +793,42 @@ def tag_args(func):
|
|||
return _tag_args_inner
|
||||
|
||||
|
||||
def trace_servlet(servlet_name, extract_context=False):
|
||||
"""Decorator which traces a serlet. It starts a span with some servlet specific
|
||||
tags such as the servlet_name and request information
|
||||
@contextlib.contextmanager
|
||||
def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
|
||||
"""Returns a context manager which traces a request. It starts a span
|
||||
with some servlet specific tags such as the request metrics name and
|
||||
request information.
|
||||
|
||||
Args:
|
||||
servlet_name (str): The name to be used for the span's operation_name
|
||||
extract_context (bool): Whether to attempt to extract the opentracing
|
||||
request
|
||||
extract_context: Whether to attempt to extract the opentracing
|
||||
context from the request the servlet is handling.
|
||||
|
||||
"""
|
||||
|
||||
def _trace_servlet_inner_1(func):
|
||||
if not opentracing:
|
||||
return func
|
||||
if opentracing is None:
|
||||
yield
|
||||
return
|
||||
|
||||
@wraps(func)
|
||||
async def _trace_servlet_inner(request, *args, **kwargs):
|
||||
request_tags = {
|
||||
"request_id": request.get_request_id(),
|
||||
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
||||
tags.HTTP_METHOD: request.get_method(),
|
||||
tags.HTTP_URL: request.get_redacted_uri(),
|
||||
tags.PEER_HOST_IPV6: request.getClientIP(),
|
||||
}
|
||||
request_tags = {
|
||||
"request_id": request.get_request_id(),
|
||||
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
||||
tags.HTTP_METHOD: request.get_method(),
|
||||
tags.HTTP_URL: request.get_redacted_uri(),
|
||||
tags.PEER_HOST_IPV6: request.getClientIP(),
|
||||
}
|
||||
|
||||
if extract_context:
|
||||
scope = start_active_span_from_request(
|
||||
request, servlet_name, tags=request_tags
|
||||
)
|
||||
else:
|
||||
scope = start_active_span(servlet_name, tags=request_tags)
|
||||
request_name = request.request_metrics.name
|
||||
if extract_context:
|
||||
scope = start_active_span_from_request(request, request_name, tags=request_tags)
|
||||
else:
|
||||
scope = start_active_span(request_name, tags=request_tags)
|
||||
|
||||
with scope:
|
||||
result = func(request, *args, **kwargs)
|
||||
with scope:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# We set the operation name again in case its changed (which happens
|
||||
# with JsonResource).
|
||||
scope.span.set_operation_name(request.request_metrics.name)
|
||||
|
||||
if not isinstance(result, (types.CoroutineType, defer.Deferred)):
|
||||
# Some servlets aren't async and just return results
|
||||
# directly, so we handle that here.
|
||||
return result
|
||||
|
||||
return await result
|
||||
|
||||
return _trace_servlet_inner
|
||||
|
||||
return _trace_servlet_inner_1
|
||||
scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
|
||||
|
|
|
@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication"
|
|||
|
||||
class ReplicationRestResource(JsonResource):
|
||||
def __init__(self, hs):
|
||||
JsonResource.__init__(self, hs, canonical_json=False)
|
||||
# We enable extracting jaeger contexts here as these are internal APIs.
|
||||
super().__init__(hs, canonical_json=False, extract_context=True)
|
||||
self.register_servlets(hs)
|
||||
|
||||
def register_servlets(self, hs):
|
||||
|
|
|
@ -28,11 +28,7 @@ from synapse.api.errors import (
|
|||
RequestSendFailed,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.logging.opentracing import (
|
||||
inject_active_span_byte_dict,
|
||||
trace,
|
||||
trace_servlet,
|
||||
)
|
||||
from synapse.logging.opentracing import inject_active_span_byte_dict, trace
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
@ -240,11 +236,8 @@ class ReplicationEndpoint(object):
|
|||
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
|
||||
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
|
||||
|
||||
handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
|
||||
# We don't let register paths trace this servlet using the default tracing
|
||||
# options because we wish to extract the context explicitly.
|
||||
http_server.register_paths(
|
||||
method, [pattern], handler, self.__class__.__name__, trace=False
|
||||
method, [pattern], handler, self.__class__.__name__,
|
||||
)
|
||||
|
||||
def _cached_handler(self, request, txn_id, **kwargs):
|
||||
|
|
|
@ -26,11 +26,7 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import NotFoundError, StoreError, SynapseError
|
||||
from synapse.config import ConfigError
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
respond_with_html,
|
||||
wrap_html_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeHtmlResource, respond_with_html
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.types import UserID
|
||||
|
||||
|
@ -48,7 +44,7 @@ else:
|
|||
return a == b
|
||||
|
||||
|
||||
class ConsentResource(DirectServeResource):
|
||||
class ConsentResource(DirectServeHtmlResource):
|
||||
"""A twisted Resource to display a privacy policy and gather consent to it
|
||||
|
||||
When accessed via GET, returns the privacy policy via a template.
|
||||
|
@ -119,7 +115,6 @@ class ConsentResource(DirectServeResource):
|
|||
|
||||
self._hmac_secret = hs.config.form_secret.encode("utf-8")
|
||||
|
||||
@wrap_html_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
"""
|
||||
Args:
|
||||
|
@ -160,7 +155,6 @@ class ConsentResource(DirectServeResource):
|
|||
except TemplateNotFound:
|
||||
raise NotFoundError("Unknown policy version")
|
||||
|
||||
@wrap_html_request_handler
|
||||
async def _async_render_POST(self, request):
|
||||
"""
|
||||
Args:
|
||||
|
|
|
@ -20,17 +20,13 @@ from signedjson.sign import sign_json
|
|||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.crypto.keyring import ServerKeyFetcher
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
respond_with_json_bytes,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
|
||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteKey(DirectServeResource):
|
||||
class RemoteKey(DirectServeJsonResource):
|
||||
"""HTTP resource for retreiving the TLS certificate and NACL signature
|
||||
verification keys for a collection of servers. Checks that the reported
|
||||
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
|
||||
|
@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource):
|
|||
isLeaf = True
|
||||
|
||||
def __init__(self, hs):
|
||||
super().__init__()
|
||||
|
||||
self.fetcher = ServerKeyFetcher(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||
self.config = hs.config
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
if len(request.postpath) == 1:
|
||||
(server,) = request.postpath
|
||||
|
@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource):
|
|||
|
||||
await self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_POST(self, request):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
|
|
|
@ -14,16 +14,10 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
respond_with_json,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
|
||||
|
||||
class MediaConfigResource(DirectServeResource):
|
||||
class MediaConfigResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource):
|
|||
self.auth = hs.get_auth()
|
||||
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
await self.auth.get_user_by_req(request)
|
||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
async def _async_render_OPTIONS(self, request):
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
return NOT_DONE_YET
|
||||
|
|
|
@ -15,18 +15,14 @@
|
|||
import logging
|
||||
|
||||
import synapse.http.servlet
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
set_cors_headers,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||
|
||||
from ._base import parse_media_id, respond_404
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadResource(DirectServeResource):
|
||||
class DownloadResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo):
|
||||
|
@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource):
|
|||
self.media_repo = media_repo
|
||||
self.server_name = hs.hostname
|
||||
|
||||
# this is expected by @wrap_json_request_handler
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
set_cors_headers(request)
|
||||
request.setHeader(
|
||||
|
|
|
@ -34,10 +34,9 @@ from twisted.internet.error import DNSLookupError
|
|||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
DirectServeJsonResource,
|
||||
respond_with_json,
|
||||
respond_with_json_bytes,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
|
@ -58,7 +57,7 @@ OG_TAG_NAME_MAXLEN = 50
|
|||
OG_TAG_VALUE_MAXLEN = 1000
|
||||
|
||||
|
||||
class PreviewUrlResource(DirectServeResource):
|
||||
class PreviewUrlResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo, media_storage):
|
||||
|
@ -108,11 +107,10 @@ class PreviewUrlResource(DirectServeResource):
|
|||
self._start_expire_url_cache_data, 10 * 1000
|
||||
)
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
async def _async_render_OPTIONS(self, request):
|
||||
request.setHeader(b"Allow", b"OPTIONS, GET")
|
||||
return respond_with_json(request, 200, {}, send_cors=True)
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
|
||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||
|
|
|
@ -16,11 +16,7 @@
|
|||
|
||||
import logging
|
||||
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
set_cors_headers,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
|
||||
from ._base import (
|
||||
|
@ -34,7 +30,7 @@ from ._base import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThumbnailResource(DirectServeResource):
|
||||
class ThumbnailResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo, media_storage):
|
||||
|
@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource):
|
|||
self.media_storage = media_storage
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.server_name = hs.hostname
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
set_cors_headers(request)
|
||||
server_name, media_id, _ = parse_media_id(request)
|
||||
|
|
|
@ -15,20 +15,14 @@
|
|||
|
||||
import logging
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
respond_with_json,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UploadResource(DirectServeResource):
|
||||
class UploadResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo):
|
||||
|
@ -43,11 +37,9 @@ class UploadResource(DirectServeResource):
|
|||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
async def _async_render_OPTIONS(self, request):
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_POST(self, request):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
|
|
|
@ -14,18 +14,17 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from synapse.http.server import DirectServeResource, wrap_html_request_handler
|
||||
from synapse.http.server import DirectServeHtmlResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OIDCCallbackResource(DirectServeResource):
|
||||
class OIDCCallbackResource(DirectServeHtmlResource):
|
||||
isLeaf = 1
|
||||
|
||||
def __init__(self, hs):
|
||||
super().__init__()
|
||||
self._oidc_handler = hs.get_oidc_handler()
|
||||
|
||||
@wrap_html_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
return await self._oidc_handler.handle_oidc_callback(request)
|
||||
await self._oidc_handler.handle_oidc_callback(request)
|
||||
|
|
|
@ -16,10 +16,10 @@
|
|||
from twisted.python import failure
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import DirectServeResource, return_html_error
|
||||
from synapse.http.server import DirectServeHtmlResource, return_html_error
|
||||
|
||||
|
||||
class SAML2ResponseResource(DirectServeResource):
|
||||
class SAML2ResponseResource(DirectServeHtmlResource):
|
||||
"""A Twisted web resource which handles the SAML response"""
|
||||
|
||||
isLeaf = 1
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from synapse.http.additional_resource import AdditionalResource
|
||||
from synapse.http.server import respond_with_json
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class _AsyncTestCustomEndpoint:
|
||||
def __init__(self, config, module_api):
|
||||
pass
|
||||
|
||||
async def handle_request(self, request):
|
||||
respond_with_json(request, 200, {"some_key": "some_value_async"})
|
||||
|
||||
|
||||
class _SyncTestCustomEndpoint:
|
||||
def __init__(self, config, module_api):
|
||||
pass
|
||||
|
||||
async def handle_request(self, request):
|
||||
respond_with_json(request, 200, {"some_key": "some_value_sync"})
|
||||
|
||||
|
||||
class AdditionalResourceTests(HomeserverTestCase):
|
||||
"""Very basic tests that `AdditionalResource` works correctly with sync
|
||||
and async handlers.
|
||||
"""
|
||||
|
||||
def test_async(self):
|
||||
handler = _AsyncTestCustomEndpoint({}, None).handle_request
|
||||
self.resource = AdditionalResource(self.hs, handler)
|
||||
|
||||
request, channel = self.make_request("GET", "/")
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(request.code, 200)
|
||||
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
|
||||
|
||||
def test_sync(self):
|
||||
handler = _SyncTestCustomEndpoint({}, None).handle_request
|
||||
self.resource = AdditionalResource(self.hs, handler)
|
||||
|
||||
request, channel = self.make_request("GET", "/")
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(request.code, 200)
|
||||
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
|
|
@ -24,12 +24,7 @@ from twisted.web.server import NOT_DONE_YET
|
|||
|
||||
from synapse.api.errors import Codes, RedirectException, SynapseError
|
||||
from synapse.config.server import parse_listener_def
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
JsonResource,
|
||||
OptionsResource,
|
||||
wrap_html_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
|
||||
from synapse.http.site import SynapseSite, logger
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.util import Clock
|
||||
|
@ -256,12 +251,11 @@ class OptionsResourceTests(unittest.TestCase):
|
|||
|
||||
|
||||
class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
||||
class TestResource(DirectServeResource):
|
||||
class TestResource(DirectServeHtmlResource):
|
||||
callback = None
|
||||
|
||||
@wrap_html_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
return await self.callback(request)
|
||||
await self.callback(request)
|
||||
|
||||
def setUp(self):
|
||||
self.reactor = ThreadedMemoryReactorClock()
|
||||
|
|
Loading…
Reference in New Issue