Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

matrix-org-hotfixes-identity
Richard van der Hoff 2019-06-04 11:59:55 +01:00
commit e91a68ef3a
66 changed files with 791 additions and 573 deletions

1
changelog.d/5226.misc Normal file
View File

@ -0,0 +1 @@
The base classes for the v1 and v2_alpha REST APIs have been unified.

1
changelog.d/5276.feature Normal file
View File

@ -0,0 +1 @@
Allow configuring a range for the account validity startup job.

1
changelog.d/5296.misc Normal file
View File

@ -0,0 +1 @@
Refactor keyring.VerifyKeyRequest to use attr.s.

1
changelog.d/5299.misc Normal file
View File

@ -0,0 +1 @@
Rewrite get_server_verify_keys, again.

1
changelog.d/5300.bugfix Normal file
View File

@ -0,0 +1 @@
Fix noisy 'no key for server' logs.

1
changelog.d/5307.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug where a notary server would sometimes forget old keys.

1
changelog.d/5321.bugfix Normal file
View File

@ -0,0 +1 @@
Ensure that we have an up-to-date copy of the signing key when validating incoming federation requests.

1
changelog.d/5328.misc Normal file
View File

@ -0,0 +1 @@
The base classes for the v1 and v2_alpha REST APIs have been unified.

1
changelog.d/5332.misc Normal file
View File

@ -0,0 +1 @@
Improve docstrings on MatrixFederationClient.

1
changelog.d/5333.bugfix Normal file
View File

@ -0,0 +1 @@
Fix various problems which made the signing-key notary server time out for some requests.

1
changelog.d/5334.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug which would make certain operations (such as room joins) block for 20 minutes while attemoting to fetch verification keys.

1
changelog.d/5335.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug where we could rapidly mark a server as unreachable even though it was only down for a few minutes.

View File

@ -763,7 +763,9 @@ uploads_path: "DATADIR/uploads"
# This means that, if a validity period is set, and Synapse is restarted (it will # This means that, if a validity period is set, and Synapse is restarted (it will
# then derive an expiration date from the current validity period), and some time # then derive an expiration date from the current validity period), and some time
# after that the validity period changes and Synapse is restarted, the users' # after that the validity period changes and Synapse is restarted, the users'
# expiration dates won't be updated unless their account is manually renewed. # expiration dates won't be updated unless their account is manually renewed. This
# date will be randomly selected within a range [now + period - d ; now + period],
# where d is equal to 10% of the validity period.
# #
#account_validity: #account_validity:
# enabled: True # enabled: True

View File

@ -20,9 +20,7 @@ class CallVisitor(ast.NodeVisitor):
else: else:
return return
if name == "client_path_patterns": if name == "client_patterns":
PATTERNS_V1.append(node.args[0].s)
elif name == "client_v2_patterns":
PATTERNS_V2.append(node.args[0].s) PATTERNS_V2.append(node.args[0].s)

View File

@ -37,8 +37,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.client.v2_alpha._base import client_v2_patterns
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
@ -49,11 +48,11 @@ from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.frontend_proxy") logger = logging.getLogger("synapse.app.frontend_proxy")
class PresenceStatusStubServlet(ClientV1RestServlet): class PresenceStatusStubServlet(RestServlet):
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
def __init__(self, hs): def __init__(self, hs):
super(PresenceStatusStubServlet, self).__init__(hs) super(PresenceStatusStubServlet, self).__init__()
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.main_uri = hs.config.worker_main_http_uri self.main_uri = hs.config.worker_main_http_uri
@ -84,7 +83,7 @@ class PresenceStatusStubServlet(ClientV1RestServlet):
class KeyUploadServlet(RestServlet): class KeyUploadServlet(RestServlet):
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs):
""" """

View File

@ -39,6 +39,8 @@ class AccountValidityConfig(Config):
else: else:
self.renew_email_subject = "Renew your %(app)s account" self.renew_email_subject = "Renew your %(app)s account"
self.startup_job_max_delta = self.period * 10. / 100.
if self.renew_by_email_enabled and "public_baseurl" not in synapse_config: if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
raise ConfigError("Can't send renewal emails without 'public_baseurl'") raise ConfigError("Can't send renewal emails without 'public_baseurl'")
@ -129,7 +131,9 @@ class RegistrationConfig(Config):
# This means that, if a validity period is set, and Synapse is restarted (it will # This means that, if a validity period is set, and Synapse is restarted (it will
# then derive an expiration date from the current validity period), and some time # then derive an expiration date from the current validity period), and some time
# after that the validity period changes and Synapse is restarted, the users' # after that the validity period changes and Synapse is restarted, the users'
# expiration dates won't be updated unless their account is manually renewed. # expiration dates won't be updated unless their account is manually renewed. This
# date will be randomly selected within a range [now + period - d ; now + period],
# where d is equal to 10%% of the validity period.
# #
#account_validity: #account_validity:
# enabled: True # enabled: True

View File

@ -15,12 +15,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import namedtuple from collections import defaultdict
import six import six
from six import raise_from from six import raise_from
from six.moves import urllib from six.moves import urllib
import attr
from signedjson.key import ( from signedjson.key import (
decode_verify_key_bytes, decode_verify_key_bytes,
encode_verify_key_base64, encode_verify_key_base64,
@ -45,6 +46,7 @@ from synapse.api.errors import (
) )
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError from synapse.util import logcontext, unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.logcontext import ( from synapse.util.logcontext import (
LoggingContext, LoggingContext,
PreserveLoggingContext, PreserveLoggingContext,
@ -57,22 +59,36 @@ from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VerifyKeyRequest = namedtuple( @attr.s(slots=True, cmp=False)
"VerifyRequest", ("server_name", "key_ids", "json_object", "deferred") class VerifyKeyRequest(object):
) """
""" A request for a verify key to verify a JSON object.
A request for a verify key to verify a JSON object.
Attributes: Attributes:
server_name(str): The name of the server to verify against. server_name(str): The name of the server to verify against.
key_ids(set(str)): The set of key_ids to that could be used to verify the
key_ids(set[str]): The set of key_ids to that could be used to verify the
JSON object JSON object
json_object(dict): The JSON object to verify. json_object(dict): The JSON object to verify.
minimum_valid_until_ts (int): time at which we require the signing key to
be valid. (0 implies we don't care)
deferred(Deferred[str, str, nacl.signing.VerifyKey]): deferred(Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no a verify key has been fetched. The deferreds' callbacks are run with no
logcontext. logcontext.
"""
If we are unable to find a key which satisfies the request, the deferred
errbacks with an M_UNAUTHORIZED SynapseError.
"""
server_name = attr.ib()
key_ids = attr.ib()
json_object = attr.ib()
minimum_valid_until_ts = attr.ib()
deferred = attr.ib(default=attr.Factory(defer.Deferred))
class KeyLookupError(ValueError): class KeyLookupError(ValueError):
@ -80,14 +96,16 @@ class KeyLookupError(ValueError):
class Keyring(object): class Keyring(object):
def __init__(self, hs): def __init__(self, hs, key_fetchers=None):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._key_fetchers = ( if key_fetchers is None:
key_fetchers = (
StoreKeyFetcher(hs), StoreKeyFetcher(hs),
PerspectivesKeyFetcher(hs), PerspectivesKeyFetcher(hs),
ServerKeyFetcher(hs), ServerKeyFetcher(hs),
) )
self._key_fetchers = key_fetchers
# map from server name to Deferred. Has an entry for each server with # map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download # an ongoing key download; the Deferred completes once the download
@ -96,9 +114,25 @@ class Keyring(object):
# These are regular, logcontext-agnostic Deferreds. # These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {} self.key_downloads = {}
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object, validity_time):
"""Verify that a JSON object has been signed by a given server
Args:
server_name (str): name of the server which must have signed this object
json_object (dict): object to be checked
validity_time (int): timestamp at which we require the signing key to
be valid. (0 implies we don't care)
Returns:
Deferred[None]: completes if the the object was correctly signed, otherwise
errbacks with an error
"""
req = server_name, json_object, validity_time
return logcontext.make_deferred_yieldable( return logcontext.make_deferred_yieldable(
self.verify_json_objects_for_server([(server_name, json_object)])[0] self.verify_json_objects_for_server((req,))[0]
) )
def verify_json_objects_for_server(self, server_and_json): def verify_json_objects_for_server(self, server_and_json):
@ -106,10 +140,12 @@ class Keyring(object):
necessary. necessary.
Args: Args:
server_and_json (list): List of pairs of (server_name, json_object) server_and_json (iterable[Tuple[str, dict, int]):
Iterable of triplets of (server_name, json_object, validity_time)
validity_time is a timestamp at which the signing key must be valid.
Returns: Returns:
List<Deferred>: for each input pair, a deferred indicating success List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel server_name. The deferreds run their callbacks in the sentinel
logcontext. logcontext.
@ -118,12 +154,12 @@ class Keyring(object):
verify_requests = [] verify_requests = []
handle = preserve_fn(_handle_key_deferred) handle = preserve_fn(_handle_key_deferred)
def process(server_name, json_object): def process(server_name, json_object, validity_time):
"""Process an entry in the request list """Process an entry in the request list
Given a (server_name, json_object) pair from the request list, Given a (server_name, json_object, validity_time) triplet from the request
adds a key request to verify_requests, and returns a deferred which will list, adds a key request to verify_requests, and returns a deferred which
complete or fail (in the sentinel context) when verification completes. will complete or fail (in the sentinel context) when verification completes.
""" """
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
@ -134,11 +170,16 @@ class Keyring(object):
) )
) )
logger.debug("Verifying for %s with key_ids %s", server_name, key_ids) logger.debug(
"Verifying for %s with key_ids %s, min_validity %i",
server_name,
key_ids,
validity_time,
)
# add the key request to the queue, but don't start it off yet. # add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest( verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, defer.Deferred() server_name, key_ids, json_object, validity_time
) )
verify_requests.append(verify_request) verify_requests.append(verify_request)
@ -150,8 +191,8 @@ class Keyring(object):
return handle(verify_request) return handle(verify_request)
results = [ results = [
process(server_name, json_object) process(server_name, json_object, validity_time)
for server_name, json_object in server_and_json for server_name, json_object, validity_time in server_and_json
] ]
if verify_requests: if verify_requests:
@ -270,29 +311,78 @@ class Keyring(object):
verify_requests (list[VerifyKeyRequest]): list of verify requests verify_requests (list[VerifyKeyRequest]): list of verify requests
""" """
remaining_requests = set(
(rq for rq in verify_requests if not rq.deferred.called)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_iterations(): def do_iterations():
with Measure(self.clock, "get_server_verify_keys"): with Measure(self.clock, "get_server_verify_keys"):
# dict[str, set(str)]: keys to fetch for each server for f in self._key_fetchers:
missing_keys = {} if not remaining_requests:
for verify_request in verify_requests: return
missing_keys.setdefault(verify_request.server_name, set()).update( yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
verify_request.key_ids
# look for any requests which weren't satisfied
with PreserveLoggingContext():
for verify_request in remaining_requests:
verify_request.deferred.errback(
SynapseError(
401,
"No key for %s with ids in %s (min_validity %i)"
% (
verify_request.server_name,
verify_request.key_ids,
verify_request.minimum_valid_until_ts,
),
Codes.UNAUTHORIZED,
)
) )
for f in self._key_fetchers: def on_err(err):
results = yield f.get_keys(missing_keys.items()) # we don't really expect to get here, because any errors should already
# have been caught and logged. But if we do, let's log the error and make
# sure that all of the deferreds are resolved.
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
with PreserveLoggingContext():
for verify_request in remaining_requests:
if not verify_request.deferred.called:
verify_request.deferred.errback(err)
# We now need to figure out which verify requests we have keys run_in_background(do_iterations).addErrback(on_err)
# for and which we don't
missing_keys = {}
requests_missing_keys = []
for verify_request in verify_requests:
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
@defer.inlineCallbacks
def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
"""Use a key fetcher to attempt to satisfy some key requests
Args:
fetcher (KeyFetcher): fetcher to use to fetch the keys
remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
Any successfully-completed requests will be removed from the list.
"""
# dict[str, dict[str, int]]: keys to fetch.
# server_name -> key_id -> min_valid_ts
missing_keys = defaultdict(dict)
for verify_request in remaining_requests:
# any completed requests should already have been removed
assert not verify_request.deferred.called
keys_for_server = missing_keys[verify_request.server_name]
for key_id in verify_request.key_ids:
# If we have several requests for the same key, then we only need to
# request that key once, but we should do so with the greatest
# min_valid_until_ts of the requests, so that we can satisfy all of
# the requests.
keys_for_server[key_id] = max(
keys_for_server.get(key_id, -1),
verify_request.minimum_valid_until_ts
)
results = yield fetcher.get_keys(missing_keys)
completed = list()
for verify_request in remaining_requests:
server_name = verify_request.server_name server_name = verify_request.server_name
# see if any of the keys we got this time are sufficient to # see if any of the keys we got this time are sufficient to
@ -300,54 +390,33 @@ class Keyring(object):
result_keys = results.get(server_name, {}) result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids: for key_id in verify_request.key_ids:
fetch_key_result = result_keys.get(key_id) fetch_key_result = result_keys.get(key_id)
if fetch_key_result: if not fetch_key_result:
# we didn't get a result for this key
continue
if (
fetch_key_result.valid_until_ts
< verify_request.minimum_valid_until_ts
):
# key was not valid at this point
continue
with PreserveLoggingContext(): with PreserveLoggingContext():
verify_request.deferred.callback( verify_request.deferred.callback(
( (server_name, key_id, fetch_key_result.verify_key)
server_name,
key_id,
fetch_key_result.verify_key,
) )
) completed.append(verify_request)
break
else:
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_keys:
break break
with PreserveLoggingContext(): remaining_requests.difference_update(completed)
for verify_request in requests_missing_keys:
verify_request.deferred.errback(
SynapseError(
401,
"No key for %s with id %s"
% (verify_request.server_name, verify_request.key_ids),
Codes.UNAUTHORIZED,
)
)
def on_err(err):
with PreserveLoggingContext():
for verify_request in verify_requests:
if not verify_request.deferred.called:
verify_request.deferred.errback(err)
run_in_background(do_iterations).addErrback(on_err)
class KeyFetcher(object): class KeyFetcher(object):
def get_keys(self, server_name_and_key_ids): def get_keys(self, keys_to_fetch):
""" """
Args: Args:
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]): keys_to_fetch (dict[str, dict[str, int]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for the keys to be fetched. server_name -> key_id -> min_valid_ts
Note that the iterables may be iterated more than once.
Returns: Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
@ -363,13 +432,15 @@ class StoreKeyFetcher(KeyFetcher):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids): def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
keys_to_fetch = ( keys_to_fetch = (
(server_name, key_id) (server_name, key_id)
for server_name, key_ids in server_name_and_key_ids for server_name, keys_for_server in keys_to_fetch.items()
for key_id in key_ids for key_id in keys_for_server.keys()
) )
res = yield self.store.get_server_verify_keys(keys_to_fetch) res = yield self.store.get_server_verify_keys(keys_to_fetch)
keys = {} keys = {}
for (server_name, key_id), key in res.items(): for (server_name, key_id), key in res.items():
@ -384,7 +455,7 @@ class BaseV2KeyFetcher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response( def process_v2_response(
self, from_server, response_json, time_added_ms, requested_ids=[] self, from_server, response_json, time_added_ms
): ):
"""Parse a 'Server Keys' structure from the result of a /key request """Parse a 'Server Keys' structure from the result of a /key request
@ -407,10 +478,6 @@ class BaseV2KeyFetcher(object):
time_added_ms (int): the timestamp to record in server_keys_json time_added_ms (int): the timestamp to record in server_keys_json
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
Returns: Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
""" """
@ -466,11 +533,6 @@ class BaseV2KeyFetcher(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json) signed_key_json_bytes = encode_canonical_json(signed_key_json)
# for reasons I don't quite understand, we store this json for the key ids we
# requested, as well as those we got.
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
yield logcontext.make_deferred_yieldable( yield logcontext.make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
@ -483,7 +545,7 @@ class BaseV2KeyFetcher(object):
ts_expires_ms=ts_valid_until_ms, ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes, key_json_bytes=signed_key_json_bytes,
) )
for key_id in updated_key_ids for key_id in verify_keys
], ],
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -502,14 +564,14 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.perspective_servers = self.config.perspectives self.perspective_servers = self.config.perspectives
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids): def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
@defer.inlineCallbacks @defer.inlineCallbacks
def get_key(perspective_name, perspective_keys): def get_key(perspective_name, perspective_keys):
try: try:
result = yield self.get_server_verify_key_v2_indirect( result = yield self.get_server_verify_key_v2_indirect(
server_name_and_key_ids, perspective_name, perspective_keys keys_to_fetch, perspective_name, perspective_keys
) )
defer.returnValue(result) defer.returnValue(result)
except KeyLookupError as e: except KeyLookupError as e:
@ -543,13 +605,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_indirect( def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys self, keys_to_fetch, perspective_name, perspective_keys
): ):
""" """
Args: Args:
server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]): keys_to_fetch (dict[str, dict[str, int]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for the keys to be fetched. server_name -> key_id -> min_valid_ts
perspective_name (str): name of the notary server to query for the keys perspective_name (str): name of the notary server to query for the keys
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server notary server
@ -563,12 +627,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
""" """
logger.info( logger.info(
"Requesting keys %s from notary server %s", "Requesting keys %s from notary server %s",
server_names_and_key_ids, keys_to_fetch.items(),
perspective_name, perspective_name,
) )
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
try: try:
query_response = yield self.client.post_json( query_response = yield self.client.post_json(
destination=perspective_name, destination=perspective_name,
@ -576,12 +638,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
data={ data={
u"server_keys": { u"server_keys": {
server_name: { server_name: {
key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids key_id: {u"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
} }
for server_name, key_ids in server_names_and_key_ids for server_name, server_keys in keys_to_fetch.items()
} }
}, },
long_retries=True,
) )
except (NotRetryingDestination, RequestSendFailed) as e: except (NotRetryingDestination, RequestSendFailed) as e:
raise_from(KeyLookupError("Failed to connect to remote server"), e) raise_from(KeyLookupError("Failed to connect to remote server"), e)
@ -687,34 +749,54 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_http_client() self.client = hs.get_http_client()
def get_keys(self, keys_to_fetch):
"""
Args:
keys_to_fetch (dict[str, iterable[str]]):
the keys to be fetched. server_name -> key_ids
Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
map from server_name -> key_id -> FetchKeyResult
"""
results = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids): def get_key(key_to_fetch_item):
"""see KeyFetcher.get_keys""" server_name, key_ids = key_to_fetch_item
results = yield logcontext.make_deferred_yieldable( try:
defer.gatherResults( keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
[ results[server_name] = keys
run_in_background( except KeyLookupError as e:
self.get_server_verify_key_v2_direct, server_name, key_ids logger.warning(
) "Error looking up keys %s from %s: %s", key_ids, server_name, e
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
) )
except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name)
merged = {} return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
for result in results: lambda _: results
merged.update(result)
defer.returnValue(
{server_name: keys for server_name, keys in merged.items() if keys}
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids): def get_server_verify_key_v2_direct(self, server_name, key_ids):
"""
Args:
server_name (str):
key_ids (iterable[str]):
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
Raises:
KeyLookupError if there was a problem making the lookup
"""
keys = {} # type: dict[str, FetchKeyResult] keys = {} # type: dict[str, FetchKeyResult]
for requested_key_id in key_ids: for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
if requested_key_id in keys: if requested_key_id in keys:
continue continue
@ -725,6 +807,19 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
path="/_matrix/key/v2/server/" path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id), + urllib.parse.quote(requested_key_id),
ignore_backoff=True, ignore_backoff=True,
# we only give the remote server 10s to respond. It should be an
# easy request to handle, so if it doesn't reply within 10s, it's
# probably not going to.
#
# Furthermore, when we are acting as a notary server, we cannot
# wait all day for all of the origin servers, as the requesting
# server will otherwise time out before we can respond.
#
# (Note that get_json may make 4 attempts, so this can still take
# almost 45 seconds to fetch the headers, plus up to another 60s to
# read the response).
timeout=10000,
) )
except (NotRetryingDestination, RequestSendFailed) as e: except (NotRetryingDestination, RequestSendFailed) as e:
raise_from(KeyLookupError("Failed to connect to remote server"), e) raise_from(KeyLookupError("Failed to connect to remote server"), e)
@ -739,7 +834,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
from_server=server_name, from_server=server_name,
requested_ids=[requested_key_id],
response_json=response, response_json=response,
time_added_ms=time_now_ms, time_added_ms=time_now_ms,
) )
@ -750,7 +844,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
) )
keys.update(response_keys) keys.update(response_keys)
defer.returnValue({server_name: keys}) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -767,31 +861,8 @@ def _handle_key_deferred(verify_request):
SynapseError if there was a problem performing the verification SynapseError if there was a problem performing the verification
""" """
server_name = verify_request.server_name server_name = verify_request.server_name
try:
with PreserveLoggingContext(): with PreserveLoggingContext():
_, key_id, verify_key = yield verify_request.deferred _, key_id, verify_key = yield verify_request.deferred
except KeyLookupError as e:
logger.warn(
"Failed to download keys for %s: %s %s",
server_name,
type(e).__name__,
str(e),
)
raise SynapseError(
502, "Error downloading keys for %s" % (server_name,), Codes.UNAUTHORIZED
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name,
type(e).__name__,
str(e),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, verify_request.key_ids),
Codes.UNAUTHORIZED,
)
json_object = verify_request.json_object json_object = verify_request.json_object

View File

@ -265,7 +265,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
] ]
more_deferreds = keyring.verify_json_objects_for_server([ more_deferreds = keyring.verify_json_objects_for_server([
(p.sender_domain, p.redacted_pdu_json) (p.sender_domain, p.redacted_pdu_json, 0)
for p in pdus_to_check_sender for p in pdus_to_check_sender
]) ])
@ -298,7 +298,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
] ]
more_deferreds = keyring.verify_json_objects_for_server([ more_deferreds = keyring.verify_json_objects_for_server([
(get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json) (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, 0)
for p in pdus_to_check_event_id for p in pdus_to_check_event_id
]) ])

View File

@ -94,6 +94,7 @@ class NoAuthenticationError(AuthenticationError):
class Authenticator(object): class Authenticator(object):
def __init__(self, hs): def __init__(self, hs):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -102,6 +103,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request, content): def authenticate_request(self, request, content):
now = self._clock.time_msec()
json_request = { json_request = {
"method": request.method.decode('ascii'), "method": request.method.decode('ascii'),
"uri": request.uri.decode('ascii'), "uri": request.uri.decode('ascii'),
@ -138,7 +140,7 @@ class Authenticator(object):
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
yield self.keyring.verify_json_for_server(origin, json_request) yield self.keyring.verify_json_for_server(origin, json_request, now)
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin request.authenticated_entity = origin

View File

@ -97,10 +97,11 @@ class GroupAttestationSigning(object):
# TODO: We also want to check that *new* attestations that people give # TODO: We also want to check that *new* attestations that people give
# us to store are valid for at least a little while. # us to store are valid for at least a little while.
if valid_until_ms < self.clock.time_msec(): now = self.clock.time_msec()
if valid_until_ms < now:
raise SynapseError(400, "Attestation expired") raise SynapseError(400, "Attestation expired")
yield self.keyring.verify_json_for_server(server_name, attestation) yield self.keyring.verify_json_for_server(server_name, attestation, now)
def create_attestation(self, group_id, user_id): def create_attestation(self, group_id, user_id):
"""Create an attestation for the group_id and user_id with default """Create an attestation for the group_id and user_id with default

View File

@ -285,7 +285,24 @@ class MatrixFederationHttpClient(object):
request (MatrixFederationRequest): details of request to be sent request (MatrixFederationRequest): details of request to be sent
timeout (int|None): number of milliseconds to wait for the response headers timeout (int|None): number of milliseconds to wait for the response headers
(including connecting to the server). 60s by default. (including connecting to the server), *for each attempt*.
60s by default.
long_retries (bool): whether to use the long retry algorithm.
The regular retry algorithm makes 4 attempts, with intervals
[0.5s, 1s, 2s].
The long retry algorithm makes 11 attempts, with intervals
[4s, 16s, 60s, 60s, ...]
Both algorithms add -20%/+40% jitter to the retry intervals.
Note that the above intervals are *in addition* to the time spent
waiting for the request to complete (up to `timeout` ms).
NB: the long retry algorithm takes over 20 minutes to complete, with
a default timeout of 60s!
ignore_backoff (bool): true to ignore the historical backoff data ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
@ -566,10 +583,14 @@ class MatrixFederationHttpClient(object):
the request body. This will be encoded as JSON. the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to json_data_callback (callable): A callable returning the dict to
use as the request body. use as the request body.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time. long_retries (bool): whether to use the long retry algorithm. See
timeout(int): How long to try (in ms) the destination for before docs on _send_request for details.
giving up. None indicates no timeout.
timeout (int|None): number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
self._default_timeout (60s) by default.
ignore_backoff (bool): true to ignore the historical backoff data ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as backoff_on_404 (bool): True if we should count a 404 response as
@ -627,15 +648,22 @@ class MatrixFederationHttpClient(object):
Args: Args:
destination (str): The remote server to send the HTTP request destination (str): The remote server to send the HTTP request
to. to.
path (str): The HTTP path. path (str): The HTTP path.
data (dict): A dict containing the data that will be used as data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON. the request body. This will be encoded as JSON.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time. long_retries (bool): whether to use the long retry algorithm. See
timeout(int): How long to try (in ms) the destination for before docs on _send_request for details.
giving up. None indicates no timeout.
timeout (int|None): number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
self._default_timeout (60s) by default.
ignore_backoff (bool): true to ignore the historical backoff data and ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway. try the request anyway.
args (dict): query params args (dict): query params
Returns: Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
@ -686,14 +714,19 @@ class MatrixFederationHttpClient(object):
Args: Args:
destination (str): The remote server to send the HTTP request destination (str): The remote server to send the HTTP request
to. to.
path (str): The HTTP path. path (str): The HTTP path.
args (dict|None): A dictionary used to create query strings, defaults to args (dict|None): A dictionary used to create query strings, defaults to
None. None.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will timeout (int|None): number of milliseconds to wait for the response headers
be retried. (including connecting to the server), *for each attempt*.
self._default_timeout (60s) by default.
ignore_backoff (bool): true to ignore the historical backoff data ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3. the request. Workaround for #3622 in Synapse <= v0.99.3.
@ -742,12 +775,18 @@ class MatrixFederationHttpClient(object):
destination (str): The remote server to send the HTTP request destination (str): The remote server to send the HTTP request
to. to.
path (str): The HTTP path. path (str): The HTTP path.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time. long_retries (bool): whether to use the long retry algorithm. See
timeout(int): How long to try (in ms) the destination for before docs on _send_request for details.
giving up. None indicates no timeout.
timeout (int|None): number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
self._default_timeout (60s) by default.
ignore_backoff (bool): true to ignore the historical backoff data and ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway. try the request anyway.
args (dict): query params
Returns: Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body. result will be the decoded JSON body.

View File

@ -1,65 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
import logging
import re
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.servlet import RestServlet
from synapse.rest.client.transactions import HttpTransactionCache
logger = logging.getLogger(__name__)
def client_path_patterns(path_regex, releases=(0,), include_in_unstable=True):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
patterns = [re.compile("^" + CLIENT_API_PREFIX + "/api/v1" + path_regex)]
if include_in_unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API.
"""
# This subclass was presumably created to allow the auth for the v1
# protocol version to be different, however this behaviour was removed.
# it may no longer be necessary
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)

View File

@ -19,11 +19,10 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import RoomAlias from synapse.types import RoomAlias
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,13 +32,14 @@ def register_servlets(hs, http_server):
ClientAppserviceDirectoryListServer(hs).register(http_server) ClientAppserviceDirectoryListServer(hs).register(http_server)
class ClientDirectoryServer(ClientV1RestServlet): class ClientDirectoryServer(RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(ClientDirectoryServer, self).__init__(hs) super(ClientDirectoryServer, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_alias): def on_GET(self, request, room_alias):
@ -120,13 +120,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ClientDirectoryListServer(ClientV1RestServlet): class ClientDirectoryListServer(RestServlet):
PATTERNS = client_path_patterns("/directory/list/room/(?P<room_id>[^/]*)$") PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs) super(ClientDirectoryListServer, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -162,15 +163,16 @@ class ClientDirectoryListServer(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ClientAppserviceDirectoryListServer(ClientV1RestServlet): class ClientAppserviceDirectoryListServer(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns(
"/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$" "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs):
super(ClientAppserviceDirectoryListServer, self).__init__(hs) super(ClientAppserviceDirectoryListServer, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
def on_PUT(self, request, network_id, room_id): def on_PUT(self, request, network_id, room_id):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)

View File

@ -19,21 +19,22 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EventStreamRestServlet(ClientV1RestServlet): class EventStreamRestServlet(RestServlet):
PATTERNS = client_path_patterns("/events$") PATTERNS = client_patterns("/events$", v1=True)
DEFAULT_LONGPOLL_TIME_MS = 30000 DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs): def __init__(self, hs):
super(EventStreamRestServlet, self).__init__(hs) super(EventStreamRestServlet, self).__init__()
self.event_stream_handler = hs.get_event_stream_handler() self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -76,11 +77,11 @@ class EventStreamRestServlet(ClientV1RestServlet):
# TODO: Unit test gets, with and without auth, with different kinds of events. # TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(ClientV1RestServlet): class EventRestServlet(RestServlet):
PATTERNS = client_path_patterns("/events/(?P<event_id>[^/]*)$") PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(EventRestServlet, self).__init__(hs) super(EventRestServlet, self).__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()

View File

@ -15,19 +15,19 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import parse_boolean from synapse.http.servlet import RestServlet, parse_boolean
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
# TODO: Needs unit testing # TODO: Needs unit testing
class InitialSyncRestServlet(ClientV1RestServlet): class InitialSyncRestServlet(RestServlet):
PATTERNS = client_path_patterns("/initialSync$") PATTERNS = client_patterns("/initialSync$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs) super(InitialSyncRestServlet, self).__init__()
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View File

@ -29,12 +29,11 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -81,15 +80,16 @@ def login_id_thirdparty_from_phone(identifier):
} }
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(RestServlet):
PATTERNS = client_path_patterns("/login$") PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas" CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso" SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token" TOKEN_TYPE = "m.login.token"
JWT_TYPE = "m.login.jwt" JWT_TYPE = "m.login.jwt"
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__()
self.hs = hs
self.jwt_enabled = hs.config.jwt_enabled self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm self.jwt_algorithm = hs.config.jwt_algorithm
@ -371,7 +371,7 @@ class LoginRestServlet(ClientV1RestServlet):
class CasRedirectServlet(RestServlet): class CasRedirectServlet(RestServlet):
PATTERNS = client_path_patterns("/login/(cas|sso)/redirect") PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(CasRedirectServlet, self).__init__() super(CasRedirectServlet, self).__init__()
@ -394,27 +394,27 @@ class CasRedirectServlet(RestServlet):
finish_request(request) finish_request(request)
class CasTicketServlet(ClientV1RestServlet): class CasTicketServlet(RestServlet):
PATTERNS = client_path_patterns("/login/cas/ticket") PATTERNS = client_patterns("/login/cas/ticket", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(CasTicketServlet, self).__init__(hs) super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs) self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_simple_http_client()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
client_redirect_url = parse_string(request, "redirectUrl", required=True) client_redirect_url = parse_string(request, "redirectUrl", required=True)
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate" uri = self.cas_server_url + "/proxyValidate"
args = { args = {
"ticket": parse_string(request, "ticket", required=True), "ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url "service": self.cas_service_url
} }
try: try:
body = yield http_client.get_raw(uri, args) body = yield self._http_client.get_raw(uri, args)
except PartialDownloadError as pde: except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed, # Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data # even if that's being used old-http style to signal end-of-data

View File

@ -17,17 +17,18 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_patterns from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LogoutRestServlet(ClientV1RestServlet): class LogoutRestServlet(RestServlet):
PATTERNS = client_path_patterns("/logout$") PATTERNS = client_patterns("/logout$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs) super(LogoutRestServlet, self).__init__()
self._auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()
@ -41,7 +42,7 @@ class LogoutRestServlet(ClientV1RestServlet):
if requester.device_id is None: if requester.device_id is None:
# the acccess token wasn't associated with a device. # the acccess token wasn't associated with a device.
# Just delete the access token # Just delete the access token
access_token = self._auth.get_access_token_from_request(request) access_token = self.auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token) yield self._auth_handler.delete_access_token(access_token)
else: else:
yield self._device_handler.delete_device( yield self._device_handler.delete_device(
@ -50,11 +51,11 @@ class LogoutRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class LogoutAllRestServlet(ClientV1RestServlet): class LogoutAllRestServlet(RestServlet):
PATTERNS = client_path_patterns("/logout/all$") PATTERNS = client_patterns("/logout/all$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs) super(LogoutAllRestServlet, self).__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler() self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler() self._device_handler = hs.get_device_handler()

View File

@ -23,21 +23,22 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(ClientV1RestServlet): class PresenceStatusRestServlet(RestServlet):
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(PresenceStatusRestServlet, self).__init__(hs) super(PresenceStatusRestServlet, self).__init__()
self.hs = hs
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):

View File

@ -16,18 +16,19 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """ """ This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_patterns
class ProfileDisplaynameRestServlet(RestServlet):
class ProfileDisplaynameRestServlet(ClientV1RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
def __init__(self, hs): def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs) super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -71,12 +72,14 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
return (200, {}) return (200, {})
class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs) super(ProfileAvatarURLRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -119,12 +122,14 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
return (200, {}) return (200, {})
class ProfileRestServlet(ClientV1RestServlet): class ProfileRestServlet(RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs) super(ProfileRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):

View File

@ -21,22 +21,22 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.http.servlet import parse_json_value_from_request, parse_string from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from .base import ClientV1RestServlet, client_path_patterns
class PushRuleRestServlet(RestServlet):
class PushRuleRestServlet(ClientV1RestServlet): PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
PATTERNS = client_path_patterns("/(?P<path>pushrules/.*)$")
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash") "Unrecognised request: You probably wanted a trailing slash")
def __init__(self, hs): def __init__(self, hs):
super(PushRuleRestServlet, self).__init__(hs) super(PushRuleRestServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None self._is_worker = hs.config.worker_app is not None

View File

@ -26,17 +26,18 @@ from synapse.http.servlet import (
parse_string, parse_string,
) )
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.rest.client.v2_alpha._base import client_patterns
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PushersRestServlet(ClientV1RestServlet): class PushersRestServlet(RestServlet):
PATTERNS = client_path_patterns("/pushers$") PATTERNS = client_patterns("/pushers$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(PushersRestServlet, self).__init__(hs) super(PushersRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -69,11 +70,13 @@ class PushersRestServlet(ClientV1RestServlet):
return 200, {} return 200, {}
class PushersSetRestServlet(ClientV1RestServlet): class PushersSetRestServlet(RestServlet):
PATTERNS = client_path_patterns("/pushers/set$") PATTERNS = client_patterns("/pushers/set$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(PushersSetRestServlet, self).__init__(hs) super(PushersSetRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool() self.pusher_pool = self.hs.get_pusherpool()
@ -141,7 +144,7 @@ class PushersRemoveRestServlet(RestServlet):
""" """
To allow pusher to be delete by clicking a link (ie. GET request) To allow pusher to be delete by clicking a link (ie. GET request)
""" """
PATTERNS = client_path_patterns("/pushers/remove$") PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>" SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs): def __init__(self, hs):

View File

@ -28,37 +28,45 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2 from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet,
assert_params_in_dict, assert_params_in_dict,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet): class TransactionRestServlet(RestServlet):
def __init__(self, hs):
super(TransactionRestServlet, self).__init__()
self.txns = HttpTransactionCache(hs)
class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here # No PATTERN; we have custom dispatch rules here
def __init__(self, hs): def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs) super(RoomCreateRestServlet, self).__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler() self._room_creation_handler = hs.get_room_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
PATTERNS = "/createRoom" PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity # define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_paths("OPTIONS", http_server.register_paths("OPTIONS",
client_path_patterns("/rooms(?:/.*)?$"), client_patterns("/rooms(?:/.*)?$", v1=True),
self.on_OPTIONS) self.on_OPTIONS)
# define CORS for /createRoom[/txnid] # define CORS for /createRoom[/txnid]
http_server.register_paths("OPTIONS", http_server.register_paths("OPTIONS",
client_path_patterns("/createRoom(?:/.*)?$"), client_patterns("/createRoom(?:/.*)?$", v1=True),
self.on_OPTIONS) self.on_OPTIONS)
def on_PUT(self, request, txn_id): def on_PUT(self, request, txn_id):
@ -85,13 +93,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events # TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet): class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs) super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
# /room/$roomid/state/$eventtype # /room/$roomid/state/$eventtype
@ -102,16 +111,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$") "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
http_server.register_paths("GET", http_server.register_paths("GET",
client_path_patterns(state_key), client_patterns(state_key, v1=True),
self.on_GET) self.on_GET)
http_server.register_paths("PUT", http_server.register_paths("PUT",
client_path_patterns(state_key), client_patterns(state_key, v1=True),
self.on_PUT) self.on_PUT)
http_server.register_paths("GET", http_server.register_paths("GET",
client_path_patterns(no_state_key), client_patterns(no_state_key, v1=True),
self.on_GET_no_state_key) self.on_GET_no_state_key)
http_server.register_paths("PUT", http_server.register_paths("PUT",
client_path_patterns(no_state_key), client_patterns(no_state_key, v1=True),
self.on_PUT_no_state_key) self.on_PUT_no_state_key)
def on_GET_no_state_key(self, request, room_id, event_type): def on_GET_no_state_key(self, request, room_id, event_type):
@ -185,11 +194,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet): class RoomSendEventRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs) super(RoomSendEventRestServlet, self).__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
@ -229,10 +239,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet): class JoinRoomAliasServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs) super(JoinRoomAliasServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
# /join/$room_identifier[/$txn_id] # /join/$room_identifier[/$txn_id]
@ -291,8 +302,13 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class PublicRoomListRestServlet(ClientV1RestServlet): class PublicRoomListRestServlet(TransactionRestServlet):
PATTERNS = client_path_patterns("/publicRooms$") PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs):
super(PublicRoomListRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -382,12 +398,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs) super(RoomMemberListRestServlet, self).__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -436,12 +453,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
# deprecated in favour of /members?membership=join? # deprecated in favour of /members?membership=join?
# except it does custom AS logic and has a simpler return format # except it does custom AS logic and has a simpler return format
class JoinedRoomMemberListRestServlet(ClientV1RestServlet): class JoinedRoomMemberListRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(JoinedRoomMemberListRestServlet, self).__init__(hs) super(JoinedRoomMemberListRestServlet, self).__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -457,12 +475,13 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
# TODO: Needs better unit testing # TODO: Needs better unit testing
class RoomMessageListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs) super(RoomMessageListRestServlet, self).__init__()
self.pagination_handler = hs.get_pagination_handler() self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -491,12 +510,13 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomStateRestServlet(ClientV1RestServlet): class RoomStateRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs) super(RoomStateRestServlet, self).__init__()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -511,12 +531,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomInitialSyncRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs) super(RoomInitialSyncRestServlet, self).__init__()
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -530,16 +551,17 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class RoomEventServlet(ClientV1RestServlet): class RoomEventServlet(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs):
super(RoomEventServlet, self).__init__(hs) super(RoomEventServlet, self).__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
@ -554,16 +576,17 @@ class RoomEventServlet(ClientV1RestServlet):
defer.returnValue((404, "Event not found.")) defer.returnValue((404, "Event not found."))
class RoomEventContextServlet(ClientV1RestServlet): class RoomEventContextServlet(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs):
super(RoomEventContextServlet, self).__init__(hs) super(RoomEventContextServlet, self).__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler() self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
@ -609,10 +632,11 @@ class RoomEventContextServlet(ClientV1RestServlet):
defer.returnValue((200, results)) defer.returnValue((200, results))
class RoomForgetRestServlet(ClientV1RestServlet): class RoomForgetRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs) super(RoomForgetRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@ -639,11 +663,12 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet): class RoomMembershipRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs) super(RoomMembershipRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
@ -722,11 +747,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
) )
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs) super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@ -757,15 +783,16 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
) )
class RoomTypingRestServlet(ClientV1RestServlet): class RoomTypingRestServlet(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True
) )
def __init__(self, hs): def __init__(self, hs):
super(RoomTypingRestServlet, self).__init__(hs) super(RoomTypingRestServlet, self).__init__()
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler() self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
@ -798,14 +825,13 @@ class RoomTypingRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class SearchRestServlet(ClientV1RestServlet): class SearchRestServlet(RestServlet):
PATTERNS = client_path_patterns( PATTERNS = client_patterns("/search$", v1=True)
"/search$"
)
def __init__(self, hs): def __init__(self, hs):
super(SearchRestServlet, self).__init__(hs) super(SearchRestServlet, self).__init__()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -823,12 +849,13 @@ class SearchRestServlet(ClientV1RestServlet):
defer.returnValue((200, results)) defer.returnValue((200, results))
class JoinedRoomsRestServlet(ClientV1RestServlet): class JoinedRoomsRestServlet(RestServlet):
PATTERNS = client_path_patterns("/joined_rooms$") PATTERNS = client_patterns("/joined_rooms$", v1=True)
def __init__(self, hs): def __init__(self, hs):
super(JoinedRoomsRestServlet, self).__init__(hs) super(JoinedRoomsRestServlet, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -853,18 +880,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
""" """
http_server.register_paths( http_server.register_paths(
"POST", "POST",
client_path_patterns(regex_string + "$"), client_patterns(regex_string + "$", v1=True),
servlet.on_POST servlet.on_POST
) )
http_server.register_paths( http_server.register_paths(
"PUT", "PUT",
client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"), client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_PUT servlet.on_PUT
) )
if with_get: if with_get:
http_server.register_paths( http_server.register_paths(
"GET", "GET",
client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"), client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_GET servlet.on_GET
) )

View File

@ -19,11 +19,17 @@ import hmac
from twisted.internet import defer from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_patterns from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
class VoipRestServlet(ClientV1RestServlet): class VoipRestServlet(RestServlet):
PATTERNS = client_path_patterns("/voip/turnServer$") PATTERNS = client_patterns("/voip/turnServer$", v1=True)
def __init__(self, hs):
super(VoipRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View File

@ -26,8 +26,7 @@ from synapse.api.urls import CLIENT_API_PREFIX
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def client_v2_patterns(path_regex, releases=(0,), def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
unstable=True):
"""Creates a regex compiled client path with the correct client path """Creates a regex compiled client path with the correct client path
prefix. prefix.
@ -41,6 +40,9 @@ def client_v2_patterns(path_regex, releases=(0,),
if unstable: if unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable" unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex)) patterns.append(re.compile("^" + unstable_prefix + path_regex))
if v1:
v1_prefix = CLIENT_API_PREFIX + "/api/v1"
patterns.append(re.compile("^" + v1_prefix + path_regex))
for release in releases: for release in releases:
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex)) patterns.append(re.compile("^" + new_prefix + path_regex))

View File

@ -30,13 +30,13 @@ from synapse.http.servlet import (
from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet): class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$") PATTERNS = client_patterns("/account/password/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
super(EmailPasswordRequestTokenRestServlet, self).__init__() super(EmailPasswordRequestTokenRestServlet, self).__init__()
@ -70,7 +70,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
class MsisdnPasswordRequestTokenRestServlet(RestServlet): class MsisdnPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$") PATTERNS = client_patterns("/account/password/msisdn/requestToken$")
def __init__(self, hs): def __init__(self, hs):
super(MsisdnPasswordRequestTokenRestServlet, self).__init__() super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
@ -108,7 +108,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
class PasswordRestServlet(RestServlet): class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password$") PATTERNS = client_patterns("/account/password$")
def __init__(self, hs): def __init__(self, hs):
super(PasswordRestServlet, self).__init__() super(PasswordRestServlet, self).__init__()
@ -180,7 +180,7 @@ class PasswordRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$") PATTERNS = client_patterns("/account/deactivate$")
def __init__(self, hs): def __init__(self, hs):
super(DeactivateAccountRestServlet, self).__init__() super(DeactivateAccountRestServlet, self).__init__()
@ -228,7 +228,7 @@ class DeactivateAccountRestServlet(RestServlet):
class EmailThreepidRequestTokenRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") PATTERNS = client_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
@ -263,7 +263,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$") PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
@ -300,7 +300,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
class ThreepidRestServlet(RestServlet): class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid$") PATTERNS = client_patterns("/account/3pid$")
def __init__(self, hs): def __init__(self, hs):
super(ThreepidRestServlet, self).__init__() super(ThreepidRestServlet, self).__init__()
@ -364,7 +364,7 @@ class ThreepidRestServlet(RestServlet):
class ThreepidDeleteRestServlet(RestServlet): class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/delete$") PATTERNS = client_patterns("/account/3pid/delete$")
def __init__(self, hs): def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__() super(ThreepidDeleteRestServlet, self).__init__()
@ -401,7 +401,7 @@ class ThreepidDeleteRestServlet(RestServlet):
class WhoamiRestServlet(RestServlet): class WhoamiRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/whoami$") PATTERNS = client_patterns("/account/whoami$")
def __init__(self, hs): def __init__(self, hs):
super(WhoamiRestServlet, self).__init__() super(WhoamiRestServlet, self).__init__()

View File

@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,7 +30,7 @@ class AccountDataServlet(RestServlet):
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
) )
@ -79,7 +79,7 @@ class RoomAccountDataServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)" "/user/(?P<user_id>[^/]*)"
"/rooms/(?P<room_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)" "/account_data/(?P<account_data_type>[^/]*)"

View File

@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet): class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_v2_patterns("/account_validity/renew$") PATTERNS = client_patterns("/account_validity/renew$")
SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>" SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>"
def __init__(self, hs): def __init__(self, hs):
@ -60,7 +60,7 @@ class AccountValidityRenewServlet(RestServlet):
class AccountValiditySendMailServlet(RestServlet): class AccountValiditySendMailServlet(RestServlet):
PATTERNS = client_v2_patterns("/account_validity/send_mail$") PATTERNS = client_patterns("/account_validity/send_mail$")
def __init__(self, hs): def __init__(self, hs):
""" """

View File

@ -23,7 +23,7 @@ from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet, parse_string from synapse.http.servlet import RestServlet, parse_string
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -122,7 +122,7 @@ class AuthRestServlet(RestServlet):
cannot be handled in the normal flow (with requests to the same endpoint). cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth. Current use is for web fallback auth.
""" """
PATTERNS = client_v2_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs): def __init__(self, hs):
super(AuthRestServlet, self).__init__() super(AuthRestServlet, self).__init__()

View File

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
class CapabilitiesRestServlet(RestServlet): class CapabilitiesRestServlet(RestServlet):
"""End point to expose the capabilities of the server.""" """End point to expose the capabilities of the server."""
PATTERNS = client_v2_patterns("/capabilities$") PATTERNS = client_patterns("/capabilities$")
def __init__(self, hs): def __init__(self, hs):
""" """

View File

@ -24,13 +24,13 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
) )
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet): class DevicesRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices$") PATTERNS = client_patterns("/devices$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -56,7 +56,7 @@ class DeleteDevicesRestServlet(RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth. key which lists the device_ids to delete. Requires user interactive auth.
""" """
PATTERNS = client_v2_patterns("/delete_devices") PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs): def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__() super(DeleteDevicesRestServlet, self).__init__()
@ -95,7 +95,7 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet): class DeviceRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$") PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs):
""" """

View File

@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
from ._base import client_v2_patterns, set_timeline_upper_limit from ._base import client_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet): class GetFilterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs):
super(GetFilterRestServlet, self).__init__() super(GetFilterRestServlet, self).__init__()
@ -63,7 +63,7 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter") PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs): def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__() super(CreateFilterRestServlet, self).__init__()

View File

@ -21,7 +21,7 @@ from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID from synapse.types import GroupID
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class GroupServlet(RestServlet): class GroupServlet(RestServlet):
"""Get the group profile """Get the group profile
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs): def __init__(self, hs):
super(GroupServlet, self).__init__() super(GroupServlet, self).__init__()
@ -65,7 +65,7 @@ class GroupServlet(RestServlet):
class GroupSummaryServlet(RestServlet): class GroupSummaryServlet(RestServlet):
"""Get the full group summary """Get the full group summary
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs): def __init__(self, hs):
super(GroupSummaryServlet, self).__init__() super(GroupSummaryServlet, self).__init__()
@ -93,7 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
- /groups/:group/summary/rooms/:room_id - /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary" "/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?" "(/categories/(?P<category_id>[^/]+))?"
"/rooms/(?P<room_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)$"
@ -137,7 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
class GroupCategoryServlet(RestServlet): class GroupCategoryServlet(RestServlet):
"""Get/add/update/delete a group category """Get/add/update/delete a group category
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
) )
@ -189,7 +189,7 @@ class GroupCategoryServlet(RestServlet):
class GroupCategoriesServlet(RestServlet): class GroupCategoriesServlet(RestServlet):
"""Get all group categories """Get all group categories
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/$" "/groups/(?P<group_id>[^/]*)/categories/$"
) )
@ -214,7 +214,7 @@ class GroupCategoriesServlet(RestServlet):
class GroupRoleServlet(RestServlet): class GroupRoleServlet(RestServlet):
"""Get/add/update/delete a group role """Get/add/update/delete a group role
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$" "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
) )
@ -266,7 +266,7 @@ class GroupRoleServlet(RestServlet):
class GroupRolesServlet(RestServlet): class GroupRolesServlet(RestServlet):
"""Get all group roles """Get all group roles
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/roles/$" "/groups/(?P<group_id>[^/]*)/roles/$"
) )
@ -295,7 +295,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
- /groups/:group/summary/users/:room_id - /groups/:group/summary/users/:room_id
- /groups/:group/summary/roles/:role/users/:user_id - /groups/:group/summary/roles/:role/users/:user_id
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary" "/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?" "(/roles/(?P<role_id>[^/]+))?"
"/users/(?P<user_id>[^/]*)$" "/users/(?P<user_id>[^/]*)$"
@ -339,7 +339,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
class GroupRoomServlet(RestServlet): class GroupRoomServlet(RestServlet):
"""Get all rooms in a group """Get all rooms in a group
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs): def __init__(self, hs):
super(GroupRoomServlet, self).__init__() super(GroupRoomServlet, self).__init__()
@ -360,7 +360,7 @@ class GroupRoomServlet(RestServlet):
class GroupUsersServlet(RestServlet): class GroupUsersServlet(RestServlet):
"""Get all users in a group """Get all users in a group
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs): def __init__(self, hs):
super(GroupUsersServlet, self).__init__() super(GroupUsersServlet, self).__init__()
@ -381,7 +381,7 @@ class GroupUsersServlet(RestServlet):
class GroupInvitedUsersServlet(RestServlet): class GroupInvitedUsersServlet(RestServlet):
"""Get users invited to a group """Get users invited to a group
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs): def __init__(self, hs):
super(GroupInvitedUsersServlet, self).__init__() super(GroupInvitedUsersServlet, self).__init__()
@ -405,7 +405,7 @@ class GroupInvitedUsersServlet(RestServlet):
class GroupSettingJoinPolicyServlet(RestServlet): class GroupSettingJoinPolicyServlet(RestServlet):
"""Set group join policy """Set group join policy
""" """
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs): def __init__(self, hs):
super(GroupSettingJoinPolicyServlet, self).__init__() super(GroupSettingJoinPolicyServlet, self).__init__()
@ -431,7 +431,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
class GroupCreateServlet(RestServlet): class GroupCreateServlet(RestServlet):
"""Create a group """Create a group
""" """
PATTERNS = client_v2_patterns("/create_group$") PATTERNS = client_patterns("/create_group$")
def __init__(self, hs): def __init__(self, hs):
super(GroupCreateServlet, self).__init__() super(GroupCreateServlet, self).__init__()
@ -462,7 +462,7 @@ class GroupCreateServlet(RestServlet):
class GroupAdminRoomsServlet(RestServlet): class GroupAdminRoomsServlet(RestServlet):
"""Add a room to the group """Add a room to the group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$" "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
) )
@ -499,7 +499,7 @@ class GroupAdminRoomsServlet(RestServlet):
class GroupAdminRoomsConfigServlet(RestServlet): class GroupAdminRoomsConfigServlet(RestServlet):
"""Update the config of a room in a group """Update the config of a room in a group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)" "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$" "/config/(?P<config_key>[^/]*)$"
) )
@ -526,7 +526,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
class GroupAdminUsersInviteServlet(RestServlet): class GroupAdminUsersInviteServlet(RestServlet):
"""Invite a user to the group """Invite a user to the group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$" "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
) )
@ -555,7 +555,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
class GroupAdminUsersKickServlet(RestServlet): class GroupAdminUsersKickServlet(RestServlet):
"""Kick a user from the group """Kick a user from the group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$" "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
) )
@ -581,7 +581,7 @@ class GroupAdminUsersKickServlet(RestServlet):
class GroupSelfLeaveServlet(RestServlet): class GroupSelfLeaveServlet(RestServlet):
"""Leave a joined group """Leave a joined group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/leave$" "/groups/(?P<group_id>[^/]*)/self/leave$"
) )
@ -607,7 +607,7 @@ class GroupSelfLeaveServlet(RestServlet):
class GroupSelfJoinServlet(RestServlet): class GroupSelfJoinServlet(RestServlet):
"""Attempt to join a group, or knock """Attempt to join a group, or knock
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/join$" "/groups/(?P<group_id>[^/]*)/self/join$"
) )
@ -633,7 +633,7 @@ class GroupSelfJoinServlet(RestServlet):
class GroupSelfAcceptInviteServlet(RestServlet): class GroupSelfAcceptInviteServlet(RestServlet):
"""Accept a group invite """Accept a group invite
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/accept_invite$" "/groups/(?P<group_id>[^/]*)/self/accept_invite$"
) )
@ -659,7 +659,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
class GroupSelfUpdatePublicityServlet(RestServlet): class GroupSelfUpdatePublicityServlet(RestServlet):
"""Update whether we publicise a users membership of a group """Update whether we publicise a users membership of a group
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/update_publicity$" "/groups/(?P<group_id>[^/]*)/self/update_publicity$"
) )
@ -686,7 +686,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
class PublicisedGroupsForUserServlet(RestServlet): class PublicisedGroupsForUserServlet(RestServlet):
"""Get the list of groups a user is advertising """Get the list of groups a user is advertising
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/publicised_groups/(?P<user_id>[^/]*)$" "/publicised_groups/(?P<user_id>[^/]*)$"
) )
@ -711,7 +711,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
class PublicisedGroupsForUsersServlet(RestServlet): class PublicisedGroupsForUsersServlet(RestServlet):
"""Get the list of groups a user is advertising """Get the list of groups a user is advertising
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/publicised_groups$" "/publicised_groups$"
) )
@ -739,7 +739,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
class GroupsForUserServlet(RestServlet): class GroupsForUserServlet(RestServlet):
"""Get all groups the logged in user is joined to """Get all groups the logged in user is joined to
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/joined_groups$" "/joined_groups$"
) )

View File

@ -26,7 +26,7 @@ from synapse.http.servlet import (
) )
from synapse.types import StreamToken from synapse.types import StreamToken
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class KeyUploadServlet(RestServlet):
}, },
} }
""" """
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -130,7 +130,7 @@ class KeyQueryServlet(RestServlet):
} } } } } } } } } } } }
""" """
PATTERNS = client_v2_patterns("/keys/query$") PATTERNS = client_patterns("/keys/query$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -159,7 +159,7 @@ class KeyChangesServlet(RestServlet):
200 OK 200 OK
{ "changed": ["@foo:example.com"] } { "changed": ["@foo:example.com"] }
""" """
PATTERNS = client_v2_patterns("/keys/changes$") PATTERNS = client_patterns("/keys/changes$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -209,7 +209,7 @@ class OneTimeKeyServlet(RestServlet):
} } } } } } } }
""" """
PATTERNS = client_v2_patterns("/keys/claim$") PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs): def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__() super(OneTimeKeyServlet, self).__init__()

View File

@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet): class NotificationsServlet(RestServlet):
PATTERNS = client_v2_patterns("/notifications$") PATTERNS = client_patterns("/notifications$")
def __init__(self, hs): def __init__(self, hs):
super(NotificationsServlet, self).__init__() super(NotificationsServlet, self).__init__()

View File

@ -22,7 +22,7 @@ from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,7 +56,7 @@ class IdTokenServlet(RestServlet):
"expires_in": 3600, "expires_in": 3600,
} }
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/openid/request_token" "/user/(?P<user_id>[^/]*)/openid/request_token"
) )

View File

@ -19,13 +19,13 @@ from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet): class ReadMarkerRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$") PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
def __init__(self, hs): def __init__(self, hs):
super(ReadMarkerRestServlet, self).__init__() super(ReadMarkerRestServlet, self).__init__()

View File

@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReceiptRestServlet(RestServlet): class ReceiptRestServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)" "/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$" "/(?P<event_id>[^/]*)$"

View File

@ -43,7 +43,7 @@ from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler from ._base import client_patterns, interactive_auth_handler
# We ought to be using hmac.compare_digest() but on older pythons it doesn't # We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison # exist. It's a _really minor_ security flaw to use plain string comparison
@ -60,7 +60,7 @@ logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet): class EmailRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$") PATTERNS = client_patterns("/register/email/requestToken$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -98,7 +98,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet): class MsisdnRegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/msisdn/requestToken$") PATTERNS = client_patterns("/register/msisdn/requestToken$")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -142,7 +142,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet): class UsernameAvailabilityRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/available") PATTERNS = client_patterns("/register/available")
def __init__(self, hs): def __init__(self, hs):
""" """
@ -182,7 +182,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$") PATTERNS = client_patterns("/register$")
def __init__(self, hs): def __init__(self, hs):
""" """

View File

@ -34,7 +34,7 @@ from synapse.http.servlet import (
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -66,12 +66,12 @@ class RelationSendServlet(RestServlet):
def register(self, http_server): def register(self, http_server):
http_server.register_paths( http_server.register_paths(
"POST", "POST",
client_v2_patterns(self.PATTERN + "$", releases=()), client_patterns(self.PATTERN + "$", releases=()),
self.on_PUT_or_POST, self.on_PUT_or_POST,
) )
http_server.register_paths( http_server.register_paths(
"PUT", "PUT",
client_v2_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()), client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
self.on_PUT, self.on_PUT,
) )
@ -120,7 +120,7 @@ class RelationPaginationServlet(RestServlet):
filtered by relation type and event type. filtered by relation type and event type.
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=(), releases=(),
@ -197,7 +197,7 @@ class RelationAggregationPaginationServlet(RestServlet):
} }
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=(), releases=(),
@ -269,7 +269,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
} }
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
"/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$", "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$",
releases=(), releases=(),

View File

@ -27,13 +27,13 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
) )
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet): class ReportEventRestServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
) )

View File

@ -24,13 +24,13 @@ from synapse.http.servlet import (
parse_string, parse_string,
) )
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RoomKeysServlet(RestServlet): class RoomKeysServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$" "/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
) )
@ -256,7 +256,7 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/room_keys/version$" "/room_keys/version$"
) )
@ -314,7 +314,7 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/room_keys/version(/(?P<version>[^/]+))?$" "/room_keys/version(/(?P<version>[^/]+))?$"
) )

View File

@ -25,7 +25,7 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
) )
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,7 +47,7 @@ class RoomUpgradeRestServlet(RestServlet):
Args: Args:
hs (synapse.server.HomeServer): hs (synapse.server.HomeServer):
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
# /rooms/$roomid/upgrade # /rooms/$roomid/upgrade
"/rooms/(?P<room_id>[^/]*)/upgrade$", "/rooms/(?P<room_id>[^/]*)/upgrade$",
) )

View File

@ -21,13 +21,13 @@ from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet): class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
) )

View File

@ -32,7 +32,7 @@ from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken from synapse.types import StreamToken
from ._base import client_v2_patterns, set_timeline_upper_limit from ._base import client_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -73,7 +73,7 @@ class SyncRestServlet(RestServlet):
} }
""" """
PATTERNS = client_v2_patterns("/sync$") PATTERNS = client_patterns("/sync$")
ALLOWED_PRESENCE = set(["online", "offline", "unavailable"]) ALLOWED_PRESENCE = set(["online", "offline", "unavailable"])
def __init__(self, hs): def __init__(self, hs):

View File

@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,7 +29,7 @@ class TagListServlet(RestServlet):
""" """
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
) )
@ -54,7 +54,7 @@ class TagServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
""" """
PATTERNS = client_v2_patterns( PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
) )

View File

@ -21,13 +21,13 @@ from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet): class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocols") PATTERNS = client_patterns("/thirdparty/protocols")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__() super(ThirdPartyProtocolsServlet, self).__init__()
@ -44,7 +44,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyProtocolServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__() super(ThirdPartyProtocolServlet, self).__init__()
@ -66,7 +66,7 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__() super(ThirdPartyUserServlet, self).__init__()
@ -89,7 +89,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs): def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__() super(ThirdPartyLocationServlet, self).__init__()

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns from ._base import client_patterns
class TokenRefreshRestServlet(RestServlet): class TokenRefreshRestServlet(RestServlet):
@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet):
Exchanges refresh tokens for a pair of an access token and a new refresh Exchanges refresh tokens for a pair of an access token and a new refresh
token. token.
""" """
PATTERNS = client_v2_patterns("/tokenrefresh") PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs): def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__() super(TokenRefreshRestServlet, self).__init__()

View File

@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns from ._base import client_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet): class UserDirectorySearchRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/user_directory/search$") PATTERNS = client_patterns("/user_directory/search$")
def __init__(self, hs): def __init__(self, hs):
""" """

View File

@ -20,7 +20,7 @@ from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import KeyLookupError, ServerKeyFetcher from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
@ -215,15 +215,7 @@ class RemoteKey(Resource):
json_results.add(bytes(result["key_json"])) json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss: if cache_misses and query_remote_on_cache_miss:
for server_name, key_ids in cache_misses.items(): yield self.fetcher.get_keys(cache_misses)
try:
yield self.fetcher.get_server_verify_key_v2_direct(
server_name, key_ids
)
except KeyLookupError as e:
logger.info("Failed to fetch key: %s", e)
except Exception:
logger.exception("Failed to get key for %r", server_name)
yield self.query_keys( yield self.query_keys(
request, query, query_remote_on_cache_miss=False request, query, query_remote_on_cache_miss=False
) )

View File

@ -16,6 +16,7 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
import logging import logging
import random
import sys import sys
import threading import threading
import time import time
@ -247,6 +248,8 @@ class SQLBaseStore(object):
self._check_safe_to_upsert, self._check_safe_to_upsert,
) )
self.rand = random.SystemRandom()
if self._account_validity.enabled: if self._account_validity.enabled:
self._clock.call_later( self._clock.call_later(
0.0, 0.0,
@ -308,21 +311,36 @@ class SQLBaseStore(object):
res = self.cursor_to_dict(txn) res = self.cursor_to_dict(txn)
if res: if res:
for user in res: for user in res:
self.set_expiration_date_for_user_txn(txn, user["name"]) self.set_expiration_date_for_user_txn(
txn,
user["name"],
use_delta=True,
)
yield self.runInteraction( yield self.runInteraction(
"get_users_with_no_expiration_date", "get_users_with_no_expiration_date",
select_users_with_no_expiration_date_txn, select_users_with_no_expiration_date_txn,
) )
def set_expiration_date_for_user_txn(self, txn, user_id): def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
"""Sets an expiration date to the account with the given user ID. """Sets an expiration date to the account with the given user ID.
Args: Args:
user_id (str): User ID to set an expiration date for. user_id (str): User ID to set an expiration date for.
use_delta (bool): If set to False, the expiration date for the user will be
now + validity period. If set to True, this expiration date will be a
random value in the [now + period - d ; now + period] range, d being a
delta equal to 10% of the validity period.
""" """
now_ms = self._clock.time_msec() now_ms = self._clock.time_msec()
expiration_ts = now_ms + self._account_validity.period expiration_ts = now_ms + self._account_validity.period
if use_delta:
expiration_ts = self.rand.randrange(
expiration_ts - self._account_validity.startup_job_max_delta,
expiration_ts,
)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
"account_validity", "account_validity",

View File

@ -46,8 +46,7 @@ class NotRetryingDestination(Exception):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_retry_limiter(destination, clock, store, ignore_backoff=False, def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
**kwargs):
"""For a given destination check if we have previously failed to """For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination. send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a If we are not ready to retry the destination, this will raise a
@ -60,8 +59,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
clock (synapse.util.clock): timing source clock (synapse.util.clock): timing source
store (synapse.storage.transactions.TransactionStore): datastore store (synapse.storage.transactions.TransactionStore): datastore
ignore_backoff (bool): true to ignore the historical backoff data and ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway. We will still update the next try the request anyway. We will still reset the retry_interval on success.
retry_interval on success/failure.
Example usage: Example usage:
@ -75,13 +73,12 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
""" """
retry_last_ts, retry_interval = (0, 0) retry_last_ts, retry_interval = (0, 0)
retry_timings = yield store.get_destination_retry_timings( retry_timings = yield store.get_destination_retry_timings(destination)
destination
)
if retry_timings: if retry_timings:
retry_last_ts, retry_interval = ( retry_last_ts, retry_interval = (
retry_timings["retry_last_ts"], retry_timings["retry_interval"] retry_timings["retry_last_ts"],
retry_timings["retry_interval"],
) )
now = int(clock.time_msec()) now = int(clock.time_msec())
@ -93,22 +90,31 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
destination=destination, destination=destination,
) )
# if we are ignoring the backoff data, we should also not increment the backoff
# when we get another failure - otherwise a server can very quickly reach the
# maximum backoff even though it might only have been down briefly
backoff_on_failure = not ignore_backoff
defer.returnValue( defer.returnValue(
RetryDestinationLimiter( RetryDestinationLimiter(
destination, destination, clock, store, retry_interval, backoff_on_failure, **kwargs
clock,
store,
retry_interval,
**kwargs
) )
) )
class RetryDestinationLimiter(object): class RetryDestinationLimiter(object):
def __init__(self, destination, clock, store, retry_interval, def __init__(
self,
destination,
clock,
store,
retry_interval,
min_retry_interval=10 * 60 * 1000, min_retry_interval=10 * 60 * 1000,
max_retry_interval=24 * 60 * 60 * 1000, max_retry_interval=24 * 60 * 60 * 1000,
multiplier_retry_interval=5, backoff_on_404=False): multiplier_retry_interval=5,
backoff_on_404=False,
backoff_on_failure=True,
):
"""Marks the destination as "down" if an exception is thrown in the """Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500. context, except for CodeMessageException with code < 500.
@ -128,6 +134,9 @@ class RetryDestinationLimiter(object):
multiplier_retry_interval (int): The multiplier to use to increase multiplier_retry_interval (int): The multiplier to use to increase
the retry interval after a failed request. the retry interval after a failed request.
backoff_on_404 (bool): Back off if we get a 404 backoff_on_404 (bool): Back off if we get a 404
backoff_on_failure (bool): set to False if we should not increase the
retry interval on a failure.
""" """
self.clock = clock self.clock = clock
self.store = store self.store = store
@ -138,6 +147,7 @@ class RetryDestinationLimiter(object):
self.max_retry_interval = max_retry_interval self.max_retry_interval = max_retry_interval
self.multiplier_retry_interval = multiplier_retry_interval self.multiplier_retry_interval = multiplier_retry_interval
self.backoff_on_404 = backoff_on_404 self.backoff_on_404 = backoff_on_404
self.backoff_on_failure = backoff_on_failure
def __enter__(self): def __enter__(self):
pass pass
@ -173,10 +183,13 @@ class RetryDestinationLimiter(object):
if not self.retry_interval: if not self.retry_interval:
return return
logger.debug("Connection to %s was successful; clearing backoff", logger.debug(
self.destination) "Connection to %s was successful; clearing backoff", self.destination
)
retry_last_ts = 0 retry_last_ts = 0
self.retry_interval = 0 self.retry_interval = 0
elif not self.backoff_on_failure:
return
else: else:
# We couldn't connect. # We couldn't connect.
if self.retry_interval: if self.retry_interval:
@ -190,7 +203,10 @@ class RetryDestinationLimiter(object):
logger.info( logger.info(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i", "Connection to %s was unsuccessful (%s(%s)); backoff now %i",
self.destination, exc_type, exc_val, self.retry_interval self.destination,
exc_type,
exc_val,
self.retry_interval,
) )
retry_last_ts = int(self.clock.time_msec()) retry_last_ts = int(self.clock.time_msec())
@ -201,9 +217,7 @@ class RetryDestinationLimiter(object):
self.destination, retry_last_ts, self.retry_interval self.destination, retry_last_ts, self.retry_interval
) )
except Exception: except Exception:
logger.exception( logger.exception("Failed to store destination_retry_timings")
"Failed to store destination_retry_timings",
)
# we deliberately do this in the background. # we deliberately do this in the background.
synapse.util.logcontext.run_in_background(store_retry_timings) synapse.util.logcontext.run_in_background(store_retry_timings)

View File

@ -21,4 +21,4 @@ import tests.patch_inline_callbacks
# attempt to do the patch before we load any synapse code # attempt to do the patch before we load any synapse code
tests.patch_inline_callbacks.do_patch() tests.patch_inline_callbacks.do_patch()
util.DEFAULT_TIMEOUT_DURATION = 10 util.DEFAULT_TIMEOUT_DURATION = 20

View File

@ -19,16 +19,13 @@ from mock import Mock
import canonicaljson import canonicaljson
import signedjson.key import signedjson.key
import signedjson.sign import signedjson.sign
from signedjson.key import get_verify_key
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto import keyring from synapse.crypto import keyring
from synapse.crypto.keyring import ( from synapse.crypto.keyring import PerspectivesKeyFetcher, ServerKeyFetcher
KeyLookupError,
PerspectivesKeyFetcher,
ServerKeyFetcher,
)
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -137,7 +134,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
context_11.request = "11" context_11.request = "11"
res_deferreds = kr.verify_json_objects_for_server( res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1), ("server11", {})] [("server10", json1, 0), ("server11", {}, 0)]
) )
# the unsigned json should be rejected pretty quickly # the unsigned json should be rejected pretty quickly
@ -174,7 +171,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.return_value = defer.Deferred() self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server( res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)] [("server10", json1, 0)]
) )
res_deferreds_2[0].addBoth(self.check_context, None) res_deferreds_2[0].addBoth(self.check_context, None)
yield logcontext.make_deferred_yieldable(res_deferreds_2[0]) yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
@ -197,31 +194,108 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs) kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1) key1 = signedjson.key.generate_signing_key(1)
key1_id = "%s:%s" % (key1.alg, key1.version)
r = self.hs.datastore.store_server_verify_keys( r = self.hs.datastore.store_server_verify_keys(
"server9", "server9",
time.time() * 1000, time.time() * 1000,
[ [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
(
"server9",
key1_id,
FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
),
],
) )
self.get_success(r) self.get_success(r)
json1 = {} json1 = {}
signedjson.sign.sign_json(json1, "server9", key1) signedjson.sign.sign_json(json1, "server9", key1)
# should fail immediately on an unsigned object # should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {}) d = _verify_json_for_server(kr, "server9", {}, 0)
self.failureResultOf(d, SynapseError) self.failureResultOf(d, SynapseError)
d = _verify_json_for_server(kr, "server9", json1) # should suceed on a signed object
self.assertFalse(d.called) d = _verify_json_for_server(kr, "server9", json1, 500)
# self.assertFalse(d.called)
self.get_success(d) self.get_success(d)
def test_verify_json_dedupes_key_requests(self):
"""Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key(1)
def get_keys(keys_to_fetch):
# there should only be one request object (with the max validity)
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed(
{
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
}
)
mock_fetcher = keyring.KeyFetcher()
mock_fetcher.get_keys = Mock(side_effect=get_keys)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
json1 = {}
signedjson.sign.sign_json(json1, "server1", key1)
# the first request should succeed; the second should fail because the key
# has expired
results = kr.verify_json_objects_for_server(
[("server1", json1, 500), ("server1", json1, 1500)]
)
self.assertEqual(len(results), 2)
self.get_success(results[0])
e = self.get_failure(results[1], SynapseError).value
self.assertEqual(e.errcode, "M_UNAUTHORIZED")
self.assertEqual(e.code, 401)
# there should have been a single call to the fetcher
mock_fetcher.get_keys.assert_called_once()
def test_verify_json_falls_back_to_other_fetchers(self):
"""If the first fetcher cannot provide a recent enough key, we fall back"""
key1 = signedjson.key.generate_signing_key(1)
def get_keys1(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed(
{
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
}
}
)
def get_keys2(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed(
{
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
}
}
)
mock_fetcher1 = keyring.KeyFetcher()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
mock_fetcher2 = keyring.KeyFetcher()
mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
json1 = {}
signedjson.sign.sign_json(json1, "server1", key1)
results = kr.verify_json_objects_for_server(
[("server1", json1, 1200), ("server1", json1, 1500)]
)
self.assertEqual(len(results), 2)
self.get_success(results[0])
e = self.get_failure(results[1], SynapseError).value
self.assertEqual(e.errcode, "M_UNAUTHORIZED")
self.assertEqual(e.code, 401)
# there should have been a single call to each fetcher
mock_fetcher1.get_keys.assert_called_once()
mock_fetcher2.get_keys.assert_called_once()
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
@ -260,8 +334,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json self.http_client.get_json.side_effect = get_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))] keys_to_fetch = {SERVER_NAME: {"key1": 0}}
keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) keys = self.get_success(fetcher.get_keys(keys_to_fetch))
k = keys[SERVER_NAME][testverifykey_id] k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey) self.assertEqual(k.verify_key, testverifykey)
@ -286,11 +360,11 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
) )
# change the server name: it should cause a rejection # change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER" response["server_name"] = "OTHER_SERVER"
self.get_failure(
fetcher.get_keys(server_name_and_key_ids), KeyLookupError keys = self.get_success(fetcher.get_keys(keys_to_fetch))
) self.assertEqual(keys, {})
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
@ -342,8 +416,8 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json self.http_client.post_json.side_effect = post_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))] keys_to_fetch = {SERVER_NAME: {"key1": 0}}
keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertIn(SERVER_NAME, keys) self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id] k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
@ -401,7 +475,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def get_key_from_perspectives(response): def get_key_from_perspectives(response):
fetcher = PerspectivesKeyFetcher(self.hs) fetcher = PerspectivesKeyFetcher(self.hs)
server_name_and_key_ids = [(SERVER_NAME, ("key1",))] keys_to_fetch = {SERVER_NAME: {"key1": 0}}
def post_json(destination, path, data, **kwargs): def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name) self.assertEqual(destination, self.mock_perspective_server.server_name)
@ -410,9 +484,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json self.http_client.post_json.side_effect = post_json
return self.get_success( return self.get_success(fetcher.get_keys(keys_to_fetch))
fetcher.get_keys(server_name_and_key_ids)
)
# start with a valid response so we can check we are testing the right thing # start with a valid response so we can check we are testing the right thing
response = build_response() response = build_response()
@ -435,6 +507,11 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
def get_key_id(key):
"""Get the matrix ID tag for a given SigningKey or VerifyKey"""
return "%s:%s" % (key.alg, key.version)
@defer.inlineCallbacks @defer.inlineCallbacks
def run_in_context(f, *args, **kwargs): def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx") as ctx: with LoggingContext("testctx") as ctx:
@ -445,14 +522,16 @@ def run_in_context(f, *args, **kwargs):
defer.returnValue(rv) defer.returnValue(rv)
def _verify_json_for_server(keyring, server_name, json_object): def _verify_json_for_server(keyring, server_name, json_object, validity_time):
"""thin wrapper around verify_json_for_server which makes sure it is wrapped """thin wrapper around verify_json_for_server which makes sure it is wrapped
with the patched defer.inlineCallbacks. with the patched defer.inlineCallbacks.
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def v(): def v():
rv1 = yield keyring.verify_json_for_server(server_name, json_object) rv1 = yield keyring.verify_json_for_server(
server_name, json_object, validity_time
)
defer.returnValue(rv1) defer.returnValue(rv1)
return run_in_context(v) return run_in_context(v)

View File

@ -408,7 +408,6 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
users_in_room = self.get_success(self.store.get_users_in_room(room_id)) users_in_room = self.get_success(self.store.get_users_in_room(room_id))
self.assertEqual([], users_in_room) self.assertEqual([], users_in_room)
@unittest.DEBUG
def test_shutdown_room_block_peek(self): def test_shutdown_room_block_peek(self):
"""Test that a world_readable room can no longer be peeked into after """Test that a world_readable room can no longer be peeked into after
it has been shut down. it has been shut down.

View File

@ -30,7 +30,7 @@ from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test" myid = "@1234ABCD:test"
PATH_PREFIX = "/_matrix/client/api/v1" PATH_PREFIX = "/_matrix/client/r0"
class MockHandlerProfileTestCase(unittest.TestCase): class MockHandlerProfileTestCase(unittest.TestCase):

View File

@ -436,6 +436,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
self.validity_period = 10 self.validity_period = 10
self.max_delta = self.validity_period * 10. / 100.
config = self.default_config() config = self.default_config()
@ -453,14 +454,18 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
def test_background_job(self): def test_background_job(self):
""" """
Tests whether the account validity startup background job does the right thing, Tests the same thing as test_background_job, except that it sets the
which is sticking an expiration date to every account that doesn't already have startup_job_max_delta parameter and checks that the expiration date is within the
one. allowed range.
""" """
user_id = self.register_user("kermit", "user") user_id = self.register_user("kermit_delta", "user")
self.hs.config.account_validity.startup_job_max_delta = self.max_delta
now_ms = self.hs.clock.time_msec() now_ms = self.hs.clock.time_msec()
self.get_success(self.store._set_expiration_date_when_missing()) self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id)) res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
self.assertEqual(res, now_ms + self.validity_period)
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period)