107 lines
3.4 KiB
Python
107 lines
3.4 KiB
Python
# Copyright 2014-2016 OpenMarket Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""This module contains base REST classes for constructing client v1 servlets.
|
|
"""
|
|
import logging
|
|
import re
|
|
from typing import Any, Awaitable, Callable, Iterable, Pattern, Tuple, TypeVar, cast
|
|
|
|
from synapse.api.errors import InteractiveAuthIncompleteError
|
|
from synapse.api.urls import CLIENT_API_PREFIX
|
|
from synapse.types import JsonDict, StrCollection
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def client_patterns(
|
|
path_regex: str,
|
|
releases: StrCollection = ("r0", "v3"),
|
|
unstable: bool = True,
|
|
v1: bool = False,
|
|
) -> Iterable[Pattern]:
|
|
"""Creates a regex compiled client path with the correct client path
|
|
prefix.
|
|
|
|
Args:
|
|
path_regex: The regex string to match. This should NOT have a ^
|
|
as this will be prefixed.
|
|
releases: An iterable of releases to include this endpoint under.
|
|
unstable: If true, include this endpoint under the "unstable" prefix.
|
|
v1: If true, include this endpoint under the "api/v1" prefix.
|
|
Returns:
|
|
An iterable of patterns.
|
|
"""
|
|
versions = []
|
|
|
|
if v1:
|
|
versions.append("api/v1")
|
|
versions.extend(releases)
|
|
if unstable:
|
|
versions.append("unstable")
|
|
|
|
if len(versions) == 1:
|
|
versions_str = versions[0]
|
|
elif len(versions) > 1:
|
|
versions_str = "(" + "|".join(versions) + ")"
|
|
else:
|
|
raise RuntimeError("Must have at least one version for a URL")
|
|
|
|
return [re.compile("^" + CLIENT_API_PREFIX + "/" + versions_str + path_regex)]
|
|
|
|
|
|
def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None:
|
|
"""
|
|
Enforces a maximum limit of a timeline query.
|
|
|
|
Params:
|
|
filter_json: The timeline query to modify.
|
|
filter_timeline_limit: The maximum limit to allow, passing -1 will
|
|
disable enforcing a maximum limit.
|
|
"""
|
|
if filter_timeline_limit < 0:
|
|
return # no upper limits
|
|
timeline = filter_json.get("room", {}).get("timeline", {})
|
|
if "limit" in timeline:
|
|
filter_json["room"]["timeline"]["limit"] = min(
|
|
filter_json["room"]["timeline"]["limit"], filter_timeline_limit
|
|
)
|
|
|
|
|
|
C = TypeVar("C", bound=Callable[..., Awaitable[Tuple[int, JsonDict]]])
|
|
|
|
|
|
def interactive_auth_handler(orig: C) -> C:
|
|
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
|
|
|
|
Takes a on_POST method which returns an Awaitable (errcode, body) response
|
|
and adds exception handling to turn a InteractiveAuthIncompleteError into
|
|
a 401 response.
|
|
|
|
Normal usage is:
|
|
|
|
@interactive_auth_handler
|
|
async def on_POST(self, request):
|
|
# ...
|
|
await self.auth_handler.check_auth
|
|
"""
|
|
|
|
async def wrapped(*args: Any, **kwargs: Any) -> Tuple[int, JsonDict]:
|
|
try:
|
|
return await orig(*args, **kwargs)
|
|
except InteractiveAuthIncompleteError as e:
|
|
return 401, e.result
|
|
|
|
return cast(C, wrapped)
|