Merge different Resource implementation classes (#7732)

pull/7789/head
Erik Johnston 2020-07-03 19:02:19 +01:00 committed by GitHub
parent 21a212f8e5
commit 5cdca53aa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 330 additions and 326 deletions

1
changelog.d/7732.bugfix Normal file
View File

@ -0,0 +1 @@
Fix "Tried to close a non-active scope!" error messages when opentracing is enabled.

View File

@ -361,11 +361,7 @@ class BaseFederationServlet(object):
continue
server.register_paths(
method,
(pattern,),
self._wrap(code),
self.__class__.__name__,
trace=False,
method, (pattern,), self._wrap(code), self.__class__.__name__,
)

View File

@ -13,13 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.http.server import wrap_json_request_handler
from synapse.http.server import DirectServeJsonResource
class AdditionalResource(Resource):
class AdditionalResource(DirectServeJsonResource):
"""Resource wrapper for additional_resources
If the user has configured additional_resources, we need to wrap the
@ -41,16 +38,10 @@ class AdditionalResource(Resource):
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
function to be called to handle the request.
"""
Resource.__init__(self)
super().__init__()
self._handler = handler
# required by the request_handler wrapper
self.clock = hs.get_clock()
def render(self, request):
self._async_render(request)
return NOT_DONE_YET
@wrap_json_request_handler
def _async_render(self, request):
# Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not.
return self._handler(request)

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import collections
import html
import logging
@ -21,7 +22,7 @@ import types
import urllib
from http import HTTPStatus
from io import BytesIO
from typing import Awaitable, Callable, TypeVar, Union
from typing import Any, Callable, Dict, Tuple, Union
import jinja2
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@ -62,99 +63,43 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
"""
def wrap_json_request_handler(h):
"""Wraps a request handler method with exception handling.
Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)",
where "request" must be a SynapseRequest.
The handler must return a deferred or a coroutine. If the deferred succeeds
we assume that a response has been sent. If the deferred fails with a SynapseError we use
it to send a JSON response with the appropriate HTTP reponse code. If the
deferred fails with any other type of error we send a 500 reponse.
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients.
"""
async def wrapped_request_handler(self, request):
try:
await h(self, request)
except SynapseError as e:
code = e.code
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
if f.check(SynapseError):
error_code = f.value.code
error_dict = f.value.error_dict()
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
if request.startedWriting:
if request.transport:
try:
request.transport.abortConnection()
except Exception:
# abortConnection throws if the connection is already closed
pass
else:
respond_with_json(
request,
code,
e.error_dict(),
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
else:
error_code = 500
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
except Exception:
# failure.Failure() fishes the original Failure out
# of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
logger.error(
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
)
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
if request.startedWriting:
if request.transport:
try:
request.transport.abortConnection()
except Exception:
# abortConnection throws if the connection is already closed
pass
else:
respond_with_json(
request,
500,
{"error": "Internal server error", "errcode": Codes.UNKNOWN},
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
logger.error(
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
)
return wrap_async_request_handler(wrapped_request_handler)
TV = TypeVar("TV")
def wrap_html_request_handler(
h: Callable[[TV, SynapseRequest], Awaitable]
) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
"""Wraps a request handler method with exception handling.
Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)",
where "request" must be a SynapseRequest.
"""
async def wrapped_request_handler(self, request):
try:
await h(self, request)
except Exception:
f = failure.Failure()
return_html_error(f, request, HTML_ERROR_TEMPLATE)
return wrap_async_request_handler(wrapped_request_handler)
# Only respond with an error response if we haven't already started writing,
# otherwise lets just kill the connection
if request.startedWriting:
if request.transport:
try:
request.transport.abortConnection()
except Exception:
# abortConnection throws if the connection is already closed
pass
else:
respond_with_json(
request,
error_code,
error_dict,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
def return_html_error(
@ -249,7 +194,113 @@ class HttpServer(object):
pass
class JsonResource(HttpServer, resource.Resource):
class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
"""Base class for resources that have async handlers.
Sub classes can either implement `_async_render_<METHOD>` to handle
requests by method, or override `_async_render` to handle all requests.
Args:
extract_context: Whether to attempt to extract the opentracing
context from the request the servlet is handling.
"""
def __init__(self, extract_context=False):
super().__init__()
self._extract_context = extract_context
def render(self, request):
""" This gets called by twisted every time someone sends us a request.
"""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET
@wrap_async_request_handler
async def _async_render_wrapper(self, request):
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
try:
request.request_metrics.name = self.__class__.__name__
with trace_servlet(request, self._extract_context):
callback_return = await self._async_render(request)
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
except Exception:
# failure.Failure() fishes the original Failure out
# of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
self._send_error_response(f, request)
async def _async_render(self, request):
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overriden in sub classes for
different routing.
"""
method_handler = getattr(
self, "_async_render_%s" % (request.method.decode("ascii"),), None
)
if method_handler:
raw_callback_return = method_handler(request)
# Is it synchronous? We'll allow this for now.
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return
return callback_return
_unrecognised_request_handler(request)
@abc.abstractmethod
def _send_response(
self, request: SynapseRequest, code: int, response_object: Any,
) -> None:
raise NotImplementedError()
@abc.abstractmethod
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
) -> None:
raise NotImplementedError()
class DirectServeJsonResource(_AsyncResource):
"""A resource that will call `self._async_on_<METHOD>` on new requests,
formatting responses and errors as JSON.
"""
def _send_response(
self, request, code, response_object,
):
"""Implements _AsyncResource._send_response
"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request,
code,
response_object,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
)
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response
"""
return_json_error(f, request)
class JsonResource(DirectServeJsonResource):
""" This implements the HttpServer interface and provides JSON support for
Resources.
@ -269,17 +320,15 @@ class JsonResource(HttpServer, resource.Resource):
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)
def __init__(self, hs, canonical_json=True):
resource.Resource.__init__(self)
def __init__(self, hs, canonical_json=True, extract_context=False):
super().__init__(extract_context)
self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
self.hs = hs
def register_paths(
self, method, path_patterns, callback, servlet_classname, trace=True
):
def register_paths(self, method, path_patterns, callback, servlet_classname):
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
@ -295,74 +344,23 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
trace (bool): Whether we should start a span to trace the servlet.
"""
method = method.encode("utf-8") # method is bytes on py3
if trace:
# We don't extract the context from the servlet because we can't
# trust the sender
callback = trace_servlet(servlet_classname)(callback)
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback, servlet_classname)
)
def render(self, request):
""" This gets called by twisted every time someone sends us a request.
"""
defer.ensureDeferred(self._async_render(request))
return NOT_DONE_YET
@wrap_json_request_handler
async def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
"""
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
# Make sure we have a name for this handler in prometheus.
request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
kwargs = intern_dict(
{
name: urllib.parse.unquote(value) if value else value
for name, value in group_dict.items()
}
)
callback_return = callback(request, **kwargs)
# Is it synchronous? We'll allow this for now.
if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await callback_return
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
def _get_handler_for_request(self, request):
"""Finds a callback method to handle the given request
Args:
request (twisted.web.http.Request):
def _get_handler_for_request(
self, request: SynapseRequest
) -> Tuple[Callable, str, Dict[str, str]]:
"""Finds a callback method to handle the given request.
Returns:
Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
label to use for that method in prometheus metrics, and the
dict mapping keys to path components as specified in the
handler's path match regexp.
The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either
None, or a tuple of (http code, response body).
A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback
"""
request_path = request.path.decode("ascii")
@ -377,42 +375,59 @@ class JsonResource(HttpServer, resource.Resource):
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, "unrecognised_request_handler", {}
def _send_response(
self, request, code, response_json_object, response_code_message=None
):
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request,
code,
response_json_object,
send_cors=True,
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
async def _async_render(self, request):
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
# Make sure we have an appopriate name for this handler in prometheus
# (rather than the default of JsonResource).
request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
kwargs = intern_dict(
{
name: urllib.parse.unquote(value) if value else value
for name, value in group_dict.items()
}
)
raw_callback_return = callback(request, **kwargs)
class DirectServeResource(resource.Resource):
def render(self, request):
# Is it synchronous? We'll allow this for now.
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return
return callback_return
class DirectServeHtmlResource(_AsyncResource):
"""A resource that will call `self._async_on_<METHOD>` on new requests,
formatting responses and errors as HTML.
"""
# The error template to use for this resource
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
def _send_response(
self, request: SynapseRequest, code: int, response_object: Any,
):
"""Implements _AsyncResource._send_response
"""
Render the request, using an asynchronous render handler if it exists.
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
html_bytes = response_object
respond_with_html_bytes(request, 200, html_bytes)
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response
"""
async_render_callback_name = "_async_render_" + request.method.decode("ascii")
# Try and get the async renderer
callback = getattr(self, async_render_callback_name, None)
# No async renderer for this request method.
if not callback:
return super().render(request)
resp = trace_servlet(self.__class__.__name__)(callback)(request)
# If it's a coroutine, turn it into a Deferred
if isinstance(resp, types.CoroutineType):
defer.ensureDeferred(resp)
return NOT_DONE_YET
return_html_error(f, request, self.ERROR_TEMPLATE)
class StaticResource(File):

View File

@ -169,7 +169,6 @@ import contextlib
import inspect
import logging
import re
import types
from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Type
@ -182,6 +181,7 @@ from synapse.config import ConfigError
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.http.site import SynapseRequest
# Helper class
@ -793,48 +793,42 @@ def tag_args(func):
return _tag_args_inner
def trace_servlet(servlet_name, extract_context=False):
"""Decorator which traces a serlet. It starts a span with some servlet specific
tags such as the servlet_name and request information
@contextlib.contextmanager
def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
"""Returns a context manager which traces a request. It starts a span
with some servlet specific tags such as the request metrics name and
request information.
Args:
servlet_name (str): The name to be used for the span's operation_name
extract_context (bool): Whether to attempt to extract the opentracing
request
extract_context: Whether to attempt to extract the opentracing
context from the request the servlet is handling.
"""
def _trace_servlet_inner_1(func):
if not opentracing:
return func
if opentracing is None:
yield
return
@wraps(func)
async def _trace_servlet_inner(request, *args, **kwargs):
request_tags = {
"request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
tags.PEER_HOST_IPV6: request.getClientIP(),
}
request_tags = {
"request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
tags.PEER_HOST_IPV6: request.getClientIP(),
}
if extract_context:
scope = start_active_span_from_request(
request, servlet_name, tags=request_tags
)
else:
scope = start_active_span(servlet_name, tags=request_tags)
request_name = request.request_metrics.name
if extract_context:
scope = start_active_span_from_request(request, request_name, tags=request_tags)
else:
scope = start_active_span(request_name, tags=request_tags)
with scope:
result = func(request, *args, **kwargs)
with scope:
try:
yield
finally:
# We set the operation name again in case its changed (which happens
# with JsonResource).
scope.span.set_operation_name(request.request_metrics.name)
if not isinstance(result, (types.CoroutineType, defer.Deferred)):
# Some servlets aren't async and just return results
# directly, so we handle that here.
return result
return await result
return _trace_servlet_inner
return _trace_servlet_inner_1
scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)

View File

@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication"
class ReplicationRestResource(JsonResource):
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
# We enable extracting jaeger contexts here as these are internal APIs.
super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs)
def register_servlets(self, hs):

View File

@ -28,11 +28,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
from synapse.logging.opentracing import (
inject_active_span_byte_dict,
trace,
trace_servlet,
)
from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@ -240,11 +236,8 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
# We don't let register paths trace this servlet using the default tracing
# options because we wish to extract the context explicitly.
http_server.register_paths(
method, [pattern], handler, self.__class__.__name__, trace=False
method, [pattern], handler, self.__class__.__name__,
)
def _cached_handler(self, request, txn_id, **kwargs):

View File

@ -26,11 +26,7 @@ from twisted.internet import defer
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import (
DirectServeResource,
respond_with_html,
wrap_html_request_handler,
)
from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_string
from synapse.types import UserID
@ -48,7 +44,7 @@ else:
return a == b
class ConsentResource(DirectServeResource):
class ConsentResource(DirectServeHtmlResource):
"""A twisted Resource to display a privacy policy and gather consent to it
When accessed via GET, returns the privacy policy via a template.
@ -119,7 +115,6 @@ class ConsentResource(DirectServeResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8")
@wrap_html_request_handler
async def _async_render_GET(self, request):
"""
Args:
@ -160,7 +155,6 @@ class ConsentResource(DirectServeResource):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")
@wrap_html_request_handler
async def _async_render_POST(self, request):
"""
Args:

View File

@ -20,17 +20,13 @@ from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import (
DirectServeResource,
respond_with_json_bytes,
wrap_json_request_handler,
)
from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
from synapse.http.servlet import parse_integer, parse_json_object_from_request
logger = logging.getLogger(__name__)
class RemoteKey(DirectServeResource):
class RemoteKey(DirectServeJsonResource):
"""HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource):
isLeaf = True
def __init__(self, hs):
super().__init__()
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config
@wrap_json_request_handler
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
(server,) = request.postpath
@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
@wrap_json_request_handler
async def _async_render_POST(self, request):
content = parse_json_object_from_request(request)

View File

@ -14,16 +14,10 @@
# limitations under the License.
#
from twisted.web.server import NOT_DONE_YET
from synapse.http.server import (
DirectServeResource,
respond_with_json,
wrap_json_request_handler,
)
from synapse.http.server import DirectServeJsonResource, respond_with_json
class MediaConfigResource(DirectServeResource):
class MediaConfigResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs):
@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource):
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
@wrap_json_request_handler
async def _async_render_GET(self, request):
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
def render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET

View File

@ -15,18 +15,14 @@
import logging
import synapse.http.servlet
from synapse.http.server import (
DirectServeResource,
set_cors_headers,
wrap_json_request_handler,
)
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from ._base import parse_media_id, respond_404
logger = logging.getLogger(__name__)
class DownloadResource(DirectServeResource):
class DownloadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource):
self.media_repo = media_repo
self.server_name = hs.hostname
# this is expected by @wrap_json_request_handler
self.clock = hs.get_clock()
@wrap_json_request_handler
async def _async_render_GET(self, request):
set_cors_headers(request)
request.setHeader(

View File

@ -34,10 +34,9 @@ from twisted.internet.error import DNSLookupError
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
DirectServeResource,
DirectServeJsonResource,
respond_with_json,
respond_with_json_bytes,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
@ -58,7 +57,7 @@ OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
class PreviewUrlResource(DirectServeResource):
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
@ -108,11 +107,10 @@ class PreviewUrlResource(DirectServeResource):
self._start_expire_url_cache_data, 10 * 1000
)
def render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request):
request.setHeader(b"Allow", b"OPTIONS, GET")
return respond_with_json(request, 200, {}, send_cors=True)
respond_with_json(request, 200, {}, send_cors=True)
@wrap_json_request_handler
async def _async_render_GET(self, request):
# XXX: if get_user_by_req fails, what should we do in an async render?

View File

@ -16,11 +16,7 @@
import logging
from synapse.http.server import (
DirectServeResource,
set_cors_headers,
wrap_json_request_handler,
)
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
from ._base import (
@ -34,7 +30,7 @@ from ._base import (
logger = logging.getLogger(__name__)
class ThumbnailResource(DirectServeResource):
class ThumbnailResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource):
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
self.clock = hs.get_clock()
@wrap_json_request_handler
async def _async_render_GET(self, request):
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)

View File

@ -15,20 +15,14 @@
import logging
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import (
DirectServeResource,
respond_with_json,
wrap_json_request_handler,
)
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__)
class UploadResource(DirectServeResource):
class UploadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
@ -43,11 +37,9 @@ class UploadResource(DirectServeResource):
self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock()
def render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
@wrap_json_request_handler
async def _async_render_POST(self, request):
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have

View File

@ -14,18 +14,17 @@
# limitations under the License.
import logging
from synapse.http.server import DirectServeResource, wrap_html_request_handler
from synapse.http.server import DirectServeHtmlResource
logger = logging.getLogger(__name__)
class OIDCCallbackResource(DirectServeResource):
class OIDCCallbackResource(DirectServeHtmlResource):
isLeaf = 1
def __init__(self, hs):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
@wrap_html_request_handler
async def _async_render_GET(self, request):
return await self._oidc_handler.handle_oidc_callback(request)
await self._oidc_handler.handle_oidc_callback(request)

View File

@ -16,10 +16,10 @@
from twisted.python import failure
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeResource, return_html_error
from synapse.http.server import DirectServeHtmlResource, return_html_error
class SAML2ResponseResource(DirectServeResource):
class SAML2ResponseResource(DirectServeHtmlResource):
"""A Twisted web resource which handles the SAML response"""
isLeaf = 1

View File

@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json
from tests.unittest import HomeserverTestCase
class _AsyncTestCustomEndpoint:
def __init__(self, config, module_api):
pass
async def handle_request(self, request):
respond_with_json(request, 200, {"some_key": "some_value_async"})
class _SyncTestCustomEndpoint:
def __init__(self, config, module_api):
pass
async def handle_request(self, request):
respond_with_json(request, 200, {"some_key": "some_value_sync"})
class AdditionalResourceTests(HomeserverTestCase):
"""Very basic tests that `AdditionalResource` works correctly with sync
and async handlers.
"""
def test_async(self):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/")
self.render(request)
self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler)
request, channel = self.make_request("GET", "/")
self.render(request)
self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})

View File

@ -24,12 +24,7 @@ from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.config.server import parse_listener_def
from synapse.http.server import (
DirectServeResource,
JsonResource,
OptionsResource,
wrap_html_request_handler,
)
from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
from synapse.http.site import SynapseSite, logger
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
@ -256,12 +251,11 @@ class OptionsResourceTests(unittest.TestCase):
class WrapHtmlRequestHandlerTests(unittest.TestCase):
class TestResource(DirectServeResource):
class TestResource(DirectServeHtmlResource):
callback = None
@wrap_html_request_handler
async def _async_render_GET(self, request):
return await self.callback(request)
await self.callback(request)
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()