Merge branch 'develop' of github.com:matrix-org/synapse into erikj/make_database_class

pull/6469/head
Erik Johnston 2019-12-06 11:28:44 +00:00
commit 2ace775d88
56 changed files with 767 additions and 824 deletions

View File

@ -1,7 +1,7 @@
# Configuration file used for testing the 'synapse_port_db' script. # Configuration file used for testing the 'synapse_port_db' script.
# Tells the script to connect to the postgresql database that will be available in the # Tells the script to connect to the postgresql database that will be available in the
# CI's Docker setup at the point where this file is considered. # CI's Docker setup at the point where this file is considered.
server_name: "localhost:8080" server_name: "test"
signing_key_path: "/src/.buildkite/test.signing.key" signing_key_path: "/src/.buildkite/test.signing.key"

View File

@ -1,7 +1,7 @@
# Configuration file used for testing the 'synapse_port_db' script. # Configuration file used for testing the 'synapse_port_db' script.
# Tells the 'update_database' script to connect to the test SQLite database to upgrade its # Tells the 'update_database' script to connect to the test SQLite database to upgrade its
# schema and run background updates on it. # schema and run background updates on it.
server_name: "localhost:8080" server_name: "test"
signing_key_path: "/src/.buildkite/test.signing.key" signing_key_path: "/src/.buildkite/test.signing.key"

1
changelog.d/5925.feature Normal file
View File

@ -0,0 +1 @@
Add admin/v2/users endpoint with pagination. Contributed by Awesome Technologies Innovationslabor GmbH.

1
changelog.d/5925.removal Normal file
View File

@ -0,0 +1 @@
Remove admin/v1/users_paginate endpoint. Contributed by Awesome Technologies Innovationslabor GmbH.

1
changelog.d/6369.doc Normal file
View File

@ -0,0 +1 @@
Update documentation and variables in user contributed systemd reference file.

1
changelog.d/6472.bugfix Normal file
View File

@ -0,0 +1 @@
Improve sanity-checking when receiving events over federation.

1
changelog.d/6480.misc Normal file
View File

@ -0,0 +1 @@
Refactor some code in the event authentication path for clarity.

1
changelog.d/6482.misc Normal file
View File

@ -0,0 +1 @@
Port synapse.rest.client.v1 to async/await.

1
changelog.d/6483.misc Normal file
View File

@ -0,0 +1 @@
Port synapse.rest.client.v2_alpha to async/await.

17
contrib/systemd/README.md Normal file
View File

@ -0,0 +1,17 @@
# Setup Synapse with Systemd
This is a setup for managing synapse with a user contributed systemd unit
file. It provides a `matrix-synapse` systemd unit file that should be tailored
to accommodate your installation in accordance with the installation
instructions provided in [installation instructions](../../INSTALL.md).
## Setup
1. Under the service section, ensure the `User` variable matches which user
you installed synapse under and wish to run it as.
2. Under the service section, ensure the `WorkingDirectory` variable matches
where you have installed synapse.
3. Under the service section, ensure the `ExecStart` variable matches the
appropriate locations of your installation.
4. Copy the `matrix-synapse.service` to `/etc/systemd/system/`
5. Start Synapse: `sudo systemctl start matrix-synapse`
6. Verify Synapse is running: `sudo systemctl status matrix-synapse`
7. *optional* Enable Synapse to start at system boot: `sudo systemctl enable matrix-synapse`

View File

@ -4,8 +4,11 @@
# systemctl enable matrix-synapse # systemctl enable matrix-synapse
# systemctl start matrix-synapse # systemctl start matrix-synapse
# #
# This assumes that Synapse has been installed by a user named
# synapse.
#
# This assumes that Synapse has been installed in a virtualenv in # This assumes that Synapse has been installed in a virtualenv in
# /opt/synapse/env. # the user's home directory: `/home/synapse/synapse/env`.
# #
# **NOTE:** This is an example service file that may change in the future. If you # **NOTE:** This is an example service file that may change in the future. If you
# wish to use this please copy rather than symlink it. # wish to use this please copy rather than symlink it.
@ -23,7 +26,7 @@ User=synapse
Group=nogroup Group=nogroup
WorkingDirectory=/opt/synapse WorkingDirectory=/opt/synapse
ExecStart=/opt/synapse/env/bin/python -m synapse.app.homeserver --config-path=/opt/synapse/homeserver.yaml ExecStart=/home/synapse/synapse/env/bin/python -m synapse.app.homeserver --config-path=/home/synapse/synapse/homeserver.yaml
SyslogIdentifier=matrix-synapse SyslogIdentifier=matrix-synapse
# adjust the cache factor if necessary # adjust the cache factor if necessary

View File

@ -1,3 +1,48 @@
List Accounts
=============
This API returns all local user accounts.
The api is::
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
including an ``access_token`` of a server admin.
The parameters ``from`` and ``limit`` are required only for pagination.
By default, a ``limit`` of 100 is used.
The parameter ``user_id`` can be used to select only users with user ids that
contain this value.
The parameter ``guests=false`` can be used to exclude guest users,
default is to include guest users.
The parameter ``deactivated=true`` can be used to include deactivated users,
default is to exclude deactivated users.
If the endpoint does not return a ``next_token`` then there are no more users left.
It returns a JSON body like the following:
.. code:: json
{
"users": [
{
"name": "<user_id1>",
"password_hash": "<password_hash1>",
"is_guest": 0,
"admin": 0,
"user_type": null,
"deactivated": 0
}, {
"name": "<user_id2>",
"password_hash": "<password_hash2>",
"is_guest": 0,
"admin": 1,
"user_type": null,
"deactivated": 0
}
],
"next_token": "100"
}
Query Account Query Account
============= =============

View File

@ -151,7 +151,7 @@ class SynchrotronPresence(object):
def set_state(self, user, state, ignore_status_msg=False): def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work? # TODO Hows this supposed to work?
pass return defer.succeed(None)
get_states = __func__(PresenceHandler.get_states) get_states = __func__(PresenceHandler.get_states)
get_state = __func__(PresenceHandler.get_state) get_state = __func__(PresenceHandler.get_state)

View File

@ -56,7 +56,7 @@ class AdminHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_users(self): def get_users(self):
"""Function to reterive a list of users in users table. """Function to retrieve a list of users in users table.
Args: Args:
Returns: Returns:
@ -67,19 +67,22 @@ class AdminHandler(BaseHandler):
return ret return ret
@defer.inlineCallbacks @defer.inlineCallbacks
def get_users_paginate(self, order, start, limit): def get_users_paginate(self, start, limit, name, guests, deactivated):
"""Function to reterive a paginated list of users from """Function to retrieve a paginated list of users from
users list. This will return a json object, which contains users list. This will return a json list of users.
list of users and the total number of users in users table.
Args: Args:
order (str): column name to order the select by this column
start (int): start number to begin the query from start (int): start number to begin the query from
limit (int): number of rows to reterive limit (int): number of rows to retrieve
name (string): filter for user names
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
Returns: Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count} defer.Deferred: resolves to json list[dict[str, Any]]
""" """
ret = yield self.store.get_users_paginate(order, start, limit) ret = yield self.store.get_users_paginate(
start, limit, name, guests, deactivated
)
return ret return ret

View File

@ -19,11 +19,13 @@
import itertools import itertools
import logging import logging
from typing import Dict, Iterable, Optional, Sequence, Tuple
import six import six
from six import iteritems, itervalues from six import iteritems, itervalues
from six.moves import http_client, zip from six.moves import http_client, zip
import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
@ -45,6 +47,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import auth_types_for_event from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.logging.context import ( from synapse.logging.context import (
@ -72,6 +75,23 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s
class _NewEventInfo:
"""Holds information about a received event, ready for passing to _handle_new_events
Attributes:
event: the received event
state: the state at that event
auth_events: the auth_event map for that event
"""
event = attr.ib(type=EventBase)
state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None)
def shortstr(iterable, maxitems=5): def shortstr(iterable, maxitems=5):
"""If iterable has maxitems or fewer, return the stringification of a list """If iterable has maxitems or fewer, return the stringification of a list
containing those items. containing those items.
@ -597,14 +617,14 @@ class FederationHandler(BaseHandler):
for e in auth_chain for e in auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create if e.event_id in auth_ids or e.type == EventTypes.Create
} }
event_infos.append({"event": e, "auth_events": auth}) event_infos.append(_NewEventInfo(event=e, auth_events=auth))
seen_ids.add(e.event_id) seen_ids.add(e.event_id)
logger.info( logger.info(
"[%s %s] persisting newly-received auth/state events %s", "[%s %s] persisting newly-received auth/state events %s",
room_id, room_id,
event_id, event_id,
[e["event"].event_id for e in event_infos], [e.event.event_id for e in event_infos],
) )
yield self._handle_new_events(origin, event_infos) yield self._handle_new_events(origin, event_infos)
@ -795,9 +815,9 @@ class FederationHandler(BaseHandler):
a.internal_metadata.outlier = True a.internal_metadata.outlier = True
ev_infos.append( ev_infos.append(
{ _NewEventInfo(
"event": a, event=a,
"auth_events": { auth_events={
( (
auth_events[a_id].type, auth_events[a_id].type,
auth_events[a_id].state_key, auth_events[a_id].state_key,
@ -805,7 +825,7 @@ class FederationHandler(BaseHandler):
for a_id in a.auth_event_ids() for a_id in a.auth_event_ids()
if a_id in auth_events if a_id in auth_events
}, },
} )
) )
# Step 1b: persist the events in the chunk we fetched state for (i.e. # Step 1b: persist the events in the chunk we fetched state for (i.e.
@ -817,10 +837,10 @@ class FederationHandler(BaseHandler):
assert not ev.internal_metadata.is_outlier() assert not ev.internal_metadata.is_outlier()
ev_infos.append( ev_infos.append(
{ _NewEventInfo(
"event": ev, event=ev,
"state": events_to_state[e_id], state=events_to_state[e_id],
"auth_events": { auth_events={
( (
auth_events[a_id].type, auth_events[a_id].type,
auth_events[a_id].state_key, auth_events[a_id].state_key,
@ -828,7 +848,7 @@ class FederationHandler(BaseHandler):
for a_id in ev.auth_event_ids() for a_id in ev.auth_event_ids()
if a_id in auth_events if a_id in auth_events
}, },
} )
) )
yield self._handle_new_events(dest, ev_infos, backfilled=True) yield self._handle_new_events(dest, ev_infos, backfilled=True)
@ -1713,7 +1733,12 @@ class FederationHandler(BaseHandler):
return context return context
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False): def _handle_new_events(
self,
origin: str,
event_infos: Iterable[_NewEventInfo],
backfilled: bool = False,
):
"""Creates the appropriate contexts and persists events. The events """Creates the appropriate contexts and persists events. The events
should not depend on one another, e.g. this should be used to persist should not depend on one another, e.g. this should be used to persist
a bunch of outliers, but not a chunk of individual events that depend a bunch of outliers, but not a chunk of individual events that depend
@ -1723,14 +1748,14 @@ class FederationHandler(BaseHandler):
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def prep(ev_info): def prep(ev_info: _NewEventInfo):
event = ev_info["event"] event = ev_info.event
with nested_logging_context(suffix=event.event_id): with nested_logging_context(suffix=event.event_id):
res = yield self._prep_event( res = yield self._prep_event(
origin, origin,
event, event,
state=ev_info.get("state"), state=ev_info.state,
auth_events=ev_info.get("auth_events"), auth_events=ev_info.auth_events,
backfilled=backfilled, backfilled=backfilled,
) )
return res return res
@ -1744,7 +1769,7 @@ class FederationHandler(BaseHandler):
yield self.persist_events_and_notify( yield self.persist_events_and_notify(
[ [
(ev_info["event"], context) (ev_info.event, context)
for ev_info, context in zip(event_infos, contexts) for ev_info, context in zip(event_infos, contexts)
], ],
backfilled=backfilled, backfilled=backfilled,
@ -1846,7 +1871,14 @@ class FederationHandler(BaseHandler):
yield self.persist_events_and_notify([(event, new_event_context)]) yield self.persist_events_and_notify([(event, new_event_context)])
@defer.inlineCallbacks @defer.inlineCallbacks
def _prep_event(self, origin, event, state, auth_events, backfilled): def _prep_event(
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
auth_events: Optional[Dict[Tuple[str, str], EventBase]],
backfilled: bool,
):
""" """
Args: Args:
@ -1854,7 +1886,7 @@ class FederationHandler(BaseHandler):
event: event:
state: state:
auth_events: auth_events:
backfilled (bool) backfilled:
Returns: Returns:
Deferred, which resolves to synapse.events.snapshot.EventContext Deferred, which resolves to synapse.events.snapshot.EventContext
@ -1890,15 +1922,16 @@ class FederationHandler(BaseHandler):
return context return context
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_for_soft_fail(self, event, state, backfilled): def _check_for_soft_fail(
self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
):
"""Checks if we should soft fail the event, if so marks the event as """Checks if we should soft fail the event, if so marks the event as
such. such.
Args: Args:
event (FrozenEvent) event
state (dict|None): The state at the event if we don't have all the state: The state at the event if we don't have all the event's prev events
event's prev events backfilled: Whether the event is from backfill
backfilled (bool): Whether the event is from backfill
Returns: Returns:
Deferred Deferred
@ -2195,21 +2228,37 @@ class FederationHandler(BaseHandler):
different_auth, different_auth,
) )
# now we state-resolve between our own idea of the auth events, and the remote's
# idea of them.
room_version = yield self.store.get_room_version(event.room_id)
# XXX: currently this checks for redactions but I'm not convinced that is # XXX: currently this checks for redactions but I'm not convinced that is
# necessary? # necessary?
different_events = yield self.store.get_events_as_list(different_auth) different_events = yield self.store.get_events_as_list(different_auth)
local_view = dict(auth_events) for d in different_events:
remote_view = dict(auth_events) if d.room_id != event.room_id:
remote_view.update({(d.type, d.state_key): d for d in different_events}) logger.warning(
"Event %s refers to auth_event %s which is in a different room",
event.event_id,
d.event_id,
)
# don't attempt to resolve the claimed auth events against our own
# in this case: just use our own auth events.
#
# XXX: should we reject the event in this case? It feels like we should,
# but then shouldn't we also do so if we've failed to fetch any of the
# auth events?
return context
# now we state-resolve between our own idea of the auth events, and the remote's
# idea of them.
local_state = auth_events.values()
remote_auth_events = dict(auth_events)
remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
remote_state = remote_auth_events.values()
room_version = yield self.store.get_room_version(event.room_id)
new_state = yield self.state_handler.resolve_events( new_state = yield self.state_handler.resolve_events(
room_version, [list(local_view.values()), list(remote_view.values())], event room_version, (local_state, remote_state), event
) )
logger.info( logger.info(

View File

@ -34,12 +34,12 @@ from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import ( from synapse.rest.admin.users import (
AccountValidityRenewServlet, AccountValidityRenewServlet,
DeactivateAccountRestServlet, DeactivateAccountRestServlet,
GetUsersPaginatedRestServlet,
ResetPasswordRestServlet, ResetPasswordRestServlet,
SearchUsersRestServlet, SearchUsersRestServlet,
UserAdminServlet, UserAdminServlet,
UserRegisterServlet, UserRegisterServlet,
UsersRestServlet, UsersRestServlet,
UsersRestServletV2,
WhoisRestServlet, WhoisRestServlet,
) )
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
@ -191,6 +191,7 @@ def register_servlets(hs, http_server):
SendServerNoticeServlet(hs).register(http_server) SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server) VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server): def register_servlets_for_client_rest_resource(hs, http_server):
@ -201,7 +202,6 @@ def register_servlets_for_client_rest_resource(hs, http_server):
PurgeHistoryRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server) UsersRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server) ResetPasswordRestServlet(hs).register(http_server)
GetUsersPaginatedRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server) ShutdownRoomRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server)

View File

@ -25,6 +25,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
parse_boolean,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
@ -59,71 +60,45 @@ class UsersRestServlet(RestServlet):
return 200, ret return 200, ret
class GetUsersPaginatedRestServlet(RestServlet): class UsersRestServletV2(RestServlet):
"""Get request to get specific number of users from Synapse. PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),)
This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_synapse/admin/v1/users_paginate/
@admin:user?access_token=admin_access_token&start=0&limit=10
Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
PATTERNS = historical_admin_path_patterns( """Get request to list all local users.
"/users_paginate/(?P<target_user_id>[^/]*)" This needs user to have administrator access in Synapse.
)
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
returns:
200 OK with list of users if success otherwise an error.
The parameters `from` and `limit` are required only for pagination.
By default, a `limit` of 100 is used.
The parameter `user_id` can be used to filter by user id.
The parameter `guests` can be used to exclude guest users.
The parameter `deactivated` can be used to include deactivated users.
"""
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.handlers = hs.get_handlers() self.admin_handler = hs.get_handlers().admin_handler
async def on_GET(self, request, target_user_id): async def on_GET(self, request):
"""Get request to get specific number of users from Synapse.
This needs user to have administrator access in Synapse.
"""
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(target_user_id) start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
user_id = parse_string(request, "user_id", default=None)
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
if not self.hs.is_mine(target_user): users = await self.admin_handler.get_users_paginate(
raise SynapseError(400, "Can only users a local user") start, limit, user_id, guests, deactivated
)
ret = {"users": users}
if len(users) >= limit:
ret["next_token"] = str(start + len(users))
order = "name" # order by name in user table
start = parse_integer(request, "start", required=True)
limit = parse_integer(request, "limit", required=True)
logger.info("limit: %s, start: %s", limit, start)
ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
return 200, ret
async def on_POST(self, request, target_user_id):
"""Post request to get specific number of users from Synapse..
This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_synapse/admin/v1/users_paginate/
@admin:user?access_token=admin_access_token
JsonBodyToSend:
{
"start": "0",
"limit": "10
}
Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
await assert_requester_is_admin(self.auth, request)
UserID.from_string(target_user_id)
order = "name" # order by name in user table
params = parse_json_object_from_request(request)
assert_params_in_dict(params, ["limit", "start"])
limit = params["limit"]
start = params["start"]
logger.info("limit: %s, start: %s", limit, start)
ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
return 200, ret return 200, ret

View File

@ -16,8 +16,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -47,17 +45,15 @@ class ClientDirectoryServer(RestServlet):
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, room_alias):
def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
dir_handler = self.handlers.directory_handler dir_handler = self.handlers.directory_handler
res = yield dir_handler.get_association(room_alias) res = await dir_handler.get_association(room_alias)
return 200, res return 200, res
@defer.inlineCallbacks async def on_PUT(self, request, room_alias):
def on_PUT(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -77,26 +73,25 @@ class ClientDirectoryServer(RestServlet):
# TODO(erikj): Check types. # TODO(erikj): Check types.
room = yield self.store.get_room(room_id) room = await self.store.get_room(room_id)
if room is None: if room is None:
raise SynapseError(400, "Room does not exist") raise SynapseError(400, "Room does not exist")
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
yield self.handlers.directory_handler.create_association( await self.handlers.directory_handler.create_association(
requester, room_alias, room_id, servers requester, room_alias, room_id, servers
) )
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_DELETE(self, request, room_alias):
def on_DELETE(self, request, room_alias):
dir_handler = self.handlers.directory_handler dir_handler = self.handlers.directory_handler
try: try:
service = yield self.auth.get_appservice_by_req(request) service = await self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
yield dir_handler.delete_appservice_association(service, room_alias) await dir_handler.delete_appservice_association(service, room_alias)
logger.info( logger.info(
"Application service at %s deleted alias %s", "Application service at %s deleted alias %s",
service.url, service.url,
@ -107,12 +102,12 @@ class ClientDirectoryServer(RestServlet):
# fallback to default user behaviour if they aren't an AS # fallback to default user behaviour if they aren't an AS
pass pass
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user = requester.user user = requester.user
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
yield dir_handler.delete_association(requester, room_alias) await dir_handler.delete_association(requester, room_alias)
logger.info( logger.info(
"User %s deleted alias %s", user.to_string(), room_alias.to_string() "User %s deleted alias %s", user.to_string(), room_alias.to_string()
@ -130,32 +125,29 @@ class ClientDirectoryListServer(RestServlet):
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, room_id):
def on_GET(self, request, room_id): room = await self.store.get_room(room_id)
room = yield self.store.get_room(room_id)
if room is None: if room is None:
raise NotFoundError("Unknown room") raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"} return 200, {"visibility": "public" if room["is_public"] else "private"}
@defer.inlineCallbacks async def on_PUT(self, request, room_id):
def on_PUT(self, request, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public") visibility = content.get("visibility", "public")
yield self.handlers.directory_handler.edit_published_room_list( await self.handlers.directory_handler.edit_published_room_list(
requester, room_id, visibility requester, room_id, visibility
) )
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_DELETE(self, request, room_id):
def on_DELETE(self, request, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
yield self.handlers.directory_handler.edit_published_room_list( await self.handlers.directory_handler.edit_published_room_list(
requester, room_id, "private" requester, room_id, "private"
) )
@ -181,15 +173,14 @@ class ClientAppserviceDirectoryListServer(RestServlet):
def on_DELETE(self, request, network_id, room_id): def on_DELETE(self, request, network_id, room_id):
return self._edit(request, network_id, room_id, "private") return self._edit(request, network_id, room_id, "private")
@defer.inlineCallbacks async def _edit(self, request, network_id, room_id, visibility):
def _edit(self, request, network_id, room_id, visibility): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if not requester.app_service: if not requester.app_service:
raise AuthError( raise AuthError(
403, "Only appservices can edit the appservice published room list" 403, "Only appservices can edit the appservice published room list"
) )
yield self.handlers.directory_handler.edit_published_appservice_room_list( await self.handlers.directory_handler.edit_published_appservice_room_list(
requester.app_service.id, network_id, room_id, visibility requester.app_service.id, network_id, room_id, visibility
) )

View File

@ -16,8 +16,6 @@
"""This module contains REST servlets to do with event streaming, /events.""" """This module contains REST servlets to do with event streaming, /events."""
import logging import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
@ -36,9 +34,8 @@ class EventStreamRestServlet(RestServlet):
self.event_stream_handler = hs.get_event_stream_handler() self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest is_guest = requester.is_guest
room_id = None room_id = None
if is_guest: if is_guest:
@ -57,7 +54,7 @@ class EventStreamRestServlet(RestServlet):
as_client_event = b"raw" not in request.args as_client_event = b"raw" not in request.args
chunk = yield self.event_stream_handler.get_stream( chunk = await self.event_stream_handler.get_stream(
requester.user.to_string(), requester.user.to_string(),
pagin_config, pagin_config,
timeout=timeout, timeout=timeout,
@ -83,14 +80,13 @@ class EventRestServlet(RestServlet):
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks async def on_GET(self, request, event_id):
def on_GET(self, request, event_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request) event = await self.event_handler.get_event(requester.user, None, event_id)
event = yield self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:
event = yield self._event_serializer.serialize_event(event, time_now) event = await self._event_serializer.serialize_event(event, time_now)
return 200, event return 200, event
else: else:
return 404, "Event not found." return 404, "Event not found."

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_boolean from synapse.http.servlet import RestServlet, parse_boolean
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
@ -29,13 +28,12 @@ class InitialSyncRestServlet(RestServlet):
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
as_client_event = b"raw" not in request.args as_client_event = b"raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
include_archived = parse_boolean(request, "archived", default=False) include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms( content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event, as_client_event=as_client_event,

View File

@ -18,7 +18,6 @@ import xml.etree.ElementTree as ET
from six.moves import urllib from six.moves import urllib
from twisted.internet import defer
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.errors import Codes, LoginError, SynapseError
@ -130,8 +129,7 @@ class LoginRestServlet(RestServlet):
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
self._address_ratelimiter.ratelimit( self._address_ratelimiter.ratelimit(
request.getClientIP(), request.getClientIP(),
time_now_s=self.hs.clock.time(), time_now_s=self.hs.clock.time(),
@ -145,11 +143,11 @@ class LoginRestServlet(RestServlet):
if self.jwt_enabled and ( if self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE login_submission["type"] == LoginRestServlet.JWT_TYPE
): ):
result = yield self.do_jwt_login(login_submission) result = await self.do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission) result = await self.do_token_login(login_submission)
else: else:
result = yield self._do_other_login(login_submission) result = await self._do_other_login(login_submission)
except KeyError: except KeyError:
raise SynapseError(400, "Missing JSON keys.") raise SynapseError(400, "Missing JSON keys.")
@ -158,8 +156,7 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data result["well_known"] = well_known_data
return 200, result return 200, result
@defer.inlineCallbacks async def _do_other_login(self, login_submission):
def _do_other_login(self, login_submission):
"""Handle non-token/saml/jwt logins """Handle non-token/saml/jwt logins
Args: Args:
@ -219,20 +216,20 @@ class LoginRestServlet(RestServlet):
( (
canonical_user_id, canonical_user_id,
callback_3pid, callback_3pid,
) = yield self.auth_handler.check_password_provider_3pid( ) = await self.auth_handler.check_password_provider_3pid(
medium, address, login_submission["password"] medium, address, login_submission["password"]
) )
if canonical_user_id: if canonical_user_id:
# Authentication through password provider and 3pid succeeded # Authentication through password provider and 3pid succeeded
result = yield self._complete_login( result = await self._complete_login(
canonical_user_id, login_submission, callback_3pid canonical_user_id, login_submission, callback_3pid
) )
return result return result
# No password providers were able to handle this 3pid # No password providers were able to handle this 3pid
# Check local store # Check local store
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = await self.hs.get_datastore().get_user_id_by_threepid(
medium, address medium, address
) )
if not user_id: if not user_id:
@ -280,7 +277,7 @@ class LoginRestServlet(RestServlet):
) )
try: try:
canonical_user_id, callback = yield self.auth_handler.validate_login( canonical_user_id, callback = await self.auth_handler.validate_login(
identifier["user"], login_submission identifier["user"], login_submission
) )
except LoginError: except LoginError:
@ -297,13 +294,12 @@ class LoginRestServlet(RestServlet):
) )
raise raise
result = yield self._complete_login( result = await self._complete_login(
canonical_user_id, login_submission, callback canonical_user_id, login_submission, callback
) )
return result return result
@defer.inlineCallbacks async def _complete_login(
def _complete_login(
self, user_id, login_submission, callback=None, create_non_existant_users=False self, user_id, login_submission, callback=None, create_non_existant_users=False
): ):
"""Called when we've successfully authed the user and now need to """Called when we've successfully authed the user and now need to
@ -337,15 +333,15 @@ class LoginRestServlet(RestServlet):
) )
if create_non_existant_users: if create_non_existant_users:
user_id = yield self.auth_handler.check_user_exists(user_id) user_id = await self.auth_handler.check_user_exists(user_id)
if not user_id: if not user_id:
user_id = yield self.registration_handler.register_user( user_id = await self.registration_handler.register_user(
localpart=UserID.from_string(user_id).localpart localpart=UserID.from_string(user_id).localpart
) )
device_id = login_submission.get("device_id") device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name") initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device( device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
@ -357,23 +353,21 @@ class LoginRestServlet(RestServlet):
} }
if callback is not None: if callback is not None:
yield callback(result) await callback(result)
return result return result
@defer.inlineCallbacks async def do_token_login(self, login_submission):
def do_token_login(self, login_submission):
token = login_submission["token"] token = login_submission["token"]
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id( user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
token token
) )
result = yield self._complete_login(user_id, login_submission) result = await self._complete_login(user_id, login_submission)
return result return result
@defer.inlineCallbacks async def do_jwt_login(self, login_submission):
def do_jwt_login(self, login_submission):
token = login_submission.get("token", None) token = login_submission.get("token", None)
if token is None: if token is None:
raise LoginError( raise LoginError(
@ -397,7 +391,7 @@ class LoginRestServlet(RestServlet):
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID(user, self.hs.hostname).to_string() user_id = UserID(user, self.hs.hostname).to_string()
result = yield self._complete_login( result = await self._complete_login(
user_id, login_submission, create_non_existant_users=True user_id, login_submission, create_non_existant_users=True
) )
return result return result
@ -460,8 +454,7 @@ class CasTicketServlet(RestServlet):
self._sso_auth_handler = SSOAuthHandler(hs) self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request):
client_redirect_url = parse_string(request, "redirectUrl", required=True) client_redirect_url = parse_string(request, "redirectUrl", required=True)
uri = self.cas_server_url + "/proxyValidate" uri = self.cas_server_url + "/proxyValidate"
args = { args = {
@ -469,12 +462,12 @@ class CasTicketServlet(RestServlet):
"service": self.cas_service_url, "service": self.cas_service_url,
} }
try: try:
body = yield self._http_client.get_raw(uri, args) body = await self._http_client.get_raw(uri, args)
except PartialDownloadError as pde: except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed, # Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data # even if that's being used old-http style to signal end-of-data
body = pde.response body = pde.response
result = yield self.handle_cas_response(request, body, client_redirect_url) result = await self.handle_cas_response(request, body, client_redirect_url)
return result return result
def handle_cas_response(self, request, cas_response_body, client_redirect_url): def handle_cas_response(self, request, cas_response_body, client_redirect_url):
@ -555,8 +548,7 @@ class SSOAuthHandler(object):
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator() self._macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks async def on_successful_auth(
def on_successful_auth(
self, username, request, client_redirect_url, user_display_name=None self, username, request, client_redirect_url, user_display_name=None
): ):
"""Called once the user has successfully authenticated with the SSO. """Called once the user has successfully authenticated with the SSO.
@ -582,9 +574,9 @@ class SSOAuthHandler(object):
""" """
localpart = map_username_to_mxid_localpart(username) localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string() user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = yield self._auth_handler.check_user_exists(user_id) registered_user_id = await self._auth_handler.check_user_exists(user_id)
if not registered_user_id: if not registered_user_id:
registered_user_id = yield self._registration_handler.register_user( registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name localpart=localpart, default_display_name=user_display_name
) )

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
@ -35,17 +33,16 @@ class LogoutRestServlet(RestServlet):
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if requester.device_id is None: if requester.device_id is None:
# the acccess token wasn't associated with a device. # the acccess token wasn't associated with a device.
# Just delete the access token # Just delete the access token
access_token = self.auth.get_access_token_from_request(request) access_token = self.auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token) await self._auth_handler.delete_access_token(access_token)
else: else:
yield self._device_handler.delete_device( await self._device_handler.delete_device(
requester.user.to_string(), requester.device_id requester.user.to_string(), requester.device_id
) )
@ -64,17 +61,16 @@ class LogoutAllRestServlet(RestServlet):
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
# first delete all of the user's devices # first delete all of the user's devices
yield self._device_handler.delete_all_devices_for_user(user_id) await self._device_handler.delete_all_devices_for_user(user_id)
# .. and then delete any access tokens which weren't associated with # .. and then delete any access tokens which weren't associated with
# devices. # devices.
yield self._auth_handler.delete_access_tokens_for_user(user_id) await self._auth_handler.delete_access_tokens_for_user(user_id)
return 200, {} return 200, {}

View File

@ -19,8 +19,6 @@ import logging
from six import string_types from six import string_types
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -40,27 +38,25 @@ class PresenceStatusRestServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if requester.user != user: if requester.user != user:
allowed = yield self.presence_handler.is_visible( allowed = await self.presence_handler.is_visible(
observed_user=user, observer_user=requester.user observed_user=user, observer_user=requester.user
) )
if not allowed: if not allowed:
raise AuthError(403, "You are not allowed to see their presence.") raise AuthError(403, "You are not allowed to see their presence.")
state = yield self.presence_handler.get_state(target_user=user) state = await self.presence_handler.get_state(target_user=user)
state = format_user_presence_state(state, self.clock.time_msec()) state = format_user_presence_state(state, self.clock.time_msec())
return 200, state return 200, state
@defer.inlineCallbacks async def on_PUT(self, request, user_id):
def on_PUT(self, request, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if requester.user != user: if requester.user != user:
@ -86,7 +82,7 @@ class PresenceStatusRestServlet(RestServlet):
raise SynapseError(400, "Unable to parse state") raise SynapseError(400, "Unable to parse state")
if self.hs.config.use_presence: if self.hs.config.use_presence:
yield self.presence_handler.set_state(user, state) await self.presence_handler.set_state(user, state)
return 200, {} return 200, {}

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """ """ This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
@ -30,19 +29,18 @@ class ProfileDisplaynameRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id):
requester_user = None requester_user = None
if self.hs.config.require_auth_for_profile_requests: if self.hs.config.require_auth_for_profile_requests:
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user = requester.user requester_user = requester.user
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
yield self.profile_handler.check_profile_query_allowed(user, requester_user) await self.profile_handler.check_profile_query_allowed(user, requester_user)
displayname = yield self.profile_handler.get_displayname(user) displayname = await self.profile_handler.get_displayname(user)
ret = {} ret = {}
if displayname is not None: if displayname is not None:
@ -50,11 +48,10 @@ class ProfileDisplaynameRestServlet(RestServlet):
return 200, ret return 200, ret
@defer.inlineCallbacks async def on_PUT(self, request, user_id):
def on_PUT(self, request, user_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
is_admin = yield self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -63,7 +60,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
except Exception: except Exception:
return 400, "Unable to parse name" return 400, "Unable to parse name"
yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
return 200, {} return 200, {}
@ -80,19 +77,18 @@ class ProfileAvatarURLRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id):
requester_user = None requester_user = None
if self.hs.config.require_auth_for_profile_requests: if self.hs.config.require_auth_for_profile_requests:
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user = requester.user requester_user = requester.user
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
yield self.profile_handler.check_profile_query_allowed(user, requester_user) await self.profile_handler.check_profile_query_allowed(user, requester_user)
avatar_url = yield self.profile_handler.get_avatar_url(user) avatar_url = await self.profile_handler.get_avatar_url(user)
ret = {} ret = {}
if avatar_url is not None: if avatar_url is not None:
@ -100,11 +96,10 @@ class ProfileAvatarURLRestServlet(RestServlet):
return 200, ret return 200, ret
@defer.inlineCallbacks async def on_PUT(self, request, user_id):
def on_PUT(self, request, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
is_admin = yield self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
try: try:
@ -112,7 +107,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
except Exception: except Exception:
return 400, "Unable to parse name" return 400, "Unable to parse name"
yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) await self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
return 200, {} return 200, {}
@ -129,20 +124,19 @@ class ProfileRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id):
requester_user = None requester_user = None
if self.hs.config.require_auth_for_profile_requests: if self.hs.config.require_auth_for_profile_requests:
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user = requester.user requester_user = requester.user
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
yield self.profile_handler.check_profile_query_allowed(user, requester_user) await self.profile_handler.check_profile_query_allowed(user, requester_user)
displayname = yield self.profile_handler.get_displayname(user) displayname = await self.profile_handler.get_displayname(user)
avatar_url = yield self.profile_handler.get_avatar_url(user) avatar_url = await self.profile_handler.get_avatar_url(user)
ret = {} ret = {}
if displayname is not None: if displayname is not None:

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
NotFoundError, NotFoundError,
@ -46,8 +45,7 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None self._is_worker = hs.config.worker_app is not None
@defer.inlineCallbacks async def on_PUT(self, request, path):
def on_PUT(self, request, path):
if self._is_worker: if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker") raise Exception("Cannot handle PUT /push_rules on worker")
@ -57,7 +55,7 @@ class PushRuleRestServlet(RestServlet):
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, str(e)) raise SynapseError(400, str(e))
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
raise SynapseError(400, "rule_id may not contain slashes") raise SynapseError(400, "rule_id may not contain slashes")
@ -67,7 +65,7 @@ class PushRuleRestServlet(RestServlet):
user_id = requester.user.to_string() user_id = requester.user.to_string()
if "attr" in spec: if "attr" in spec:
yield self.set_rule_attr(user_id, spec, content) await self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id) self.notify_user(user_id)
return 200, {} return 200, {}
@ -91,7 +89,7 @@ class PushRuleRestServlet(RestServlet):
after = _namespaced_rule_id(spec, after) after = _namespaced_rule_id(spec, after)
try: try:
yield self.store.add_push_rule( await self.store.add_push_rule(
user_id=user_id, user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec), rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class, priority_class=priority_class,
@ -108,20 +106,19 @@ class PushRuleRestServlet(RestServlet):
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_DELETE(self, request, path):
def on_DELETE(self, request, path):
if self._is_worker: if self._is_worker:
raise Exception("Cannot handle DELETE /push_rules on worker") raise Exception("Cannot handle DELETE /push_rules on worker")
spec = _rule_spec_from_path([x for x in path.split("/")]) spec = _rule_spec_from_path([x for x in path.split("/")])
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try: try:
yield self.store.delete_push_rule(user_id, namespaced_rule_id) await self.store.delete_push_rule(user_id, namespaced_rule_id)
self.notify_user(user_id) self.notify_user(user_id)
return 200, {} return 200, {}
except StoreError as e: except StoreError as e:
@ -130,15 +127,14 @@ class PushRuleRestServlet(RestServlet):
else: else:
raise raise
@defer.inlineCallbacks async def on_GET(self, request, path):
def on_GET(self, request, path): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference # is probably not going to make a whole lot of difference
rules = yield self.store.get_push_rules_for_user(user_id) rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(requester.user, rules) rules = format_push_rules_for_user(requester.user, rules)

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -39,12 +37,11 @@ class PushersRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = requester.user user = requester.user
pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
allowed_keys = [ allowed_keys = [
"app_display_name", "app_display_name",
@ -78,9 +75,8 @@ class PushersSetRestServlet(RestServlet):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool() self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = requester.user user = requester.user
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -91,7 +87,7 @@ class PushersSetRestServlet(RestServlet):
and "kind" in content and "kind" in content
and content["kind"] is None and content["kind"] is None
): ):
yield self.pusher_pool.remove_pusher( await self.pusher_pool.remove_pusher(
content["app_id"], content["pushkey"], user_id=user.to_string() content["app_id"], content["pushkey"], user_id=user.to_string()
) )
return 200, {} return 200, {}
@ -117,14 +113,14 @@ class PushersSetRestServlet(RestServlet):
append = content["append"] append = content["append"]
if not append: if not append:
yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content["app_id"], app_id=content["app_id"],
pushkey=content["pushkey"], pushkey=content["pushkey"],
not_user_id=user.to_string(), not_user_id=user.to_string(),
) )
try: try:
yield self.pusher_pool.add_pusher( await self.pusher_pool.add_pusher(
user_id=user.to_string(), user_id=user.to_string(),
access_token=requester.access_token_id, access_token=requester.access_token_id,
kind=content["kind"], kind=content["kind"],
@ -164,16 +160,15 @@ class PushersRemoveRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool() self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user user = requester.user
app_id = parse_string(request, "app_id", required=True) app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True) pushkey = parse_string(request, "pushkey", required=True)
try: try:
yield self.pusher_pool.remove_pusher( await self.pusher_pool.remove_pusher(
app_id=app_id, pushkey=pushkey, user_id=user.to_string() app_id=app_id, pushkey=pushkey, user_id=user.to_string()
) )
except StoreError as se: except StoreError as se:

View File

@ -17,8 +17,6 @@ import base64
import hashlib import hashlib
import hmac import hmac
from twisted.internet import defer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
@ -31,9 +29,8 @@ class VoipRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(
requester = yield self.auth.get_user_by_req(
request, self.hs.config.turn_allow_guests request, self.hs.config.turn_allow_guests
) )

View File

@ -78,7 +78,7 @@ def interactive_auth_handler(orig):
""" """
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
res = defer.maybeDeferred(orig, *args, **kwargs) res = defer.ensureDeferred(orig(*args, **kwargs))
res.addErrback(_catch_incomplete_interactive_auth) res.addErrback(_catch_incomplete_interactive_auth)
return res return res

View File

@ -18,8 +18,6 @@ import logging
from six.moves import http_client from six.moves import http_client
from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
@ -67,8 +65,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
template_text=template_text, template_text=template_text,
) )
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
@ -95,7 +92,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email "email", email
) )
@ -106,7 +103,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
assert self.hs.config.account_threepid_delegate_email assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request # Have the configured identity server handle the request
ret = yield self.identity_handler.requestEmailToken( ret = await self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email, self.hs.config.account_threepid_delegate_email,
email, email,
client_secret, client_secret,
@ -115,7 +112,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
) )
else: else:
# Send password reset emails from Synapse # Send password reset emails from Synapse
sid = yield self.identity_handler.send_threepid_validation( sid = await self.identity_handler.send_threepid_validation(
email, email,
client_secret, client_secret,
send_attempt, send_attempt,
@ -153,8 +150,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
[self.config.email_password_reset_template_failure_html], [self.config.email_password_reset_template_failure_html],
) )
@defer.inlineCallbacks async def on_GET(self, request, medium):
def on_GET(self, request, medium):
# We currently only handle threepid token submissions for email # We currently only handle threepid token submissions for email
if medium != "email": if medium != "email":
raise SynapseError( raise SynapseError(
@ -176,7 +172,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
# Attempt to validate a 3PID session # Attempt to validate a 3PID session
try: try:
# Mark the session as valid # Mark the session as valid
next_link = yield self.store.validate_threepid_session( next_link = await self.store.validate_threepid_session(
sid, client_secret, token, self.clock.time_msec() sid, client_secret, token, self.clock.time_msec()
) )
@ -218,8 +214,7 @@ class PasswordRestServlet(RestServlet):
self._set_password_handler = hs.get_set_password_handler() self._set_password_handler = hs.get_set_password_handler()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
# there are two possibilities here. Either the user does not have an # there are two possibilities here. Either the user does not have an
@ -233,14 +228,14 @@ class PasswordRestServlet(RestServlet):
# In the second case, we require a password to confirm their identity. # In the second case, we require a password to confirm their identity.
if self.auth.has_access_token(request): if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
params = yield self.auth_handler.validate_user_via_ui_auth( params = await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request) requester, body, self.hs.get_ip_from_request(request)
) )
user_id = requester.user.to_string() user_id = requester.user.to_string()
else: else:
requester = None requester = None
result, params, _ = yield self.auth_handler.check_auth( result, params, _ = await self.auth_handler.check_auth(
[[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request) [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
) )
@ -254,7 +249,7 @@ class PasswordRestServlet(RestServlet):
# (See add_threepid in synapse/handlers/auth.py) # (See add_threepid in synapse/handlers/auth.py)
threepid["address"] = threepid["address"].lower() threepid["address"] = threepid["address"].lower()
# if using email, we must know about the email they're authing with! # if using email, we must know about the email they're authing with!
threepid_user_id = yield self.datastore.get_user_id_by_threepid( threepid_user_id = await self.datastore.get_user_id_by_threepid(
threepid["medium"], threepid["address"] threepid["medium"], threepid["address"]
) )
if not threepid_user_id: if not threepid_user_id:
@ -267,7 +262,7 @@ class PasswordRestServlet(RestServlet):
assert_params_in_dict(params, ["new_password"]) assert_params_in_dict(params, ["new_password"])
new_password = params["new_password"] new_password = params["new_password"]
yield self._set_password_handler.set_password(user_id, new_password, requester) await self._set_password_handler.set_password(user_id, new_password, requester)
return 200, {} return 200, {}
@ -286,8 +281,7 @@ class DeactivateAccountRestServlet(RestServlet):
self._deactivate_account_handler = hs.get_deactivate_account_handler() self._deactivate_account_handler = hs.get_deactivate_account_handler()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
erase = body.get("erase", False) erase = body.get("erase", False)
if not isinstance(erase, bool): if not isinstance(erase, bool):
@ -297,19 +291,19 @@ class DeactivateAccountRestServlet(RestServlet):
Codes.BAD_JSON, Codes.BAD_JSON,
) )
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
# allow ASes to dectivate their own users # allow ASes to dectivate their own users
if requester.app_service: if requester.app_service:
yield self._deactivate_account_handler.deactivate_account( await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase requester.user.to_string(), erase
) )
return 200, {} return 200, {}
yield self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request) requester, body, self.hs.get_ip_from_request(request)
) )
result = yield self._deactivate_account_handler.deactivate_account( result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase, id_server=body.get("id_server") requester.user.to_string(), erase, id_server=body.get("id_server")
) )
if result: if result:
@ -346,8 +340,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
template_text=template_text, template_text=template_text,
) )
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
@ -371,7 +364,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
existing_user_id = yield self.store.get_user_id_by_threepid( existing_user_id = await self.store.get_user_id_by_threepid(
"email", body["email"] "email", body["email"]
) )
@ -382,7 +375,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
assert self.hs.config.account_threepid_delegate_email assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request # Have the configured identity server handle the request
ret = yield self.identity_handler.requestEmailToken( ret = await self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email, self.hs.config.account_threepid_delegate_email,
email, email,
client_secret, client_secret,
@ -391,7 +384,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
) )
else: else:
# Send threepid validation emails from Synapse # Send threepid validation emails from Synapse
sid = yield self.identity_handler.send_threepid_validation( sid = await self.identity_handler.send_threepid_validation(
email, email,
client_secret, client_secret,
send_attempt, send_attempt,
@ -414,8 +407,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict( assert_params_in_dict(
body, ["client_secret", "country", "phone_number", "send_attempt"] body, ["client_secret", "country", "phone_number", "send_attempt"]
@ -435,7 +427,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
existing_user_id = yield self.store.get_user_id_by_threepid("msisdn", msisdn) existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None: if existing_user_id is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
@ -450,7 +442,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
"Adding phone numbers to user account is not supported by this homeserver", "Adding phone numbers to user account is not supported by this homeserver",
) )
ret = yield self.identity_handler.requestMsisdnToken( ret = await self.identity_handler.requestMsisdnToken(
self.hs.config.account_threepid_delegate_msisdn, self.hs.config.account_threepid_delegate_msisdn,
country, country,
phone_number, phone_number,
@ -484,8 +476,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
[self.config.email_add_threepid_template_failure_html], [self.config.email_add_threepid_template_failure_html],
) )
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config: if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
@ -508,7 +499,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Attempt to validate a 3PID session # Attempt to validate a 3PID session
try: try:
# Mark the session as valid # Mark the session as valid
next_link = yield self.store.validate_threepid_session( next_link = await self.store.validate_threepid_session(
sid, client_secret, token, self.clock.time_msec() sid, client_secret, token, self.clock.time_msec()
) )
@ -558,8 +549,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
if not self.config.account_threepid_delegate_msisdn: if not self.config.account_threepid_delegate_msisdn:
raise SynapseError( raise SynapseError(
400, 400,
@ -571,7 +561,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
assert_params_in_dict(body, ["client_secret", "sid", "token"]) assert_params_in_dict(body, ["client_secret", "sid", "token"])
# Proxy submit_token request to msisdn threepid delegate # Proxy submit_token request to msisdn threepid delegate
response = yield self.identity_handler.proxy_msisdn_submit_token( response = await self.identity_handler.proxy_msisdn_submit_token(
self.config.account_threepid_delegate_msisdn, self.config.account_threepid_delegate_msisdn,
body["client_secret"], body["client_secret"],
body["sid"], body["sid"],
@ -591,17 +581,15 @@ class ThreepidRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore() self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) threepids = await self.datastore.user_get_threepids(requester.user.to_string())
return 200, {"threepids": threepids} return 200, {"threepids": threepids}
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -615,11 +603,11 @@ class ThreepidRestServlet(RestServlet):
client_secret = threepid_creds["client_secret"] client_secret = threepid_creds["client_secret"]
sid = threepid_creds["sid"] sid = threepid_creds["sid"]
validation_session = yield self.identity_handler.validate_threepid_session( validation_session = await self.identity_handler.validate_threepid_session(
client_secret, sid client_secret, sid
) )
if validation_session: if validation_session:
yield self.auth_handler.add_threepid( await self.auth_handler.add_threepid(
user_id, user_id,
validation_session["medium"], validation_session["medium"],
validation_session["address"], validation_session["address"],
@ -643,9 +631,8 @@ class ThreepidAddRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -653,15 +640,15 @@ class ThreepidAddRestServlet(RestServlet):
client_secret = body["client_secret"] client_secret = body["client_secret"]
sid = body["sid"] sid = body["sid"]
yield self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request) requester, body, self.hs.get_ip_from_request(request)
) )
validation_session = yield self.identity_handler.validate_threepid_session( validation_session = await self.identity_handler.validate_threepid_session(
client_secret, sid client_secret, sid
) )
if validation_session: if validation_session:
yield self.auth_handler.add_threepid( await self.auth_handler.add_threepid(
user_id, user_id,
validation_session["medium"], validation_session["medium"],
validation_session["address"], validation_session["address"],
@ -683,8 +670,7 @@ class ThreepidBindRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["id_server", "sid", "client_secret"]) assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
@ -693,10 +679,10 @@ class ThreepidBindRestServlet(RestServlet):
client_secret = body["client_secret"] client_secret = body["client_secret"]
id_access_token = body.get("id_access_token") # optional id_access_token = body.get("id_access_token") # optional
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
yield self.identity_handler.bind_threepid( await self.identity_handler.bind_threepid(
client_secret, sid, user_id, id_server, id_access_token client_secret, sid, user_id, id_server, id_access_token
) )
@ -713,12 +699,11 @@ class ThreepidUnbindRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.datastore = self.hs.get_datastore() self.datastore = self.hs.get_datastore()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
"""Unbind the given 3pid from a specific identity server, or identity servers that are """Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound known to have this 3pid bound
""" """
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"]) assert_params_in_dict(body, ["medium", "address"])
@ -728,7 +713,7 @@ class ThreepidUnbindRestServlet(RestServlet):
# Attempt to unbind the threepid from an identity server. If id_server is None, try to # Attempt to unbind the threepid from an identity server. If id_server is None, try to
# unbind from all identity servers this threepid has been added to in the past # unbind from all identity servers this threepid has been added to in the past
result = yield self.identity_handler.try_unbind_threepid( result = await self.identity_handler.try_unbind_threepid(
requester.user.to_string(), requester.user.to_string(),
{"address": address, "medium": medium, "id_server": id_server}, {"address": address, "medium": medium, "id_server": id_server},
) )
@ -743,16 +728,15 @@ class ThreepidDeleteRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"]) assert_params_in_dict(body, ["medium", "address"])
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
try: try:
ret = yield self.auth_handler.delete_threepid( ret = await self.auth_handler.delete_threepid(
user_id, body["medium"], body["address"], body.get("id_server") user_id, body["medium"], body["address"], body.get("id_server")
) )
except Exception: except Exception:
@ -777,9 +761,8 @@ class WhoamiRestServlet(RestServlet):
super(WhoamiRestServlet, self).__init__() super(WhoamiRestServlet, self).__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
return 200, {"user_id": requester.user.to_string()} return 200, {"user_id": requester.user.to_string()}

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -41,15 +39,14 @@ class AccountDataServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks async def on_PUT(self, request, user_id, account_data_type):
def on_PUT(self, request, user_id, account_data_type): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
max_id = yield self.store.add_account_data_for_user( max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body user_id, account_data_type, body
) )
@ -57,13 +54,12 @@ class AccountDataServlet(RestServlet):
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_GET(self, request, user_id, account_data_type):
def on_GET(self, request, user_id, account_data_type): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.") raise AuthError(403, "Cannot get account data for other users.")
event = yield self.store.get_global_account_data_by_type_for_user( event = await self.store.get_global_account_data_by_type_for_user(
account_data_type, user_id account_data_type, user_id
) )
@ -91,9 +87,8 @@ class RoomAccountDataServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks async def on_PUT(self, request, user_id, room_id, account_data_type):
def on_PUT(self, request, user_id, room_id, account_data_type): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.") raise AuthError(403, "Cannot add account data for other users.")
@ -106,7 +101,7 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers", " Use /rooms/!roomId:server.name/read_markers",
) )
max_id = yield self.store.add_account_data_to_room( max_id = await self.store.add_account_data_to_room(
user_id, room_id, account_data_type, body user_id, room_id, account_data_type, body
) )
@ -114,13 +109,12 @@ class RoomAccountDataServlet(RestServlet):
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_GET(self, request, user_id, room_id, account_data_type):
def on_GET(self, request, user_id, room_id, account_data_type): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.") raise AuthError(403, "Cannot get account data for other users.")
event = yield self.store.get_account_data_for_room_and_type( event = await self.store.get_account_data_for_room_and_type(
user_id, room_id, account_data_type user_id, room_id, account_data_type
) )

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
@ -45,13 +43,12 @@ class AccountValidityRenewServlet(RestServlet):
self.success_html = hs.config.account_validity.account_renewed_html_content self.success_html = hs.config.account_validity.account_renewed_html_content
self.failure_html = hs.config.account_validity.invalid_token_html_content self.failure_html = hs.config.account_validity.invalid_token_html_content
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request):
if b"token" not in request.args: if b"token" not in request.args:
raise SynapseError(400, "Missing renewal token") raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0] renewal_token = request.args[b"token"][0]
token_valid = yield self.account_activity_handler.renew_account( token_valid = await self.account_activity_handler.renew_account(
renewal_token.decode("utf8") renewal_token.decode("utf8")
) )
@ -67,7 +64,6 @@ class AccountValidityRenewServlet(RestServlet):
request.setHeader(b"Content-Length", b"%d" % (len(response),)) request.setHeader(b"Content-Length", b"%d" % (len(response),))
request.write(response.encode("utf8")) request.write(response.encode("utf8"))
finish_request(request) finish_request(request)
defer.returnValue(None)
class AccountValiditySendMailServlet(RestServlet): class AccountValiditySendMailServlet(RestServlet):
@ -85,18 +81,17 @@ class AccountValiditySendMailServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.account_validity = self.hs.config.account_validity self.account_validity = self.hs.config.account_validity
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
if not self.account_validity.renew_by_email_enabled: if not self.account_validity.renew_by_email_enabled:
raise AuthError( raise AuthError(
403, "Account renewal via email is disabled on this server." 403, "Account renewal via email is disabled on this server."
) )
requester = yield self.auth.get_user_by_req(request, allow_expired=True) requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
yield self.account_activity_handler.send_renewal_email_to_user(user_id) await self.account_activity_handler.send_renewal_email_to_user(user_id)
defer.returnValue((200, {})) return 200, {}
def register_servlets(hs, http_server): def register_servlets(hs, http_server):

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_API_PREFIX from synapse.api.urls import CLIENT_API_PREFIX
@ -171,8 +169,7 @@ class AuthRestServlet(RestServlet):
else: else:
raise SynapseError(404, "Unknown auth stage type") raise SynapseError(404, "Unknown auth stage type")
@defer.inlineCallbacks async def on_POST(self, request, stagetype):
def on_POST(self, request, stagetype):
session = parse_string(request, "session") session = parse_string(request, "session")
if not session: if not session:
@ -186,7 +183,7 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session} authdict = {"response": response, "session": session}
success = yield self.auth_handler.add_oob_auth( success = await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
) )
@ -215,7 +212,7 @@ class AuthRestServlet(RestServlet):
session = request.args["session"][0] session = request.args["session"][0]
authdict = {"session": session} authdict = {"session": session}
success = yield self.auth_handler.add_oob_auth( success = await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
) )

View File

@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
@ -40,10 +38,9 @@ class CapabilitiesRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True) user = await self.store.get_user_by_id(requester.user.to_string())
user = yield self.store.get_user_by_id(requester.user.to_string())
change_password = bool(user["password_hash"]) change_password = bool(user["password_hash"])
response = { response = {

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api import errors from synapse.api import errors
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -42,10 +40,9 @@ class DevicesRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True) devices = await self.device_handler.get_devices_by_user(
devices = yield self.device_handler.get_devices_by_user(
requester.user.to_string() requester.user.to_string()
) )
return 200, {"devices": devices} return 200, {"devices": devices}
@ -67,9 +64,8 @@ class DeleteDevicesRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
try: try:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -84,11 +80,11 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"]) assert_params_in_dict(body, ["devices"])
yield self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request) requester, body, self.hs.get_ip_from_request(request)
) )
yield self.device_handler.delete_devices( await self.device_handler.delete_devices(
requester.user.to_string(), body["devices"] requester.user.to_string(), body["devices"]
) )
return 200, {} return 200, {}
@ -108,18 +104,16 @@ class DeviceRestServlet(RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks async def on_GET(self, request, device_id):
def on_GET(self, request, device_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True) device = await self.device_handler.get_device(
device = yield self.device_handler.get_device(
requester.user.to_string(), device_id requester.user.to_string(), device_id
) )
return 200, device return 200, device
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_DELETE(self, request, device_id):
def on_DELETE(self, request, device_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
try: try:
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -132,19 +126,18 @@ class DeviceRestServlet(RestServlet):
else: else:
raise raise
yield self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request) requester, body, self.hs.get_ip_from_request(request)
) )
yield self.device_handler.delete_device(requester.user.to_string(), device_id) await self.device_handler.delete_device(requester.user.to_string(), device_id)
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_PUT(self, request, device_id):
def on_PUT(self, request, device_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
yield self.device_handler.update_device( await self.device_handler.update_device(
requester.user.to_string(), device_id, body requester.user.to_string(), device_id, body
) )
return 200, {} return 200, {}

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
@ -35,10 +33,9 @@ class GetFilterRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
@defer.inlineCallbacks async def on_GET(self, request, user_id, filter_id):
def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if target_user != requester.user: if target_user != requester.user:
raise AuthError(403, "Cannot get filters for other users") raise AuthError(403, "Cannot get filters for other users")
@ -52,7 +49,7 @@ class GetFilterRestServlet(RestServlet):
raise SynapseError(400, "Invalid filter_id") raise SynapseError(400, "Invalid filter_id")
try: try:
filter_collection = yield self.filtering.get_user_filter( filter_collection = await self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id user_localpart=target_user.localpart, filter_id=filter_id
) )
except StoreError as e: except StoreError as e:
@ -72,11 +69,10 @@ class CreateFilterRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
@defer.inlineCallbacks async def on_POST(self, request, user_id):
def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if target_user != requester.user: if target_user != requester.user:
raise AuthError(403, "Cannot create filters for other users") raise AuthError(403, "Cannot create filters for other users")
@ -87,7 +83,7 @@ class CreateFilterRestServlet(RestServlet):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit)
filter_id = yield self.filtering.add_user_filter( filter_id = await self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_filter=content user_localpart=target_user.localpart, user_filter=content
) )

View File

@ -16,8 +16,6 @@
import logging import logging
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID from synapse.types import GroupID
@ -38,24 +36,22 @@ class GroupServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
group_description = yield self.groups_handler.get_group_profile( group_description = await self.groups_handler.get_group_profile(
group_id, requester_user_id group_id, requester_user_id
) )
return 200, group_description return 200, group_description
@defer.inlineCallbacks async def on_POST(self, request, group_id):
def on_POST(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
yield self.groups_handler.update_group_profile( await self.groups_handler.update_group_profile(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -74,12 +70,11 @@ class GroupSummaryServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
get_group_summary = yield self.groups_handler.get_group_summary( get_group_summary = await self.groups_handler.get_group_summary(
group_id, requester_user_id group_id, requester_user_id
) )
@ -106,13 +101,12 @@ class GroupSummaryRoomsCatServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id, category_id, room_id):
def on_PUT(self, request, group_id, category_id, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_summary_room( resp = await self.groups_handler.update_group_summary_room(
group_id, group_id,
requester_user_id, requester_user_id,
room_id=room_id, room_id=room_id,
@ -122,12 +116,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp return 200, resp
@defer.inlineCallbacks async def on_DELETE(self, request, group_id, category_id, room_id):
def on_DELETE(self, request, group_id, category_id, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_summary_room( resp = await self.groups_handler.delete_group_summary_room(
group_id, requester_user_id, room_id=room_id, category_id=category_id group_id, requester_user_id, room_id=room_id, category_id=category_id
) )
@ -148,35 +141,32 @@ class GroupCategoryServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id, category_id):
def on_GET(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_category( category = await self.groups_handler.get_group_category(
group_id, requester_user_id, category_id=category_id group_id, requester_user_id, category_id=category_id
) )
return 200, category return 200, category
@defer.inlineCallbacks async def on_PUT(self, request, group_id, category_id):
def on_PUT(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_category( resp = await self.groups_handler.update_group_category(
group_id, requester_user_id, category_id=category_id, content=content group_id, requester_user_id, category_id=category_id, content=content
) )
return 200, resp return 200, resp
@defer.inlineCallbacks async def on_DELETE(self, request, group_id, category_id):
def on_DELETE(self, request, group_id, category_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_category( resp = await self.groups_handler.delete_group_category(
group_id, requester_user_id, category_id=category_id group_id, requester_user_id, category_id=category_id
) )
@ -195,12 +185,11 @@ class GroupCategoriesServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_categories( category = await self.groups_handler.get_group_categories(
group_id, requester_user_id group_id, requester_user_id
) )
@ -219,35 +208,32 @@ class GroupRoleServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id, role_id):
def on_GET(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_role( category = await self.groups_handler.get_group_role(
group_id, requester_user_id, role_id=role_id group_id, requester_user_id, role_id=role_id
) )
return 200, category return 200, category
@defer.inlineCallbacks async def on_PUT(self, request, group_id, role_id):
def on_PUT(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_role( resp = await self.groups_handler.update_group_role(
group_id, requester_user_id, role_id=role_id, content=content group_id, requester_user_id, role_id=role_id, content=content
) )
return 200, resp return 200, resp
@defer.inlineCallbacks async def on_DELETE(self, request, group_id, role_id):
def on_DELETE(self, request, group_id, role_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_role( resp = await self.groups_handler.delete_group_role(
group_id, requester_user_id, role_id=role_id group_id, requester_user_id, role_id=role_id
) )
@ -266,12 +252,11 @@ class GroupRolesServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_roles( category = await self.groups_handler.get_group_roles(
group_id, requester_user_id group_id, requester_user_id
) )
@ -298,13 +283,12 @@ class GroupSummaryUsersRoleServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id, role_id, user_id):
def on_PUT(self, request, group_id, role_id, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_summary_user( resp = await self.groups_handler.update_group_summary_user(
group_id, group_id,
requester_user_id, requester_user_id,
user_id=user_id, user_id=user_id,
@ -314,12 +298,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp return 200, resp
@defer.inlineCallbacks async def on_DELETE(self, request, group_id, role_id, user_id):
def on_DELETE(self, request, group_id, role_id, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_summary_user( resp = await self.groups_handler.delete_group_summary_user(
group_id, requester_user_id, user_id=user_id, role_id=role_id group_id, requester_user_id, user_id=user_id, role_id=role_id
) )
@ -338,12 +321,11 @@ class GroupRoomServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_rooms_in_group( result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id group_id, requester_user_id
) )
@ -362,12 +344,11 @@ class GroupUsersServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_users_in_group( result = await self.groups_handler.get_users_in_group(
group_id, requester_user_id group_id, requester_user_id
) )
@ -386,12 +367,11 @@ class GroupInvitedUsersServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, group_id):
def on_GET(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_invited_users_in_group( result = await self.groups_handler.get_invited_users_in_group(
group_id, requester_user_id group_id, requester_user_id
) )
@ -409,14 +389,13 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id):
def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.set_group_join_policy( result = await self.groups_handler.set_group_join_policy(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -436,9 +415,8 @@ class GroupCreateServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
# TODO: Create group on remote server # TODO: Create group on remote server
@ -446,7 +424,7 @@ class GroupCreateServlet(RestServlet):
localpart = content.pop("localpart") localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string() group_id = GroupID(localpart, self.server_name).to_string()
result = yield self.groups_handler.create_group( result = await self.groups_handler.create_group(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -467,24 +445,22 @@ class GroupAdminRoomsServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id, room_id):
def on_PUT(self, request, group_id, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.add_room_to_group( result = await self.groups_handler.add_room_to_group(
group_id, requester_user_id, room_id, content group_id, requester_user_id, room_id, content
) )
return 200, result return 200, result
@defer.inlineCallbacks async def on_DELETE(self, request, group_id, room_id):
def on_DELETE(self, request, group_id, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.remove_room_from_group( result = await self.groups_handler.remove_room_from_group(
group_id, requester_user_id, room_id group_id, requester_user_id, room_id
) )
@ -506,13 +482,12 @@ class GroupAdminRoomsConfigServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id, room_id, config_key):
def on_PUT(self, request, group_id, room_id, config_key): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.update_room_in_group( result = await self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content group_id, requester_user_id, room_id, config_key, content
) )
@ -535,14 +510,13 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks async def on_PUT(self, request, group_id, user_id):
def on_PUT(self, request, group_id, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
config = content.get("config", {}) config = content.get("config", {})
result = yield self.groups_handler.invite( result = await self.groups_handler.invite(
group_id, user_id, requester_user_id, config group_id, user_id, requester_user_id, config
) )
@ -563,13 +537,12 @@ class GroupAdminUsersKickServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id, user_id):
def on_PUT(self, request, group_id, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.remove_user_from_group( result = await self.groups_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content group_id, user_id, requester_user_id, content
) )
@ -588,13 +561,12 @@ class GroupSelfLeaveServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id):
def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.remove_user_from_group( result = await self.groups_handler.remove_user_from_group(
group_id, requester_user_id, requester_user_id, content group_id, requester_user_id, requester_user_id, content
) )
@ -613,13 +585,12 @@ class GroupSelfJoinServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id):
def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.join_group( result = await self.groups_handler.join_group(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -638,13 +609,12 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_PUT(self, request, group_id):
def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
result = yield self.groups_handler.accept_invite( result = await self.groups_handler.accept_invite(
group_id, requester_user_id, content group_id, requester_user_id, content
) )
@ -663,14 +633,13 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def on_PUT(self, request, group_id):
def on_PUT(self, request, group_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
publicise = content["publicise"] publicise = content["publicise"]
yield self.store.update_group_publicity(group_id, requester_user_id, publicise) await self.store.update_group_publicity(group_id, requester_user_id, publicise)
return 200, {} return 200, {}
@ -688,11 +657,10 @@ class PublicisedGroupsForUserServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
result = yield self.groups_handler.get_publicised_groups_for_user(user_id) result = await self.groups_handler.get_publicised_groups_for_user(user_id)
return 200, result return 200, result
@ -710,14 +678,13 @@ class PublicisedGroupsForUsersServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
user_ids = content["user_ids"] user_ids = content["user_ids"]
result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) result = await self.groups_handler.bulk_get_publicised_groups(user_ids)
return 200, result return 200, result
@ -734,12 +701,11 @@ class GroupsForUserServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler() self.groups_handler = hs.get_groups_local_handler()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string() requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_joined_groups(requester_user_id) result = await self.groups_handler.get_joined_groups(requester_user_id)
return 200, result return 200, result

View File

@ -16,8 +16,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -71,9 +69,8 @@ class KeyUploadServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@trace(opname="upload_keys") @trace(opname="upload_keys")
@defer.inlineCallbacks async def on_POST(self, request, device_id):
def on_POST(self, request, device_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -103,7 +100,7 @@ class KeyUploadServlet(RestServlet):
400, "To upload keys, you must pass device_id when authenticating" 400, "To upload keys, you must pass device_id when authenticating"
) )
result = yield self.e2e_keys_handler.upload_keys_for_user( result = await self.e2e_keys_handler.upload_keys_for_user(
user_id, device_id, body user_id, device_id, body
) )
return 200, result return 200, result
@ -154,13 +151,12 @@ class KeyQueryServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body, timeout, user_id) result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
return 200, result return 200, result
@ -185,9 +181,8 @@ class KeyChangesServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from") from_token_string = parse_string(request, "from")
set_tag("from", from_token_string) set_tag("from", from_token_string)
@ -200,7 +195,7 @@ class KeyChangesServlet(RestServlet):
user_id = requester.user.to_string() user_id = requester.user.to_string()
results = yield self.device_handler.get_user_ids_changed(user_id, from_token) results = await self.device_handler.get_user_ids_changed(user_id, from_token)
return 200, results return 200, results
@ -231,12 +226,11 @@ class OneTimeKeyServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000) timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
return 200, result return 200, result
@ -263,17 +257,16 @@ class SigningKeyUploadServlet(RestServlet):
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
yield self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request) requester, body, self.hs.get_ip_from_request(request)
) )
result = yield self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
return 200, result return 200, result
@ -315,13 +308,12 @@ class SignaturesUploadServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.upload_signatures_for_device_keys( result = await self.e2e_keys_handler.upload_signatures_for_device_keys(
user_id, body user_id, body
) )
return 200, result return 200, result

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
@ -35,9 +33,8 @@ class NotificationsServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
from_token = parse_string(request, "from", required=False) from_token = parse_string(request, "from", required=False)
@ -46,16 +43,16 @@ class NotificationsServlet(RestServlet):
limit = min(limit, 500) limit = min(limit, 500)
push_actions = yield self.store.get_push_actions_for_user( push_actions = await self.store.get_push_actions_for_user(
user_id, from_token, limit, only_highlight=(only == "highlight") user_id, from_token, limit, only_highlight=(only == "highlight")
) )
receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
user_id, "m.read" user_id, "m.read"
) )
notif_event_ids = [pa["event_id"] for pa in push_actions] notif_event_ids = [pa["event_id"] for pa in push_actions]
notif_events = yield self.store.get_events(notif_event_ids) notif_events = await self.store.get_events(notif_event_ids)
returned_push_actions = [] returned_push_actions = []
@ -68,7 +65,7 @@ class NotificationsServlet(RestServlet):
"actions": pa["actions"], "actions": pa["actions"],
"ts": pa["received_ts"], "ts": pa["received_ts"],
"event": ( "event": (
yield self._event_serializer.serialize_event( await self._event_serializer.serialize_event(
notif_events[pa["event_id"]], notif_events[pa["event_id"]],
self.clock.time_msec(), self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id, event_format=format_event_for_client_v2_without_room_id,

View File

@ -16,8 +16,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
@ -68,9 +66,8 @@ class IdTokenServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
@defer.inlineCallbacks async def on_POST(self, request, user_id):
def on_POST(self, request, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot request tokens for other users.") raise AuthError(403, "Cannot request tokens for other users.")
@ -81,7 +78,7 @@ class IdTokenServlet(RestServlet):
token = random_string(24) token = random_string(24)
ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS
yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
return ( return (
200, 200,

View File

@ -20,8 +20,6 @@ from typing import List, Union
from six import string_types from six import string_types
from twisted.internet import defer
import synapse import synapse
import synapse.types import synapse.types
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
@ -102,8 +100,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
template_text=template_text, template_text=template_text,
) )
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config: if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning( logger.warning(
@ -129,7 +126,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", body["email"] "email", body["email"]
) )
@ -140,7 +137,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
assert self.hs.config.account_threepid_delegate_email assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request # Have the configured identity server handle the request
ret = yield self.identity_handler.requestEmailToken( ret = await self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email, self.hs.config.account_threepid_delegate_email,
email, email,
client_secret, client_secret,
@ -149,7 +146,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
) )
else: else:
# Send registration emails from Synapse # Send registration emails from Synapse
sid = yield self.identity_handler.send_threepid_validation( sid = await self.identity_handler.send_threepid_validation(
email, email,
client_secret, client_secret,
send_attempt, send_attempt,
@ -175,8 +172,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
assert_params_in_dict( assert_params_in_dict(
@ -197,7 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
) )
existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid( existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn "msisdn", msisdn
) )
@ -215,7 +211,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
400, "Registration by phone number is not supported on this homeserver" 400, "Registration by phone number is not supported on this homeserver"
) )
ret = yield self.identity_handler.requestMsisdnToken( ret = await self.identity_handler.requestMsisdnToken(
self.hs.config.account_threepid_delegate_msisdn, self.hs.config.account_threepid_delegate_msisdn,
country, country,
phone_number, phone_number,
@ -258,8 +254,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
[self.config.email_registration_template_failure_html], [self.config.email_registration_template_failure_html],
) )
@defer.inlineCallbacks async def on_GET(self, request, medium):
def on_GET(self, request, medium):
if medium != "email": if medium != "email":
raise SynapseError( raise SynapseError(
400, "This medium is currently not supported for registration" 400, "This medium is currently not supported for registration"
@ -280,7 +275,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Attempt to validate a 3PID session # Attempt to validate a 3PID session
try: try:
# Mark the session as valid # Mark the session as valid
next_link = yield self.store.validate_threepid_session( next_link = await self.store.validate_threepid_session(
sid, client_secret, token, self.clock.time_msec() sid, client_secret, token, self.clock.time_msec()
) )
@ -338,8 +333,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
), ),
) )
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request):
if not self.hs.config.enable_registration: if not self.hs.config.enable_registration:
raise SynapseError( raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
@ -347,11 +341,11 @@ class UsernameAvailabilityRestServlet(RestServlet):
ip = self.hs.get_ip_from_request(request) ip = self.hs.get_ip_from_request(request)
with self.ratelimiter.ratelimit(ip) as wait_deferred: with self.ratelimiter.ratelimit(ip) as wait_deferred:
yield wait_deferred await wait_deferred
username = parse_string(request, "username", required=True) username = parse_string(request, "username", required=True)
yield self.registration_handler.check_username(username) await self.registration_handler.check_username(username)
return 200, {"available": True} return 200, {"available": True}
@ -382,8 +376,7 @@ class RegisterRestServlet(RestServlet):
) )
@interactive_auth_handler @interactive_auth_handler
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
client_addr = request.getClientIP() client_addr = request.getClientIP()
@ -408,7 +401,7 @@ class RegisterRestServlet(RestServlet):
kind = request.args[b"kind"][0] kind = request.args[b"kind"][0]
if kind == b"guest": if kind == b"guest":
ret = yield self._do_guest_registration(body, address=client_addr) ret = await self._do_guest_registration(body, address=client_addr)
return ret return ret
elif kind != b"user": elif kind != b"user":
raise UnrecognizedRequestError( raise UnrecognizedRequestError(
@ -435,7 +428,7 @@ class RegisterRestServlet(RestServlet):
appservice = None appservice = None
if self.auth.has_access_token(request): if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request) appservice = await self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes which have completely # fork off as soon as possible for ASes which have completely
# different registration flows to normal users # different registration flows to normal users
@ -455,7 +448,7 @@ class RegisterRestServlet(RestServlet):
access_token = self.auth.get_access_token_from_request(request) access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, string_types): if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration( result = await self._do_appservice_registration(
desired_username, access_token, body desired_username, access_token, body
) )
return 200, result # we throw for non 200 responses return 200, result # we throw for non 200 responses
@ -495,13 +488,13 @@ class RegisterRestServlet(RestServlet):
) )
if desired_username is not None: if desired_username is not None:
yield self.registration_handler.check_username( await self.registration_handler.check_username(
desired_username, desired_username,
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
assigned_user_id=registered_user_id, assigned_user_id=registered_user_id,
) )
auth_result, params, session_id = yield self.auth_handler.check_auth( auth_result, params, session_id = await self.auth_handler.check_auth(
self._registration_flows, body, self.hs.get_ip_from_request(request) self._registration_flows, body, self.hs.get_ip_from_request(request)
) )
@ -557,7 +550,7 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"] medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"] address = auth_result[login_type]["address"]
existing_user_id = yield self.store.get_user_id_by_threepid( existing_user_id = await self.store.get_user_id_by_threepid(
medium, address medium, address
) )
@ -568,7 +561,7 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_IN_USE, Codes.THREEPID_IN_USE,
) )
registered_user_id = yield self.registration_handler.register_user( registered_user_id = await self.registration_handler.register_user(
localpart=desired_username, localpart=desired_username,
password=new_password, password=new_password,
guest_access_token=guest_access_token, guest_access_token=guest_access_token,
@ -581,7 +574,7 @@ class RegisterRestServlet(RestServlet):
if is_threepid_reserved( if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid self.hs.config.mau_limits_reserved_threepids, threepid
): ):
yield self.store.upsert_monthly_active_user(registered_user_id) await self.store.upsert_monthly_active_user(registered_user_id)
# remember that we've now registered that user account, and with # remember that we've now registered that user account, and with
# what user ID (since the user may not have specified) # what user ID (since the user may not have specified)
@ -591,12 +584,12 @@ class RegisterRestServlet(RestServlet):
registered = True registered = True
return_dict = yield self._create_registration_details( return_dict = await self._create_registration_details(
registered_user_id, params registered_user_id, params
) )
if registered: if registered:
yield self.registration_handler.post_registration_actions( await self.registration_handler.post_registration_actions(
user_id=registered_user_id, user_id=registered_user_id,
auth_result=auth_result, auth_result=auth_result,
access_token=return_dict.get("access_token"), access_token=return_dict.get("access_token"),
@ -607,15 +600,13 @@ class RegisterRestServlet(RestServlet):
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
@defer.inlineCallbacks async def _do_appservice_registration(self, username, as_token, body):
def _do_appservice_registration(self, username, as_token, body): user_id = await self.registration_handler.appservice_register(
user_id = yield self.registration_handler.appservice_register(
username, as_token username, as_token
) )
return (yield self._create_registration_details(user_id, body)) return await self._create_registration_details(user_id, body)
@defer.inlineCallbacks async def _create_registration_details(self, user_id, params):
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user """Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token. Allocates device_id if one was not given; also creates access_token.
@ -631,18 +622,17 @@ class RegisterRestServlet(RestServlet):
if not params.get("inhibit_login", False): if not params.get("inhibit_login", False):
device_id = params.get("device_id") device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device( device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=False user_id, device_id, initial_display_name, is_guest=False
) )
result.update({"access_token": access_token, "device_id": device_id}) result.update({"access_token": access_token, "device_id": device_id})
return result return result
@defer.inlineCallbacks async def _do_guest_registration(self, params, address=None):
def _do_guest_registration(self, params, address=None):
if not self.hs.config.allow_guest_access: if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled") raise SynapseError(403, "Guest access is disabled")
user_id = yield self.registration_handler.register_user( user_id = await self.registration_handler.register_user(
make_guest=True, address=address make_guest=True, address=address
) )
@ -650,7 +640,7 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it. # we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device( device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=True user_id, device_id, initial_display_name, is_guest=True
) )

View File

@ -21,8 +21,6 @@ any time to reflect changes in the MSC.
import logging import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -86,11 +84,10 @@ class RelationSendServlet(RestServlet):
request, self.on_PUT_or_POST, request, *args, **kwargs request, self.on_PUT_or_POST, request, *args, **kwargs
) )
@defer.inlineCallbacks async def on_PUT_or_POST(
def on_PUT_or_POST(
self, request, room_id, parent_id, relation_type, event_type, txn_id=None self, request, room_id, parent_id, relation_type, event_type, txn_id=None
): ):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
# Add relations to a membership is meaningless, so we just deny it # Add relations to a membership is meaningless, so we just deny it
@ -114,7 +111,7 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
} }
event = yield self.event_creation_handler.create_and_send_nonmember_event( event = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id requester, event_dict=event_dict, txn_id=txn_id
) )
@ -140,17 +137,18 @@ class RelationPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks async def on_GET(
def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): self, request, room_id, parent_id, relation_type=None, event_type=None
requester = yield self.auth.get_user_by_req(request, allow_guest=True) ):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.check_in_room_or_world_readable( await self.auth.check_in_room_or_world_readable(
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
# This gets the original event and checks that a) the event exists and # This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it. # b) the user is allowed to view it.
event = yield self.event_handler.get_event(requester.user, room_id, parent_id) event = await self.event_handler.get_event(requester.user, room_id, parent_id)
limit = parse_integer(request, "limit", default=5) limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from") from_token = parse_string(request, "from")
@ -167,7 +165,7 @@ class RelationPaginationServlet(RestServlet):
if to_token: if to_token:
to_token = RelationPaginationToken.from_string(to_token) to_token = RelationPaginationToken.from_string(to_token)
pagination_chunk = yield self.store.get_relations_for_event( pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id, event_id=parent_id,
relation_type=relation_type, relation_type=relation_type,
event_type=event_type, event_type=event_type,
@ -176,7 +174,7 @@ class RelationPaginationServlet(RestServlet):
to_token=to_token, to_token=to_token,
) )
events = yield self.store.get_events_as_list( events = await self.store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk] [c["event_id"] for c in pagination_chunk.chunk]
) )
@ -184,13 +182,13 @@ class RelationPaginationServlet(RestServlet):
# We set bundle_aggregations to False when retrieving the original # We set bundle_aggregations to False when retrieving the original
# event because we want the content before relations were applied to # event because we want the content before relations were applied to
# it. # it.
original_event = yield self._event_serializer.serialize_event( original_event = await self._event_serializer.serialize_event(
event, now, bundle_aggregations=False event, now, bundle_aggregations=False
) )
# Similarly, we don't allow relations to be applied to relations, so we # Similarly, we don't allow relations to be applied to relations, so we
# return the original relations without any aggregations on top of them # return the original relations without any aggregations on top of them
# here. # here.
events = yield self._event_serializer.serialize_events( events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=False events, now, bundle_aggregations=False
) )
@ -232,17 +230,18 @@ class RelationAggregationPaginationServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks async def on_GET(
def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): self, request, room_id, parent_id, relation_type=None, event_type=None
requester = yield self.auth.get_user_by_req(request, allow_guest=True) ):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.check_in_room_or_world_readable( await self.auth.check_in_room_or_world_readable(
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
# This checks that a) the event exists and b) the user is allowed to # This checks that a) the event exists and b) the user is allowed to
# view it. # view it.
event = yield self.event_handler.get_event(requester.user, room_id, parent_id) event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type not in (RelationTypes.ANNOTATION, None): if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
@ -262,7 +261,7 @@ class RelationAggregationPaginationServlet(RestServlet):
if to_token: if to_token:
to_token = AggregationPaginationToken.from_string(to_token) to_token = AggregationPaginationToken.from_string(to_token)
pagination_chunk = yield self.store.get_aggregation_groups_for_event( pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id, event_id=parent_id,
event_type=event_type, event_type=event_type,
limit=limit, limit=limit,
@ -311,17 +310,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.check_in_room_or_world_readable( await self.auth.check_in_room_or_world_readable(
room_id, requester.user.to_string() room_id, requester.user.to_string()
) )
# This checks that a) the event exists and b) the user is allowed to # This checks that a) the event exists and b) the user is allowed to
# view it. # view it.
yield self.event_handler.get_event(requester.user, room_id, parent_id) await self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type != RelationTypes.ANNOTATION: if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'") raise SynapseError(400, "Relation type must be 'annotation'")
@ -336,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
if to_token: if to_token:
to_token = RelationPaginationToken.from_string(to_token) to_token = RelationPaginationToken.from_string(to_token)
result = yield self.store.get_relations_for_event( result = await self.store.get_relations_for_event(
event_id=parent_id, event_id=parent_id,
relation_type=relation_type, relation_type=relation_type,
event_type=event_type, event_type=event_type,
@ -346,12 +344,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
to_token=to_token, to_token=to_token,
) )
events = yield self.store.get_events_as_list( events = await self.store.get_events_as_list(
[c["event_id"] for c in result.chunk] [c["event_id"] for c in result.chunk]
) )
now = self.clock.time_msec() now = self.clock.time_msec()
events = yield self._event_serializer.serialize_events(events, now) events = await self._event_serializer.serialize_events(events, now)
return_value = result.to_dict() return_value = result.to_dict()
return_value["chunk"] = events return_value["chunk"] = events

View File

@ -18,8 +18,6 @@ import logging
from six import string_types from six import string_types
from six.moves import http_client from six.moves import http_client
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -42,9 +40,8 @@ class ReportEventRestServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def on_POST(self, request, room_id, event_id):
def on_POST(self, request, room_id, event_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -63,7 +60,7 @@ class ReportEventRestServlet(RestServlet):
Codes.BAD_JSON, Codes.BAD_JSON,
) )
yield self.store.add_event_report( await self.store.add_event_report(
room_id=room_id, room_id=room_id,
event_id=event_id, event_id=event_id,
user_id=user_id, user_id=user_id,

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@ -43,8 +41,7 @@ class RoomKeysServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
@defer.inlineCallbacks async def on_PUT(self, request, room_id, session_id):
def on_PUT(self, request, room_id, session_id):
""" """
Uploads one or more encrypted E2E room keys for backup purposes. Uploads one or more encrypted E2E room keys for backup purposes.
room_id: the ID of the room the keys are for (optional) room_id: the ID of the room the keys are for (optional)
@ -123,7 +120,7 @@ class RoomKeysServlet(RestServlet):
} }
} }
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
version = parse_string(request, "version") version = parse_string(request, "version")
@ -134,11 +131,10 @@ class RoomKeysServlet(RestServlet):
if room_id: if room_id:
body = {"rooms": {room_id: body}} body = {"rooms": {room_id: body}}
ret = yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
return 200, ret return 200, ret
@defer.inlineCallbacks async def on_GET(self, request, room_id, session_id):
def on_GET(self, request, room_id, session_id):
""" """
Retrieves one or more encrypted E2E room keys for backup purposes. Retrieves one or more encrypted E2E room keys for backup purposes.
Symmetric with the PUT version of the API. Symmetric with the PUT version of the API.
@ -190,11 +186,11 @@ class RoomKeysServlet(RestServlet):
} }
} }
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
version = parse_string(request, "version") version = parse_string(request, "version")
room_keys = yield self.e2e_room_keys_handler.get_room_keys( room_keys = await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id user_id, version, room_id, session_id
) )
@ -220,8 +216,7 @@ class RoomKeysServlet(RestServlet):
return 200, room_keys return 200, room_keys
@defer.inlineCallbacks async def on_DELETE(self, request, room_id, session_id):
def on_DELETE(self, request, room_id, session_id):
""" """
Deletes one or more encrypted E2E room keys for a user for backup purposes. Deletes one or more encrypted E2E room keys for a user for backup purposes.
@ -235,11 +230,11 @@ class RoomKeysServlet(RestServlet):
the version must already have been created via the /change_secret API. the version must already have been created via the /change_secret API.
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
version = parse_string(request, "version") version = parse_string(request, "version")
ret = yield self.e2e_room_keys_handler.delete_room_keys( ret = await self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id user_id, version, room_id, session_id
) )
return 200, ret return 200, ret
@ -257,8 +252,7 @@ class RoomKeysNewVersionServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
""" """
Create a new backup version for this user's room_keys with the given Create a new backup version for this user's room_keys with the given
info. The version is allocated by the server and returned to the user info. The version is allocated by the server and returned to the user
@ -288,11 +282,11 @@ class RoomKeysNewVersionServlet(RestServlet):
"version": 12345 "version": 12345
} }
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
info = parse_json_object_from_request(request) info = parse_json_object_from_request(request)
new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) new_version = await self.e2e_room_keys_handler.create_version(user_id, info)
return 200, {"version": new_version} return 200, {"version": new_version}
# we deliberately don't have a PUT /version, as these things really should # we deliberately don't have a PUT /version, as these things really should
@ -311,8 +305,7 @@ class RoomKeysVersionServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
@defer.inlineCallbacks async def on_GET(self, request, version):
def on_GET(self, request, version):
""" """
Retrieve the version information about a given version of the user's Retrieve the version information about a given version of the user's
room_keys backup. If the version part is missing, returns info about the room_keys backup. If the version part is missing, returns info about the
@ -330,18 +323,17 @@ class RoomKeysVersionServlet(RestServlet):
"auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K" "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
} }
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
try: try:
info = yield self.e2e_room_keys_handler.get_version_info(user_id, version) info = await self.e2e_room_keys_handler.get_version_info(user_id, version)
except SynapseError as e: except SynapseError as e:
if e.code == 404: if e.code == 404:
raise SynapseError(404, "No backup found", Codes.NOT_FOUND) raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
return 200, info return 200, info
@defer.inlineCallbacks async def on_DELETE(self, request, version):
def on_DELETE(self, request, version):
""" """
Delete the information about a given version of the user's Delete the information about a given version of the user's
room_keys backup. If the version part is missing, deletes the most room_keys backup. If the version part is missing, deletes the most
@ -354,14 +346,13 @@ class RoomKeysVersionServlet(RestServlet):
if version is None: if version is None:
raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND) raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND)
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
yield self.e2e_room_keys_handler.delete_version(user_id, version) await self.e2e_room_keys_handler.delete_version(user_id, version)
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_PUT(self, request, version):
def on_PUT(self, request, version):
""" """
Update the information about a given version of the user's room_keys backup. Update the information about a given version of the user's room_keys backup.
@ -382,7 +373,7 @@ class RoomKeysVersionServlet(RestServlet):
Content-Type: application/json Content-Type: application/json
{} {}
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
info = parse_json_object_from_request(request) info = parse_json_object_from_request(request)
@ -391,7 +382,7 @@ class RoomKeysVersionServlet(RestServlet):
400, "No version specified to update", Codes.MISSING_PARAM 400, "No version specified to update", Codes.MISSING_PARAM
) )
yield self.e2e_room_keys_handler.update_version(user_id, version, info) await self.e2e_room_keys_handler.update_version(user_id, version, info)
return 200, {} return 200, {}

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -59,9 +57,8 @@ class RoomUpgradeRestServlet(RestServlet):
self._room_creation_handler = hs.get_room_creation_handler() self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth() self._auth = hs.get_auth()
@defer.inlineCallbacks async def on_POST(self, request, room_id):
def on_POST(self, request, room_id): requester = await self._auth.get_user_by_req(request)
requester = yield self._auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("new_version",)) assert_params_in_dict(content, ("new_version",))
@ -74,7 +71,7 @@ class RoomUpgradeRestServlet(RestServlet):
Codes.UNSUPPORTED_ROOM_VERSION, Codes.UNSUPPORTED_ROOM_VERSION,
) )
new_room_id = yield self._room_creation_handler.upgrade_room( new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version requester, room_id, new_version
) )

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.http import servlet from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.logging.opentracing import set_tag, trace from synapse.logging.opentracing import set_tag, trace
@ -51,15 +49,14 @@ class SendToDeviceRestServlet(servlet.RestServlet):
request, self._put, request, message_type, txn_id request, self._put, request, message_type, txn_id
) )
@defer.inlineCallbacks async def _put(self, request, message_type, txn_id):
def _put(self, request, message_type, txn_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
sender_user_id = requester.user.to_string() sender_user_id = requester.user.to_string()
yield self.device_message_handler.send_device_message( await self.device_message_handler.send_device_message(
sender_user_id, message_type, content["messages"] sender_user_id, message_type, content["messages"]
) )

View File

@ -18,8 +18,6 @@ import logging
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
@ -87,8 +85,7 @@ class SyncRestServlet(RestServlet):
self._server_notices_sender = hs.get_server_notices_sender() self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request):
if b"from" in request.args: if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'. # /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'. # Lets be helpful and whine if we see a 'from'.
@ -96,7 +93,7 @@ class SyncRestServlet(RestServlet):
400, "'from' is not a valid query parameter. Did you mean 'since'?" 400, "'from' is not a valid query parameter. Did you mean 'since'?"
) )
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = requester.user user = requester.user
device_id = requester.device_id device_id = requester.device_id
@ -138,7 +135,7 @@ class SyncRestServlet(RestServlet):
filter_collection = FilterCollection(filter_object) filter_collection = FilterCollection(filter_object)
else: else:
try: try:
filter_collection = yield self.filtering.get_user_filter( filter_collection = await self.filtering.get_user_filter(
user.localpart, filter_id user.localpart, filter_id
) )
except StoreError as err: except StoreError as err:
@ -161,20 +158,20 @@ class SyncRestServlet(RestServlet):
since_token = None since_token = None
# send any outstanding server notices to the user. # send any outstanding server notices to the user.
yield self._server_notices_sender.on_user_syncing(user.to_string()) await self._server_notices_sender.on_user_syncing(user.to_string())
affect_presence = set_presence != PresenceState.OFFLINE affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence: if affect_presence:
yield self.presence_handler.set_state( await self.presence_handler.set_state(
user, {"presence": set_presence}, True user, {"presence": set_presence}, True
) )
context = yield self.presence_handler.user_syncing( context = await self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence user.to_string(), affect_presence=affect_presence
) )
with context: with context:
sync_result = yield self.sync_handler.wait_for_sync_for_user( sync_result = await self.sync_handler.wait_for_sync_for_user(
sync_config, sync_config,
since_token=since_token, since_token=since_token,
timeout=timeout, timeout=timeout,
@ -182,14 +179,13 @@ class SyncRestServlet(RestServlet):
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
response_content = yield self.encode_response( response_content = await self.encode_response(
time_now, sync_result, requester.access_token_id, filter_collection time_now, sync_result, requester.access_token_id, filter_collection
) )
return 200, response_content return 200, response_content
@defer.inlineCallbacks async def encode_response(self, time_now, sync_result, access_token_id, filter):
def encode_response(self, time_now, sync_result, access_token_id, filter):
if filter.event_format == "client": if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == "federation": elif filter.event_format == "federation":
@ -197,7 +193,7 @@ class SyncRestServlet(RestServlet):
else: else:
raise Exception("Unknown event format %s" % (filter.event_format,)) raise Exception("Unknown event format %s" % (filter.event_format,))
joined = yield self.encode_joined( joined = await self.encode_joined(
sync_result.joined, sync_result.joined,
time_now, time_now,
access_token_id, access_token_id,
@ -205,11 +201,11 @@ class SyncRestServlet(RestServlet):
event_formatter, event_formatter,
) )
invited = yield self.encode_invited( invited = await self.encode_invited(
sync_result.invited, time_now, access_token_id, event_formatter sync_result.invited, time_now, access_token_id, event_formatter
) )
archived = yield self.encode_archived( archived = await self.encode_archived(
sync_result.archived, sync_result.archived,
time_now, time_now,
access_token_id, access_token_id,
@ -250,8 +246,9 @@ class SyncRestServlet(RestServlet):
] ]
} }
@defer.inlineCallbacks async def encode_joined(
def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter): self, rooms, time_now, token_id, event_fields, event_formatter
):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
@ -272,7 +269,7 @@ class SyncRestServlet(RestServlet):
""" """
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = yield self.encode_room( joined[room.room_id] = await self.encode_room(
room, room,
time_now, time_now,
token_id, token_id,
@ -283,8 +280,7 @@ class SyncRestServlet(RestServlet):
return joined return joined
@defer.inlineCallbacks async def encode_invited(self, rooms, time_now, token_id, event_formatter):
def encode_invited(self, rooms, time_now, token_id, event_formatter):
""" """
Encode the invited rooms in a sync result Encode the invited rooms in a sync result
@ -304,7 +300,7 @@ class SyncRestServlet(RestServlet):
""" """
invited = {} invited = {}
for room in rooms: for room in rooms:
invite = yield self._event_serializer.serialize_event( invite = await self._event_serializer.serialize_event(
room.invite, room.invite,
time_now, time_now,
token_id=token_id, token_id=token_id,
@ -319,8 +315,9 @@ class SyncRestServlet(RestServlet):
return invited return invited
@defer.inlineCallbacks async def encode_archived(
def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter): self, rooms, time_now, token_id, event_fields, event_formatter
):
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
@ -341,7 +338,7 @@ class SyncRestServlet(RestServlet):
""" """
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = yield self.encode_room( joined[room.room_id] = await self.encode_room(
room, room,
time_now, time_now,
token_id, token_id,
@ -352,8 +349,7 @@ class SyncRestServlet(RestServlet):
return joined return joined
@defer.inlineCallbacks async def encode_room(
def encode_room(
self, room, time_now, token_id, joined, only_fields, event_formatter self, room, time_now, token_id, joined, only_fields, event_formatter
): ):
""" """
@ -401,8 +397,8 @@ class SyncRestServlet(RestServlet):
event.room_id, event.room_id,
) )
serialized_state = yield serialize(state_events) serialized_state = await serialize(state_events)
serialized_timeline = yield serialize(timeline_events) serialized_timeline = await serialize(timeline_events)
account_data = room.account_data account_data = room.account_data

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -37,13 +35,12 @@ class TagListServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def on_GET(self, request, user_id, room_id):
def on_GET(self, request, user_id, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.") raise AuthError(403, "Cannot get tags for other users.")
tags = yield self.store.get_tags_for_room(user_id, room_id) tags = await self.store.get_tags_for_room(user_id, room_id)
return 200, {"tags": tags} return 200, {"tags": tags}
@ -64,27 +61,25 @@ class TagServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@defer.inlineCallbacks async def on_PUT(self, request, user_id, room_id, tag):
def on_PUT(self, request, user_id, room_id, tag): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_DELETE(self, request, user_id, room_id, tag):
def on_DELETE(self, request, user_id, room_id, tag): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string(): if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) self.notifier.on_new_event("account_data_key", max_id, users=[user_id])

View File

@ -16,8 +16,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
@ -35,11 +33,10 @@ class ThirdPartyProtocolsServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
protocols = yield self.appservice_handler.get_3pe_protocols() protocols = await self.appservice_handler.get_3pe_protocols()
return 200, protocols return 200, protocols
@ -52,11 +49,10 @@ class ThirdPartyProtocolServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks async def on_GET(self, request, protocol):
def on_GET(self, request, protocol): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
protocols = yield self.appservice_handler.get_3pe_protocols( protocols = await self.appservice_handler.get_3pe_protocols(
only_protocol=protocol only_protocol=protocol
) )
if protocol in protocols: if protocol in protocols:
@ -74,14 +70,13 @@ class ThirdPartyUserServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks async def on_GET(self, request, protocol):
def on_GET(self, request, protocol): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args fields = request.args
fields.pop(b"access_token", None) fields.pop(b"access_token", None)
results = yield self.appservice_handler.query_3pe( results = await self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields ThirdPartyEntityKind.USER, protocol, fields
) )
@ -97,14 +92,13 @@ class ThirdPartyLocationServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks async def on_GET(self, request, protocol):
def on_GET(self, request, protocol): await self.auth.get_user_by_req(request, allow_guest=True)
yield self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args fields = request.args
fields.pop(b"access_token", None) fields.pop(b"access_token", None)
results = yield self.appservice_handler.query_3pe( results = await self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields ThirdPartyEntityKind.LOCATION, protocol, fields
) )

View File

@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
@ -32,8 +30,7 @@ class TokenRefreshRestServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__() super(TokenRefreshRestServlet, self).__init__()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
raise AuthError(403, "tokenrefresh is no longer supported.") raise AuthError(403, "tokenrefresh is no longer supported.")

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -38,8 +36,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
"""Searches for users in directory """Searches for users in directory
Returns: Returns:
@ -56,7 +53,7 @@ class UserDirectorySearchRestServlet(RestServlet):
] ]
} }
""" """
requester = yield self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string() user_id = requester.user.to_string()
if not self.hs.config.user_directory_search_enabled: if not self.hs.config.user_directory_search_enabled:
@ -72,7 +69,7 @@ class UserDirectorySearchRestServlet(RestServlet):
except Exception: except Exception:
raise SynapseError(400, "`search_term` is required field") raise SynapseError(400, "`search_term` is required field")
results = yield self.user_directory_handler.search_users( results = await self.user_directory_handler.search_users(
user_id, search_term, limit user_id, search_term, limit
) )

View File

@ -19,8 +19,6 @@ import calendar
import logging import logging
import time import time
from twisted.internet import defer
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
@ -476,7 +474,7 @@ class DataStore(
) )
def get_users(self): def get_users(self):
"""Function to reterive a list of users in users table. """Function to retrieve a list of users in users table.
Args: Args:
Returns: Returns:
@ -485,38 +483,59 @@ class DataStore(
return self.db.simple_select_list( return self.db.simple_select_list(
table="users", table="users",
keyvalues={}, keyvalues={},
retcols=["name", "password_hash", "is_guest", "admin", "user_type"], retcols=[
"name",
"password_hash",
"is_guest",
"admin",
"user_type",
"deactivated",
],
desc="get_users", desc="get_users",
) )
@defer.inlineCallbacks def get_users_paginate(
def get_users_paginate(self, order, start, limit): self, start, limit, name=None, guests=True, deactivated=False
"""Function to reterive a paginated list of users from ):
users list. This will return a json object, which contains """Function to retrieve a paginated list of users from
list of users and the total number of users in users table. users list. This will return a json list of users.
Args: Args:
order (str): column name to order the select by this column
start (int): start number to begin the query from start (int): start number to begin the query from
limit (int): number of rows to reterive limit (int): number of rows to retrieve
name (string): filter for user names
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
Returns: Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count} defer.Deferred: resolves to list[dict[str, Any]]
""" """
users = yield self.db.runInteraction( name_filter = {}
"get_users_paginate", if name:
self.db.simple_select_list_paginate_txn, name_filter["name"] = "%" + name + "%"
attr_filter = {}
if not guests:
attr_filter["is_guest"] = False
if not deactivated:
attr_filter["deactivated"] = False
return self.db.simple_select_list_paginate(
desc="get_users_paginate",
table="users", table="users",
keyvalues={"is_guest": False}, orderby="name",
orderby=order,
start=start, start=start,
limit=limit, limit=limit,
retcols=["name", "password_hash", "is_guest", "admin", "user_type"], filters=name_filter,
keyvalues=attr_filter,
retcols=[
"name",
"password_hash",
"is_guest",
"admin",
"user_type",
"deactivated",
],
) )
count = yield self.db.runInteraction(
"get_users_paginate", self.get_user_count_txn
)
retval = {"users": users, "total": count}
return retval
def search_users(self, term): def search_users(self, term):
"""Function to search users list for one or more users with """Function to search users list for one or more users with

View File

@ -260,11 +260,11 @@ class StatsStore(StateDeltasStore):
slice_list = self.db.simple_select_list_paginate_txn( slice_list = self.db.simple_select_list_paginate_txn(
txn, txn,
table + "_historical", table + "_historical",
{id_col: stats_id},
"end_ts", "end_ts",
start, start,
size, size,
retcols=selected_columns + ["bucket_size", "end_ts"], retcols=selected_columns + ["bucket_size", "end_ts"],
keyvalues={id_col: stats_id},
order_direction="DESC", order_direction="DESC",
) )

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import random
import sys import sys
import time import time
from typing import Iterable, Tuple from typing import Iterable, Tuple
@ -1321,11 +1320,12 @@ class Database(object):
def simple_select_list_paginate( def simple_select_list_paginate(
self, self,
table, table,
keyvalues,
orderby, orderby,
start, start,
limit, limit,
retcols, retcols,
filters=None,
keyvalues=None,
order_direction="ASC", order_direction="ASC",
desc="simple_select_list_paginate", desc="simple_select_list_paginate",
): ):
@ -1336,6 +1336,9 @@ class Database(object):
Args: Args:
table (str): the table name table (str): the table name
filters (dict[str, T] | None):
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
keyvalues (dict[str, T] | None): keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not column names and values to select the rows with, or None to not
apply a WHERE clause. apply a WHERE clause.
@ -1351,11 +1354,12 @@ class Database(object):
desc, desc,
self.simple_select_list_paginate_txn, self.simple_select_list_paginate_txn,
table, table,
keyvalues,
orderby, orderby,
start, start,
limit, limit,
retcols, retcols,
filters=filters,
keyvalues=keyvalues,
order_direction=order_direction, order_direction=order_direction,
) )
@ -1364,11 +1368,12 @@ class Database(object):
cls, cls,
txn, txn,
table, table,
keyvalues,
orderby, orderby,
start, start,
limit, limit,
retcols, retcols,
filters=None,
keyvalues=None,
order_direction="ASC", order_direction="ASC",
): ):
""" """
@ -1376,16 +1381,23 @@ class Database(object):
of row numbers, which may return zero or number of rows from start to limit, of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts. returning the result as a list of dicts.
Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
select attributes with exact matches. All constraints are joined together
using 'AND'.
Args: Args:
txn : Transaction object txn : Transaction object
table (str): the table name table (str): the table name
keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
orderby (str): Column to order the results by. orderby (str): Column to order the results by.
start (int): Index to begin the query at. start (int): Index to begin the query at.
limit (int): Number of results to return. limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return retcols (iterable[str]): the names of the columns to return
filters (dict[str, T] | None):
column names and values to filter the rows with, or None to not
apply a WHERE ? LIKE ? clause.
keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
order_direction (str): Whether the results should be ordered "ASC" or "DESC". order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] defer.Deferred: resolves to list[dict[str, Any]]
@ -1393,10 +1405,15 @@ class Database(object):
if order_direction not in ["ASC", "DESC"]: if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.") raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
where_clause = "WHERE " if filters or keyvalues else ""
arg_list = []
if filters:
where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
arg_list += list(filters.values())
where_clause += " AND " if filters and keyvalues else ""
if keyvalues: if keyvalues:
where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues) where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
else: arg_list += list(keyvalues.values())
where_clause = ""
sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % ( sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
", ".join(retcols), ", ".join(retcols),
@ -1405,22 +1422,10 @@ class Database(object):
orderby, orderby,
order_direction, order_direction,
) )
txn.execute(sql, list(keyvalues.values()) + [limit, start]) txn.execute(sql, arg_list + [limit, start])
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
def get_user_count_txn(self, txn):
"""Get a total number of registered users in the users list.
Args:
txn : Transaction object
Returns:
int : number of users
"""
sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
txn.execute(sql_count)
return txn.fetchone()[0]
def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"): def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.

View File

@ -15,6 +15,8 @@
from mock import Mock from mock import Mock
from twisted.internet import defer
from synapse.rest.client.v1 import presence from synapse.rest.client.v1 import presence
from synapse.types import UserID from synapse.types import UserID
@ -36,6 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
) )
hs.presence_handler = Mock() hs.presence_handler = Mock()
hs.presence_handler.set_state.return_value = defer.succeed(None)
return hs return hs

View File

@ -52,6 +52,14 @@ class MockHandlerProfileTestCase(unittest.TestCase):
] ]
) )
self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
Mock()
)
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
self.addCleanup, self.addCleanup,
"test", "test",
@ -63,7 +71,7 @@ class MockHandlerProfileTestCase(unittest.TestCase):
) )
def _get_user_by_req(request=None, allow_guest=False): def _get_user_by_req(request=None, allow_guest=False):
return synapse.types.create_requester(myid) return defer.succeed(synapse.types.create_requester(myid))
hs.get_auth().get_user_by_req = _get_user_by_req hs.get_auth().get_user_by_req = _get_user_by_req

View File

@ -461,7 +461,9 @@ class MockHttpResource(HttpServer):
try: try:
args = [urlparse.unquote(u) for u in matcher.groups()] args = [urlparse.unquote(u) for u in matcher.groups()]
(code, response) = yield func(mock_request, *args) (code, response) = yield defer.ensureDeferred(
func(mock_request, *args)
)
return code, response return code, response
except CodeMessageException as e: except CodeMessageException as e:
return (e.code, cs_error(e.msg, code=e.errcode)) return (e.code, cs_error(e.msg, code=e.errcode))