Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

michaelkaye/matrix_org_hotfixes_increase_replication_timeout
Erik Johnston 2019-10-02 14:09:29 +01:00
commit 610219d53d
45 changed files with 899 additions and 860 deletions

1
.gitignore vendored
View File

@ -10,6 +10,7 @@
*.tac *.tac
_trial_temp/ _trial_temp/
_trial_temp*/ _trial_temp*/
/out
# stuff that is likely to exist when you run a server locally # stuff that is likely to exist when you run a server locally
/*.db /*.db

1
changelog.d/5978.misc Normal file
View File

@ -0,0 +1 @@
Move lookup-related functions from RoomMemberHandler to IdentityHandler.

1
changelog.d/6019.misc Normal file
View File

@ -0,0 +1 @@
Improve performance of the public room list directory.

1
changelog.d/6077.misc Normal file
View File

@ -0,0 +1 @@
Edit header dicts docstrings in SimpleHttpClient to note that `str` or `bytes` can be passed as header keys.

1
changelog.d/6101.misc Normal file
View File

@ -0,0 +1 @@
Kill off half-implemented password-reset via sms.

1
changelog.d/6108.misc Normal file
View File

@ -0,0 +1 @@
Remove `get_user_by_req` opentracing span and add some tags.

1
changelog.d/6115.misc Normal file
View File

@ -0,0 +1 @@
Drop some unused database tables.

1
changelog.d/6125.feature Normal file
View File

@ -0,0 +1 @@
Reject all pending invites for a user during deactivation.

1
changelog.d/6144.bugfix Normal file
View File

@ -0,0 +1 @@
Prevent user push rules being deleted from a room when it is upgraded.

1
changelog.d/6150.misc Normal file
View File

@ -0,0 +1 @@
Expand type-checking on modules imported by synapse.config.

View File

@ -179,7 +179,6 @@ class Auth(object):
def get_public_keys(self, invite_event): def get_public_keys(self, invite_event):
return event_auth.get_public_keys(invite_event) return event_auth.get_public_keys(invite_event)
@opentracing.trace
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req( def get_user_by_req(
self, request, allow_guest=False, rights="access", allow_expired=False self, request, allow_guest=False, rights="access", allow_expired=False
@ -212,6 +211,7 @@ class Auth(object):
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self.hs.config.track_appservice_user_ips: if ip_addr and self.hs.config.track_appservice_user_ips:
yield self.store.insert_client_ip( yield self.store.insert_client_ip(
@ -263,6 +263,8 @@ class Auth(object):
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
opentracing.set_tag("authenticated_entity", user.to_string()) opentracing.set_tag("authenticated_entity", user.to_string())
if device_id:
opentracing.set_tag("device_id", device_id)
return synapse.types.create_requester( return synapse.types.create_requester(
user, token_id, is_guest, device_id, app_service=app_service user, token_id, is_guest, device_id, app_service=app_service

View File

@ -17,6 +17,7 @@
"""Contains exceptions and error codes.""" """Contains exceptions and error codes."""
import logging import logging
from typing import Dict
from six import iteritems from six import iteritems
from six.moves import http_client from six.moves import http_client
@ -111,7 +112,7 @@ class ProxiedRequestError(SynapseError):
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
super(ProxiedRequestError, self).__init__(code, msg, errcode) super(ProxiedRequestError, self).__init__(code, msg, errcode)
if additional_fields is None: if additional_fields is None:
self._additional_fields = {} self._additional_fields = {} # type: Dict
else: else:
self._additional_fields = dict(additional_fields) self._additional_fields = dict(additional_fields)

View File

@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict
import attr import attr
@ -102,4 +105,4 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V4, RoomVersions.V4,
RoomVersions.V5, RoomVersions.V5,
) )
} # type: dict[str, RoomVersion] } # type: Dict[str, RoomVersion]

View File

@ -263,7 +263,9 @@ def start(hs, listeners=None):
refresh_certificate(hs) refresh_certificate(hs)
# Start the tracer # Start the tracer
synapse.logging.opentracing.init_tracer(hs.config) synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
hs.config
)
# It is now safe to start your Synapse. # It is now safe to start your Synapse.
hs.start_listening(listeners) hs.start_listening(listeners)

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict
from six import string_types from six import string_types
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
@ -56,8 +57,8 @@ def load_appservices(hostname, config_files):
return [] return []
# Dicts of value -> filename # Dicts of value -> filename
seen_as_tokens = {} seen_as_tokens = {} # type: Dict[str, str]
seen_ids = {} seen_ids = {} # type: Dict[str, str]
appservices = [] appservices = []

View File

@ -73,8 +73,8 @@ DEFAULT_CONFIG = """\
class ConsentConfig(Config): class ConsentConfig(Config):
def __init__(self): def __init__(self, *args):
super(ConsentConfig, self).__init__() super(ConsentConfig, self).__init__(*args)
self.user_consent_version = None self.user_consent_version = None
self.user_consent_template_dir = None self.user_consent_template_dir = None

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, List
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
from ._base import Config from ._base import Config
@ -22,7 +24,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config): class PasswordAuthProviderConfig(Config):
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.password_providers = [] self.password_providers = [] # type: List[Any]
providers = [] providers = []
# We want to be backwards compatible with the old `ldap_config` # We want to be backwards compatible with the old `ldap_config`

View File

@ -15,6 +15,7 @@
import os import os
from collections import namedtuple from collections import namedtuple
from typing import Dict, List
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
@ -61,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
Dictionary mapping from media type string to list of Dictionary mapping from media type string to list of
ThumbnailRequirement tuples. ThumbnailRequirement tuples.
""" """
requirements = {} requirements = {} # type: Dict[str, List]
for size in thumbnail_sizes: for size in thumbnail_sizes:
width = size["width"] width = size["width"]
height = size["height"] height = size["height"]
@ -130,7 +131,7 @@ class ContentRepositoryConfig(Config):
# #
# We don't create the storage providers here as not all workers need # We don't create the storage providers here as not all workers need
# them to be started. # them to be started.
self.media_storage_providers = [] self.media_storage_providers = [] # type: List[tuple]
for provider_config in storage_providers: for provider_config in storage_providers:
# We special case the module "file_system" so as not to need to # We special case the module "file_system" so as not to need to

View File

@ -19,6 +19,7 @@ import logging
import os.path import os.path
import re import re
from textwrap import indent from textwrap import indent
from typing import List
import attr import attr
import yaml import yaml
@ -243,7 +244,7 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile. # events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
self.listeners = [] self.listeners = [] # type: List[dict]
for listener in config.get("listeners", []): for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int): if not isinstance(listener.get("port", None), int):
raise ConfigError( raise ConfigError(
@ -287,7 +288,10 @@ class ServerConfig(Config):
validator=attr.validators.instance_of(bool), default=False validator=attr.validators.instance_of(bool), default=False
) )
complexity = attr.ib( complexity = attr.ib(
validator=attr.validators.instance_of((int, float)), default=1.0 validator=attr.validators.instance_of(
(float, int) # type: ignore[arg-type] # noqa
),
default=1.0,
) )
complexity_error = attr.ib( complexity_error = attr.ib(
validator=attr.validators.instance_of(str), validator=attr.validators.instance_of(str),
@ -366,7 +370,7 @@ class ServerConfig(Config):
"cleanup_extremities_with_dummy_events", True "cleanup_extremities_with_dummy_events", True
) )
def has_tls_listener(self): def has_tls_listener(self) -> bool:
return any(l["tls"] for l in self.listeners) return any(l["tls"] for l in self.listeners)
def generate_config_section( def generate_config_section(

View File

@ -59,8 +59,8 @@ class ServerNoticesConfig(Config):
None if server notices are not enabled. None if server notices are not enabled.
""" """
def __init__(self): def __init__(self, *args):
super(ServerNoticesConfig, self).__init__() super(ServerNoticesConfig, self).__init__(*args)
self.server_notices_mxid = None self.server_notices_mxid = None
self.server_notices_mxid_display_name = None self.server_notices_mxid_display_name = None
self.server_notices_mxid_avatar_url = None self.server_notices_mxid_avatar_url = None

View File

@ -765,6 +765,10 @@ class PublicRoomList(BaseFederationServlet):
else: else:
network_tuple = ThirdPartyInstanceID(None, None) network_tuple = ThirdPartyInstanceID(None, None)
if limit == 0:
# zero is a special value which corresponds to no limit.
limit = None
data = await maybeDeferred( data = await maybeDeferred(
self.handler.get_local_public_room_list, self.handler.get_local_public_room_list,
limit, limit,
@ -800,6 +804,10 @@ class PublicRoomList(BaseFederationServlet):
if search_filter is None: if search_filter is None:
logger.warning("Nonefilter") logger.warning("Nonefilter")
if limit == 0:
# zero is a special value which corresponds to no limit.
limit = None
data = await self.handler.get_local_public_room_list( data = await self.handler.get_local_public_room_list(
limit=limit, limit=limit,
since_token=since_token, since_token=since_token,

View File

@ -120,6 +120,10 @@ class DeactivateAccountHandler(BaseHandler):
# parts users from rooms (if it isn't already running) # parts users from rooms (if it isn't already running)
self._start_user_parting() self._start_user_parting()
# Reject all pending invites for the user, so that the user doesn't show up in the
# "invited" section of rooms' members list.
yield self._reject_pending_invites_for_user(user_id)
# Remove all information on the user from the account_validity table. # Remove all information on the user from the account_validity table.
if self._account_validity_enabled: if self._account_validity_enabled:
yield self.store.delete_account_validity_for_user(user_id) yield self.store.delete_account_validity_for_user(user_id)
@ -129,6 +133,39 @@ class DeactivateAccountHandler(BaseHandler):
return identity_server_supports_unbinding return identity_server_supports_unbinding
@defer.inlineCallbacks
def _reject_pending_invites_for_user(self, user_id):
"""Reject pending invites addressed to a given user ID.
Args:
user_id (str): The user ID to reject pending invites for.
"""
user = UserID.from_string(user_id)
pending_invites = yield self.store.get_invited_rooms_for_user(user_id)
for room in pending_invites:
try:
yield self._room_member_handler.update_membership(
create_requester(user),
user,
room.room_id,
"leave",
ratelimit=False,
require_consent=False,
)
logger.info(
"Rejected invite for deactivated user %r in room %r",
user_id,
room.room_id,
)
except Exception:
logger.exception(
"Failed to reject invite for user %r in room %r:"
" ignoring and continuing",
user_id,
room.room_id,
)
def _start_user_parting(self): def _start_user_parting(self):
""" """
Start the process that goes through the table of users Start the process that goes through the table of users

View File

@ -21,11 +21,15 @@ import logging
import urllib import urllib
from canonicaljson import json from canonicaljson import json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.error import TimeoutError from twisted.internet.error import TimeoutError
from synapse.api.errors import ( from synapse.api.errors import (
AuthError,
CodeMessageException, CodeMessageException,
Codes, Codes,
HttpResponseException, HttpResponseException,
@ -33,12 +37,15 @@ from synapse.api.errors import (
) )
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.util.hash import sha256_and_url_safe_base64
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from ._base import BaseHandler from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
id_server_scheme = "https://"
class IdentityHandler(BaseHandler): class IdentityHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -557,6 +564,352 @@ class IdentityHandler(BaseHandler):
logger.warning("Error contacting msisdn account_threepid_delegate: %s", e) logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
raise SynapseError(400, "Error contacting the identity server") raise SynapseError(400, "Error contacting the identity server")
@defer.inlineCallbacks
def lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
id_access_token (str|None): The access token to authenticate to the identity
server with
Returns:
str|None: the matrix ID of the 3pid, or None if it is not recognized.
"""
if id_access_token is not None:
try:
results = yield self._lookup_3pid_v2(
id_server, id_access_token, medium, address
)
return results
except Exception as e:
# Catch HttpResponseExcept for a non-200 response code
# Check if this identity server does not know about v2 lookups
if isinstance(e, HttpResponseException) and e.code == 404:
# This is an old identity server that does not yet support v2 lookups
logger.warning(
"Attempted v2 lookup on v1 identity server %s. Falling "
"back to v1",
id_server,
)
else:
logger.warning("Error when looking up hashing details: %s", e)
return None
return (yield self._lookup_3pid_v1(id_server, medium, address))
@defer.inlineCallbacks
def _lookup_3pid_v1(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
{"medium": medium, "address": address},
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
yield self._verify_any_signature(data, id_server)
return data["mxid"]
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except IOError as e:
logger.warning("Error from v1 identity server lookup: %s" % (e,))
return None
@defer.inlineCallbacks
def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
id_access_token (str): The access token to authenticate to the identity server with
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised.
"""
# Check what hashing details are supported by this identity server
try:
hash_details = yield self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token},
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
if not isinstance(hash_details, dict):
logger.warning(
"Got non-dict object when checking hash details of %s%s: %s",
id_server_scheme,
id_server,
hash_details,
)
raise SynapseError(
400,
"Non-dict object from %s%s during v2 hash_details request: %s"
% (id_server_scheme, id_server, hash_details),
)
# Extract information from hash_details
supported_lookup_algorithms = hash_details.get("algorithms")
lookup_pepper = hash_details.get("lookup_pepper")
if (
not supported_lookup_algorithms
or not isinstance(supported_lookup_algorithms, list)
or not lookup_pepper
or not isinstance(lookup_pepper, str)
):
raise SynapseError(
400,
"Invalid hash details received from identity server %s%s: %s"
% (id_server_scheme, id_server, hash_details),
)
# Check if any of the supported lookup algorithms are present
if LookupAlgorithm.SHA256 in supported_lookup_algorithms:
# Perform a hashed lookup
lookup_algorithm = LookupAlgorithm.SHA256
# Hash address, medium and the pepper with sha256
to_hash = "%s %s %s" % (address, medium, lookup_pepper)
lookup_value = sha256_and_url_safe_base64(to_hash)
elif LookupAlgorithm.NONE in supported_lookup_algorithms:
# Perform a non-hashed lookup
lookup_algorithm = LookupAlgorithm.NONE
# Combine together plaintext address and medium
lookup_value = "%s %s" % (address, medium)
else:
logger.warning(
"None of the provided lookup algorithms of %s are supported: %s",
id_server,
supported_lookup_algorithms,
)
raise SynapseError(
400,
"Provided identity server does not support any v2 lookup "
"algorithms that this homeserver supports.",
)
# Authenticate with identity server given the access token from the client
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
lookup_results = yield self.blacklisting_http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{
"addresses": [lookup_value],
"algorithm": lookup_algorithm,
"pepper": lookup_pepper,
},
headers=headers,
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except Exception as e:
logger.warning("Error when performing a v2 3pid lookup: %s", e)
raise SynapseError(
500, "Unknown error occurred during identity server lookup"
)
# Check for a mapping from what we looked up to an MXID
if "mappings" not in lookup_results or not isinstance(
lookup_results["mappings"], dict
):
logger.warning("No results from 3pid lookup")
return None
# Return the MXID if it's available, or None otherwise
mxid = lookup_results["mappings"].get(lookup_value)
return mxid
@defer.inlineCallbacks
def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
try:
key_data = yield self.blacklisting_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s"
% (id_server_scheme, server_hostname, key_name)
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
if "public_key" not in key_data:
raise AuthError(
401, "No public key named %s from %s" % (key_name, server_hostname)
)
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(
key_name, decode_base64(key_data["public_key"])
),
)
return
@defer.inlineCallbacks
def ask_id_server_for_third_party_invite(
self,
requester,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url,
id_access_token=None,
):
"""
Asks an identity server for a third party invite.
Args:
requester (Requester)
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
room_id (str): The ID of the room to which the user is invited.
inviter_user_id (str): The user ID of the inviter.
room_alias (str): An alias for the room, for cosmetic notifications.
room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
room_join_rules (str): The join rules of the email (e.g. "public").
room_name (str): The m.room.name of the room.
inviter_display_name (str): The current display name of the
inviter.
inviter_avatar_url (str): The URL of the inviter's avatar.
id_access_token (str|None): The access token to authenticate to the identity
server with
Returns:
A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
data = None
base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
if id_access_token:
key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
id_server_scheme,
id_server,
)
# Attempt a v2 lookup
url = base_url + "/v2/store-invite"
try:
data = yield self.blacklisting_http_client.post_json_get_json(
url,
invite_config,
{"Authorization": create_id_access_token_header(id_access_token)},
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
if e.code != 404:
logger.info("Failed to POST %s with JSON: %s", url, e)
raise e
if data is None:
key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme,
id_server,
)
url = base_url + "/api/v1/store-invite"
try:
data = yield self.blacklisting_http_client.post_json_get_json(
url, invite_config
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning(
"Error trying to call /store-invite on %s%s: %s",
id_server_scheme,
id_server,
e,
)
if data is None:
# Some identity servers may only support application/x-www-form-urlencoded
# types. This is especially true with old instances of Sydent, see
# https://github.com/matrix-org/sydent/pull/170
try:
data = yield self.blacklisting_http_client.post_urlencoded_get_json(
url, invite_config
)
except HttpResponseException as e:
logger.warning(
"Error calling /store-invite on %s%s with fallback "
"encoding: %s",
id_server_scheme,
id_server,
e,
)
raise e
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": key_validity_url,
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
return token, public_keys, fallback_public_key, display_name
def create_id_access_token_header(id_access_token): def create_id_access_token_header(id_access_token):
"""Create an Authorization header for passing to SimpleHttpClient as the header value """Create an Authorization header for passing to SimpleHttpClient as the header value

View File

@ -16,8 +16,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from six import PY3, iteritems from six import iteritems
from six.moves import range
import msgpack import msgpack
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
@ -27,7 +26,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, HttpResponseException from synapse.api.errors import Codes, HttpResponseException
from synapse.types import ThirdPartyInstanceID from synapse.types import ThirdPartyInstanceID
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -37,7 +35,6 @@ logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
# This is used to indicate we should only return rooms published to the main list. # This is used to indicate we should only return rooms published to the main list.
EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
@ -73,6 +70,8 @@ class RoomListHandler(BaseHandler):
This can be (None, None) to indicate the main list, or a particular This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one. appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists. Setting to None returns all public rooms across all lists.
from_federation (bool): true iff the request comes from the federation
API
""" """
if not self.enable_room_list_search: if not self.enable_room_list_search:
return defer.succeed({"chunk": [], "total_room_count_estimate": 0}) return defer.succeed({"chunk": [], "total_room_count_estimate": 0})
@ -134,239 +133,109 @@ class RoomListHandler(BaseHandler):
from_federation (bool): Whether this request originated from a from_federation (bool): Whether this request originated from a
federating server or a client. Used for room filtering. federating server or a client. Used for room filtering.
timeout (int|None): Amount of seconds to wait for a response before timeout (int|None): Amount of seconds to wait for a response before
timing out. timing out. TODO
""" """
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
since_token = None
rooms_to_order_value = {} # Pagination tokens work by storing the room ID sent in the last batch,
rooms_to_num_joined = {} # plus the direction (forwards or backwards). Next batch tokens always
# go forwards, prev batch tokens always go backwards.
newly_visible = []
newly_unpublished = []
if since_token:
stream_token = since_token.stream_ordering
current_public_id = yield self.store.get_current_public_room_stream_id()
public_room_stream_id = since_token.public_room_stream_id
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
public_room_stream_id, current_public_id, network_tuple=network_tuple
)
else:
stream_token = yield self.store.get_room_max_stream_ordering()
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
room_ids = yield self.store.get_public_room_ids_at_stream_id(
public_room_stream_id, network_tuple=network_tuple
)
# We want to return rooms in a particular order: the number of joined
# users. We then arbitrarily use the room_id as a tie breaker.
@defer.inlineCallbacks
def get_order_for_room(room_id):
# Most of the rooms won't have changed between the since token and
# now (especially if the since token is "now"). So, we can ask what
# the current users are in a room (that will hit a cache) and then
# check if the room has changed since the since token. (We have to
# do it in that order to avoid races).
# If things have changed then fall back to getting the current state
# at the since token.
joined_users = yield self.store.get_users_in_room(room_id)
if self.store.has_room_changed_since(room_id, stream_token):
latest_event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_token
)
if not latest_event_ids:
return
joined_users = yield self.state_handler.get_current_users_in_room(
room_id, latest_event_ids
)
num_joined_users = len(joined_users)
rooms_to_num_joined[room_id] = num_joined_users
if num_joined_users == 0:
return
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
logger.info(
"Getting ordering for %i rooms since %s", len(room_ids), stream_token
)
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
sorted_rooms = [room_id for room_id, _ in sorted_entries]
# `sorted_rooms` should now be a list of all public room ids that is
# stable across pagination. Therefore, we can use indices into this
# list as our pagination tokens.
# Filter out rooms that we don't want to return
rooms_to_scan = [
r
for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[r] > 0
]
total_room_count = len(rooms_to_scan)
if since_token: if since_token:
# Filter out rooms we've already returned previously batch_token = RoomListNextBatch.from_token(since_token)
# `since_token.current_limit` is the index of the last room we
# sent down, so we exclude it and everything before/after it. last_room_id = batch_token.last_room_id
if since_token.direction_is_forward: forwards = batch_token.direction_is_forward
rooms_to_scan = rooms_to_scan[since_token.current_limit + 1 :]
else: else:
rooms_to_scan = rooms_to_scan[: since_token.current_limit] batch_token = None
rooms_to_scan.reverse()
logger.info("After sorting and filtering, %i rooms remain", len(rooms_to_scan)) last_room_id = None
forwards = True
# _append_room_entry_to_chunk will append to chunk but will stop if # we request one more than wanted to see if there are more pages to come
# len(chunk) > limit probing_limit = limit + 1 if limit is not None else None
#
# Normally we will generate enough results on the first iteration here,
# but if there is a search filter, _append_room_entry_to_chunk may
# filter some results out, in which case we loop again.
#
# We don't want to scan over the entire range either as that
# would potentially waste a lot of work.
#
# XXX if there is no limit, we may end up DoSing the server with
# calls to get_current_state_ids for every single room on the
# server. Surely we should cap this somehow?
#
if limit:
step = limit + 1
else:
# step cannot be zero
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
chunk = [] results = yield self.store.get_largest_public_rooms(
for i in range(0, len(rooms_to_scan), step): network_tuple,
if timeout and self.clock.time() > timeout:
raise Exception("Timed out searching room directory")
batch = rooms_to_scan[i : i + step]
logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(
lambda r: self._append_room_entry_to_chunk(
r,
rooms_to_num_joined[r],
chunk,
limit,
search_filter, search_filter,
from_federation=from_federation, probing_limit,
), last_room_id=last_room_id,
batch, forwards=forwards,
5, ignore_non_federatable=from_federation,
) )
logger.info("Now %i rooms in result", len(chunk))
if len(chunk) >= limit + 1:
break
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"])) def build_room_entry(room):
entry = {
"room_id": room["room_id"],
"name": room["name"],
"topic": room["topic"],
"canonical_alias": room["canonical_alias"],
"num_joined_members": room["joined_members"],
"avatar_url": room["avatar"],
"world_readable": room["history_visibility"] == "world_readable",
"guest_can_join": room["guest_access"] == "can_join",
}
# Work out the new limit of the batch for pagination, or None if we # Filter out Nones rather omit the field altogether
# know there are no more results that would be returned. return {k: v for k, v in entry.items() if v is not None}
# i.e., [since_token.current_limit..new_limit] is the batch of rooms
# we've returned (or the reverse if we paginated backwards)
# We tried to pull out limit + 1 rooms above, so if we have <= limit
# then we know there are no more results to return
new_limit = None
if chunk and (not limit or len(chunk) > limit):
if not since_token or since_token.direction_is_forward: results = [build_room_entry(r) for r in results]
if limit:
chunk = chunk[:limit] response = {}
last_room_id = chunk[-1]["room_id"] num_results = len(results)
if limit is not None:
more_to_come = num_results == probing_limit
# Depending on direction we trim either the front or back.
if forwards:
results = results[:limit]
else: else:
if limit: results = results[-limit:]
chunk = chunk[-limit:] else:
last_room_id = chunk[0]["room_id"] more_to_come = False
new_limit = sorted_rooms.index(last_room_id) if num_results > 0:
final_room_id = results[-1]["room_id"]
initial_room_id = results[0]["room_id"]
results = {"chunk": chunk, "total_room_count_estimate": total_room_count} if forwards:
if batch_token:
if since_token: # If there was a token given then we assume that there
results["new_rooms"] = bool(newly_visible) # must be previous results.
response["prev_batch"] = RoomListNextBatch(
if not since_token or since_token.direction_is_forward: last_room_id=initial_room_id, direction_is_forward=False
if new_limit is not None:
results["next_batch"] = RoomListNextBatch(
stream_ordering=stream_token,
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=True,
).to_token() ).to_token()
if since_token: if more_to_come:
results["prev_batch"] = since_token.copy_and_replace( response["next_batch"] = RoomListNextBatch(
direction_is_forward=False, last_room_id=final_room_id, direction_is_forward=True
current_limit=since_token.current_limit + 1,
).to_token() ).to_token()
else: else:
if new_limit is not None: if batch_token:
results["prev_batch"] = RoomListNextBatch( response["next_batch"] = RoomListNextBatch(
stream_ordering=stream_token, last_room_id=final_room_id, direction_is_forward=True
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=False,
).to_token() ).to_token()
if since_token: if more_to_come:
results["next_batch"] = since_token.copy_and_replace( response["prev_batch"] = RoomListNextBatch(
direction_is_forward=True, last_room_id=initial_room_id, direction_is_forward=False
current_limit=since_token.current_limit - 1,
).to_token() ).to_token()
return results for room in results:
# populate search result entries with additional fields, namely
# 'aliases'
room_id = room["room_id"]
@defer.inlineCallbacks aliases = yield self.store.get_aliases_for_room(room_id)
def _append_room_entry_to_chunk( if aliases:
self, room["aliases"] = aliases
room_id,
num_joined_users,
chunk,
limit,
search_filter,
from_federation=False,
):
"""Generate the entry for a room in the public room list and append it
to the `chunk` if it matches the search filter
Args: response["chunk"] = results
room_id (str): The ID of the room.
num_joined_users (int): The number of joined users in the room.
chunk (list)
limit (int|None): Maximum amount of rooms to display. Function will
return if length of chunk is greater than limit + 1.
search_filter (dict|None)
from_federation (bool): Whether this request originated from a
federating server or a client. Used for room filtering.
"""
if limit and len(chunk) > limit + 1:
# We've already got enough, so lets just drop it.
return
result = yield self.generate_room_entry(room_id, num_joined_users) response["total_room_count_estimate"] = yield self.store.count_public_rooms(
if not result: network_tuple, ignore_non_federatable=from_federation
return )
if from_federation and not result.get("m.federate", True): return response
# This is a room that other servers cannot join. Do not show them
# this room.
return
if _matches_room_entry(result, search_filter):
chunk.append(result)
@cachedInlineCallbacks(num_args=1, cache_context=True) @cachedInlineCallbacks(num_args=1, cache_context=True)
def generate_room_entry( def generate_room_entry(
@ -581,32 +450,18 @@ class RoomListNextBatch(
namedtuple( namedtuple(
"RoomListNextBatch", "RoomListNextBatch",
( (
"stream_ordering", # stream_ordering of the first public room list "last_room_id", # The room_id to get rooms after/before
"public_room_stream_id", # public room stream id for first public room list
"current_limit", # The number of previous rooms returned
"direction_is_forward", # Bool if this is a next_batch, false if prev_batch "direction_is_forward", # Bool if this is a next_batch, false if prev_batch
), ),
) )
): ):
KEY_DICT = {"last_room_id": "r", "direction_is_forward": "d"}
KEY_DICT = {
"stream_ordering": "s",
"public_room_stream_id": "p",
"current_limit": "n",
"direction_is_forward": "d",
}
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()} REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@classmethod @classmethod
def from_token(cls, token): def from_token(cls, token):
if PY3:
# The argument raw=False is only available on new versions of
# msgpack, and only really needed on Python 3. Gate it behind
# a PY3 check to avoid causing issues on Debian-packaged versions.
decoded = msgpack.loads(decode_base64(token), raw=False) decoded = msgpack.loads(decode_base64(token), raw=False)
else:
decoded = msgpack.loads(decode_base64(token))
return RoomListNextBatch( return RoomListNextBatch(
**{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()} **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
) )

View File

@ -20,29 +20,19 @@ import logging
from six.moves import http_client from six.moves import http_client
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.error import TimeoutError
from synapse import types from synapse import types
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.handlers.identity import LookupAlgorithm, create_id_access_token_header
from synapse.http.client import SimpleHttpClient
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room from synapse.util.distributor import user_joined_room, user_left_room
from synapse.util.hash import sha256_and_url_safe_base64
from ._base import BaseHandler from ._base import BaseHandler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
id_server_scheme = "https://"
class RoomMemberHandler(object): class RoomMemberHandler(object):
# TODO(paul): This handler currently contains a messy conflation of # TODO(paul): This handler currently contains a messy conflation of
@ -63,14 +53,10 @@ class RoomMemberHandler(object):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.config = hs.config self.config = hs.config
# We create a blacklisting instance of SimpleHttpClient for contacting identity
# servers specified by clients
self.simple_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
self.federation_handler = hs.get_handlers().federation_handler self.federation_handler = hs.get_handlers().federation_handler
self.directory_handler = hs.get_handlers().directory_handler self.directory_handler = hs.get_handlers().directory_handler
self.identity_handler = hs.get_handlers().identity_handler
self.registration_handler = hs.get_registration_handler() self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
@ -231,8 +217,8 @@ class RoomMemberHandler(object):
self.copy_room_tags_and_direct_to_room( self.copy_room_tags_and_direct_to_room(
predecessor["room_id"], room_id, user_id predecessor["room_id"], room_id, user_id
) )
# Move over old push rules # Copy over push rules
self.store.move_push_rules_from_room_to_room_for_user( yield self.store.copy_push_rules_from_room_to_room_for_user(
predecessor["room_id"], room_id, user_id predecessor["room_id"], room_id, user_id
) )
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
@ -702,7 +688,9 @@ class RoomMemberHandler(object):
403, "Looking up third-party identifiers is denied from this server" 403, "Looking up third-party identifiers is denied from this server"
) )
invitee = yield self._lookup_3pid(id_server, medium, address, id_access_token) invitee = yield self.identity_handler.lookup_3pid(
id_server, medium, address, id_access_token
)
if invitee: if invitee:
yield self.update_membership( yield self.update_membership(
@ -720,211 +708,6 @@ class RoomMemberHandler(object):
id_access_token=id_access_token, id_access_token=id_access_token,
) )
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address, id_access_token=None):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
id_access_token (str|None): The access token to authenticate to the identity
server with
Returns:
str|None: the matrix ID of the 3pid, or None if it is not recognized.
"""
if id_access_token is not None:
try:
results = yield self._lookup_3pid_v2(
id_server, id_access_token, medium, address
)
return results
except Exception as e:
# Catch HttpResponseExcept for a non-200 response code
# Check if this identity server does not know about v2 lookups
if isinstance(e, HttpResponseException) and e.code == 404:
# This is an old identity server that does not yet support v2 lookups
logger.warning(
"Attempted v2 lookup on v1 identity server %s. Falling "
"back to v1",
id_server,
)
else:
logger.warning("Error when looking up hashing details: %s", e)
return None
return (yield self._lookup_3pid_v1(id_server, medium, address))
@defer.inlineCallbacks
def _lookup_3pid_v1(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server using v1 lookup.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
{"medium": medium, "address": address},
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
yield self._verify_any_signature(data, id_server)
return data["mxid"]
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except IOError as e:
logger.warning("Error from v1 identity server lookup: %s" % (e,))
return None
@defer.inlineCallbacks
def _lookup_3pid_v2(self, id_server, id_access_token, medium, address):
"""Looks up a 3pid in the passed identity server using v2 lookup.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
id_access_token (str): The access token to authenticate to the identity server with
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised.
"""
# Check what hashing details are supported by this identity server
try:
hash_details = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/v2/hash_details" % (id_server_scheme, id_server),
{"access_token": id_access_token},
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
if not isinstance(hash_details, dict):
logger.warning(
"Got non-dict object when checking hash details of %s%s: %s",
id_server_scheme,
id_server,
hash_details,
)
raise SynapseError(
400,
"Non-dict object from %s%s during v2 hash_details request: %s"
% (id_server_scheme, id_server, hash_details),
)
# Extract information from hash_details
supported_lookup_algorithms = hash_details.get("algorithms")
lookup_pepper = hash_details.get("lookup_pepper")
if (
not supported_lookup_algorithms
or not isinstance(supported_lookup_algorithms, list)
or not lookup_pepper
or not isinstance(lookup_pepper, str)
):
raise SynapseError(
400,
"Invalid hash details received from identity server %s%s: %s"
% (id_server_scheme, id_server, hash_details),
)
# Check if any of the supported lookup algorithms are present
if LookupAlgorithm.SHA256 in supported_lookup_algorithms:
# Perform a hashed lookup
lookup_algorithm = LookupAlgorithm.SHA256
# Hash address, medium and the pepper with sha256
to_hash = "%s %s %s" % (address, medium, lookup_pepper)
lookup_value = sha256_and_url_safe_base64(to_hash)
elif LookupAlgorithm.NONE in supported_lookup_algorithms:
# Perform a non-hashed lookup
lookup_algorithm = LookupAlgorithm.NONE
# Combine together plaintext address and medium
lookup_value = "%s %s" % (address, medium)
else:
logger.warning(
"None of the provided lookup algorithms of %s are supported: %s",
id_server,
supported_lookup_algorithms,
)
raise SynapseError(
400,
"Provided identity server does not support any v2 lookup "
"algorithms that this homeserver supports.",
)
# Authenticate with identity server given the access token from the client
headers = {"Authorization": create_id_access_token_header(id_access_token)}
try:
lookup_results = yield self.simple_http_client.post_json_get_json(
"%s%s/_matrix/identity/v2/lookup" % (id_server_scheme, id_server),
{
"addresses": [lookup_value],
"algorithm": lookup_algorithm,
"pepper": lookup_pepper,
},
headers=headers,
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except Exception as e:
logger.warning("Error when performing a v2 3pid lookup: %s", e)
raise SynapseError(
500, "Unknown error occurred during identity server lookup"
)
# Check for a mapping from what we looked up to an MXID
if "mappings" not in lookup_results or not isinstance(
lookup_results["mappings"], dict
):
logger.warning("No results from 3pid lookup")
return None
# Return the MXID if it's available, or None otherwise
mxid = lookup_results["mappings"].get(lookup_value)
return mxid
@defer.inlineCallbacks
def _verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
try:
key_data = yield self.simple_http_client.get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s"
% (id_server_scheme, server_hostname, key_name)
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
if "public_key" not in key_data:
raise AuthError(
401, "No public key named %s from %s" % (key_name, server_hostname)
)
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(
key_name, decode_base64(key_data["public_key"])
),
)
return
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_store_3pid_invite( def _make_and_store_3pid_invite(
self, self,
@ -971,7 +754,7 @@ class RoomMemberHandler(object):
room_avatar_url = room_avatar_event.content.get("url", "") room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = ( token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite( yield self.identity_handler.ask_id_server_for_third_party_invite(
requester=requester, requester=requester,
id_server=id_server, id_server=id_server,
medium=medium, medium=medium,
@ -1007,147 +790,6 @@ class RoomMemberHandler(object):
txn_id=txn_id, txn_id=txn_id,
) )
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
requester,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url,
id_access_token=None,
):
"""
Asks an identity server for a third party invite.
Args:
requester (Requester)
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
room_id (str): The ID of the room to which the user is invited.
inviter_user_id (str): The user ID of the inviter.
room_alias (str): An alias for the room, for cosmetic notifications.
room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
room_join_rules (str): The join rules of the email (e.g. "public").
room_name (str): The m.room.name of the room.
inviter_display_name (str): The current display name of the
inviter.
inviter_avatar_url (str): The URL of the inviter's avatar.
id_access_token (str|None): The access token to authenticate to the identity
server with
Returns:
A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
# Add the identity service access token to the JSON body and use the v2
# Identity Service endpoints if id_access_token is present
data = None
base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
if id_access_token:
key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
id_server_scheme,
id_server,
)
# Attempt a v2 lookup
url = base_url + "/v2/store-invite"
try:
data = yield self.simple_http_client.post_json_get_json(
url,
invite_config,
{"Authorization": create_id_access_token_header(id_access_token)},
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
if e.code != 404:
logger.info("Failed to POST %s with JSON: %s", url, e)
raise e
if data is None:
key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme,
id_server,
)
url = base_url + "/api/v1/store-invite"
try:
data = yield self.simple_http_client.post_json_get_json(
url, invite_config
)
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
logger.warning(
"Error trying to call /store-invite on %s%s: %s",
id_server_scheme,
id_server,
e,
)
if data is None:
# Some identity servers may only support application/x-www-form-urlencoded
# types. This is especially true with old instances of Sydent, see
# https://github.com/matrix-org/sydent/pull/170
try:
data = yield self.simple_http_client.post_urlencoded_get_json(
url, invite_config
)
except HttpResponseException as e:
logger.warning(
"Error calling /store-invite on %s%s with fallback "
"encoding: %s",
id_server_scheme,
id_server,
e,
)
raise e
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": key_validity_url,
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
return token, public_keys, fallback_public_key, display_name
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids): def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very # Have we just created the room, and is this about to be the very

View File

@ -327,7 +327,7 @@ class SimpleHttpClient(object):
Args: Args:
uri (str): uri (str):
args (dict[str, str|List[str]]): query params args (dict[str, str|List[str]]): query params
headers (dict[str, List[str]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
@ -371,7 +371,7 @@ class SimpleHttpClient(object):
Args: Args:
uri (str): uri (str):
post_json (object): post_json (object):
headers (dict[str, List[str]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
@ -414,7 +414,7 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@ -438,7 +438,7 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@ -482,7 +482,7 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
headers (dict[str, List[str]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@ -516,7 +516,7 @@ class SimpleHttpClient(object):
Args: Args:
url (str): The URL to GET url (str): The URL to GET
output_stream (file): File to write the response body to. output_stream (file): File to write the response body to.
headers (dict[str, List[str]]|None): If not None, a map from headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header header name to a list of values for that header
Returns: Returns:
A (int,dict,string,int) tuple of the file length, dict of the response A (int,dict,string,int) tuple of the file length, dict of the response

View File

@ -170,6 +170,7 @@ import inspect
import logging import logging
import re import re
from functools import wraps from functools import wraps
from typing import Dict
from canonicaljson import json from canonicaljson import json
@ -547,7 +548,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
return return
span = opentracing.tracer.active_span span = opentracing.tracer.active_span
carrier = {} carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items(): for key, value in carrier.items():
@ -584,7 +585,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
span = opentracing.tracer.active_span span = opentracing.tracer.active_span
carrier = {} carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items(): for key, value in carrier.items():
@ -639,7 +640,7 @@ def get_active_span_text_map(destination=None):
if destination and not whitelisted_homeserver(destination): if destination and not whitelisted_homeserver(destination):
return {} return {}
carrier = {} carrier = {} # type: Dict[str, str]
opentracing.tracer.inject( opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
) )
@ -653,7 +654,7 @@ def active_span_context_as_string():
Returns: Returns:
The active span context encoded as a string. The active span context encoded as a string.
""" """
carrier = {} carrier = {} # type: Dict[str, str]
if opentracing: if opentracing:
opentracing.tracer.inject( opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier

View File

@ -119,7 +119,11 @@ def trace_function(f):
logger = logging.getLogger(name) logger = logging.getLogger(name)
level = logging.DEBUG level = logging.DEBUG
s = inspect.currentframe().f_back frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back
to_print = [ to_print = [
"\t%s:%s %s. Args: args=%s, kwargs=%s" "\t%s:%s %s. Args: args=%s, kwargs=%s"
@ -144,7 +148,7 @@ def trace_function(f):
pathname=pathname, pathname=pathname,
lineno=lineno, lineno=lineno,
msg=msg, msg=msg,
args=None, args=tuple(),
exc_info=None, exc_info=None,
) )
@ -157,7 +161,12 @@ def trace_function(f):
def get_previous_frames(): def get_previous_frames():
s = inspect.currentframe().f_back.f_back
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back
to_return = [] to_return = []
while s: while s:
if s.f_globals["__name__"].startswith("synapse"): if s.f_globals["__name__"].startswith("synapse"):
@ -174,7 +183,10 @@ def get_previous_frames():
def get_previous_frame(ignore=[]): def get_previous_frame(ignore=[]):
s = inspect.currentframe().f_back.f_back frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back
while s: while s:
if s.f_globals["__name__"].startswith("synapse"): if s.f_globals["__name__"].startswith("synapse"):

View File

@ -125,7 +125,7 @@ class InFlightGauge(object):
) )
# Counts number of in flight blocks for a given set of label values # Counts number of in flight blocks for a given set of label values
self._registrations = {} self._registrations = {} # type: Dict
# Protects access to _registrations # Protects access to _registrations
self._lock = threading.Lock() self._lock = threading.Lock()
@ -226,7 +226,7 @@ class BucketCollector(object):
# Fetch the data -- this must be synchronous! # Fetch the data -- this must be synchronous!
data = self.data_collector() data = self.data_collector()
buckets = {} buckets = {} # type: Dict[float, int]
res = [] res = []
for x in data.keys(): for x in data.keys():

View File

@ -36,9 +36,9 @@ from twisted.web.resource import Resource
try: try:
from prometheus_client.samples import Sample from prometheus_client.samples import Sample
except ImportError: except ImportError:
Sample = namedtuple( Sample = namedtuple( # type: ignore[no-redef] # noqa
"Sample", ["name", "labels", "value", "timestamp", "exemplar"] "Sample", ["name", "labels", "value", "timestamp", "exemplar"]
) # type: ignore )
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Set from typing import List, Set
from pkg_resources import ( from pkg_resources import (
DistributionNotFound, DistributionNotFound,
@ -73,6 +73,7 @@ REQUIREMENTS = [
"netaddr>=0.7.18", "netaddr>=0.7.18",
"Jinja2>=2.9", "Jinja2>=2.9",
"bleach>=1.4.3", "bleach>=1.4.3",
"typing-extensions>=3.7.4",
] ]
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
@ -144,7 +145,11 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency) deps_needed.append(dependency)
errors.append( errors.append(
"Needed %s, got %s==%s" "Needed %s, got %s==%s"
% (dependency, e.dist.project_name, e.dist.version) % (
dependency,
e.dist.project_name, # type: ignore[attr-defined] # noqa
e.dist.version, # type: ignore[attr-defined] # noqa
)
) )
except DistributionNotFound: except DistributionNotFound:
deps_needed.append(dependency) deps_needed.append(dependency)
@ -159,7 +164,7 @@ def check_requirements(for_feature=None):
if not for_feature: if not for_feature:
# Check the optional dependencies are up to date. We allow them to not be # Check the optional dependencies are up to date. We allow them to not be
# installed. # installed.
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str]
for dependency in OPTS: for dependency in OPTS:
try: try:
@ -168,7 +173,11 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency) deps_needed.append(dependency)
errors.append( errors.append(
"Needed optional %s, got %s==%s" "Needed optional %s, got %s==%s"
% (dependency, e.dist.project_name, e.dist.version) % (
dependency,
e.dist.project_name, # type: ignore[attr-defined] # noqa
e.dist.version, # type: ignore[attr-defined] # noqa
)
) )
except DistributionNotFound: except DistributionNotFound:
# If it's not found, we don't care # If it's not found, we don't care

View File

@ -39,6 +39,7 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
) )
from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
@ -81,6 +82,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
) )
def on_PUT(self, request, txn_id): def on_PUT(self, request, txn_id):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request) return self.txns.fetch_or_execute_request(request, self.on_POST, request)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -181,6 +183,9 @@ class RoomStateEventRestServlet(TransactionRestServlet):
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
if txn_id:
set_tag("txn_id", txn_id)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
event_dict = { event_dict = {
@ -209,6 +214,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
ret = {} ret = {}
if event: if event:
set_tag("event_id", event.event_id)
ret = {"event_id": event.event_id} ret = {"event_id": event.event_id}
return 200, ret return 200, ret
@ -244,12 +250,15 @@ class RoomSendEventRestServlet(TransactionRestServlet):
requester, event_dict, txn_id=txn_id requester, event_dict, txn_id=txn_id
) )
set_tag("event_id", event.event_id)
return 200, {"event_id": event.event_id} return 200, {"event_id": event.event_id}
def on_GET(self, request, room_id, event_type, txn_id): def on_GET(self, request, room_id, event_type, txn_id):
return 200, "Not implemented" return 200, "Not implemented"
def on_PUT(self, request, room_id, event_type, txn_id): def on_PUT(self, request, room_id, event_type, txn_id):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id request, self.on_POST, request, room_id, event_type, txn_id
) )
@ -310,6 +319,8 @@ class JoinRoomAliasServlet(TransactionRestServlet):
return 200, {"room_id": room_id} return 200, {"room_id": room_id}
def on_PUT(self, request, room_identifier, txn_id): def on_PUT(self, request, room_identifier, txn_id):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id request, self.on_POST, request, room_identifier, txn_id
) )
@ -350,6 +361,10 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = parse_integer(request, "limit", 0) limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None) since_token = parse_string(request, "since", None)
if limit == 0:
# zero is a special value which corresponds to no limit.
limit = None
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server: if server:
data = yield handler.get_remote_public_room_list( data = yield handler.get_remote_public_room_list(
@ -387,6 +402,10 @@ class PublicRoomListRestServlet(TransactionRestServlet):
else: else:
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id) network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
if limit == 0:
# zero is a special value which corresponds to no limit.
limit = None
handler = self.hs.get_room_list_handler() handler = self.hs.get_room_list_handler()
if server: if server:
data = yield handler.get_remote_public_room_list( data = yield handler.get_remote_public_room_list(
@ -655,6 +674,8 @@ class RoomForgetRestServlet(TransactionRestServlet):
return 200, {} return 200, {}
def on_PUT(self, request, room_id, txn_id): def on_PUT(self, request, room_id, txn_id):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id request, self.on_POST, request, room_id, txn_id
) )
@ -738,6 +759,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return True return True
def on_PUT(self, request, room_id, membership_action, txn_id): def on_PUT(self, request, room_id, membership_action, txn_id):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id request, self.on_POST, request, room_id, membership_action, txn_id
) )
@ -771,9 +794,12 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
txn_id=txn_id, txn_id=txn_id,
) )
set_tag("event_id", event.event_id)
return 200, {"event_id": event.event_id} return 200, {"event_id": event.event_id}
def on_PUT(self, request, room_id, event_id, txn_id): def on_PUT(self, request, room_id, event_id, txn_id):
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request( return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id request, self.on_POST, request, room_id, event_id, txn_id
) )

View File

@ -129,66 +129,6 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
return 200, ret return 200, ret
class MsisdnPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/msisdn/requestToken$")
def __init__(self, hs):
super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.datastore = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_dict(
body, ["client_secret", "country", "phone_number", "send_attempt"]
)
client_secret = body["client_secret"]
country = body["country"]
phone_number = body["phone_number"]
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
msisdn = phone_number_to_msisdn(country, phone_number)
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
existing_user_id = yield self.datastore.get_user_id_by_threepid(
"msisdn", msisdn
)
if existing_user_id is None:
raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
if not self.hs.config.account_threepid_delegate_msisdn:
logger.warn(
"No upstream msisdn account_threepid_delegate configured on the server to "
"handle this request"
)
raise SynapseError(
400,
"Password reset by phone number is not supported on this homeserver",
)
ret = yield self.identity_handler.requestMsisdnToken(
self.hs.config.account_threepid_delegate_msisdn,
country,
phone_number,
client_secret,
send_attempt,
next_link,
)
return 200, ret
class PasswordResetSubmitTokenServlet(RestServlet): class PasswordResetSubmitTokenServlet(RestServlet):
"""Handles 3PID validation token submission""" """Handles 3PID validation token submission"""
@ -301,9 +241,7 @@ class PasswordRestServlet(RestServlet):
else: else:
requester = None requester = None
result, params, _ = yield self.auth_handler.check_auth( result, params, _ = yield self.auth_handler.check_auth(
[[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
body,
self.hs.get_ip_from_request(request),
) )
if LoginType.EMAIL_IDENTITY in result: if LoginType.EMAIL_IDENTITY in result:
@ -843,7 +781,6 @@ class WhoamiRestServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server) EmailPasswordRequestTokenRestServlet(hs).register(http_server)
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordResetSubmitTokenServlet(hs).register(http_server) PasswordResetSubmitTokenServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server)

View File

@ -183,8 +183,8 @@ class PushRulesWorkerStore(
return results return results
@defer.inlineCallbacks @defer.inlineCallbacks
def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule): def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
"""Move a single push rule from one room to another for a specific user. """Copy a single push rule from one room to another for a specific user.
Args: Args:
new_room_id (str): ID of the new room. new_room_id (str): ID of the new room.
@ -209,14 +209,11 @@ class PushRulesWorkerStore(
actions=rule["actions"], actions=rule["actions"],
) )
# Delete push rule for the old room
yield self.delete_push_rule(user_id, rule["rule_id"])
@defer.inlineCallbacks @defer.inlineCallbacks
def move_push_rules_from_room_to_room_for_user( def copy_push_rules_from_room_to_room_for_user(
self, old_room_id, new_room_id, user_id self, old_room_id, new_room_id, user_id
): ):
"""Move all of the push rules from one room to another for a specific """Copy all of the push rules from one room to another for a specific
user. user.
Args: Args:
@ -227,15 +224,14 @@ class PushRulesWorkerStore(
# Retrieve push rules for this user # Retrieve push rules for this user
user_push_rules = yield self.get_push_rules_for_user(user_id) user_push_rules = yield self.get_push_rules_for_user(user_id)
# Get rules relating to the old room, move them to the new room, then # Get rules relating to the old room and copy them to the new room
# delete them from the old room
for rule in user_push_rules: for rule in user_push_rules:
conditions = rule.get("conditions", []) conditions = rule.get("conditions", [])
if any( if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id) (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions for c in conditions
): ):
self.move_push_rule_from_room_to_room(new_room_id, user_id, rule) yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@defer.inlineCallbacks @defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context): def bulk_get_push_rules_for_room(self, event, context):

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -63,103 +64,176 @@ class RoomWorkerStore(SQLBaseStore):
desc="get_public_room_ids", desc="get_public_room_ids",
) )
@cached(num_args=2, max_entries=100) def count_public_rooms(self, network_tuple, ignore_non_federatable):
def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): """Counts the number of public rooms as tracked in the room_stats_current
"""Get pulbic rooms for a particular list, or across all lists. and room_stats_state table.
Args: Args:
stream_id (int) network_tuple (ThirdPartyInstanceID|None)
network_tuple (ThirdPartyInstanceID): The list to use (None, None) ignore_non_federatable (bool): If true filters out non-federatable rooms
means the main list, None means all lsits.
""" """
return self.runInteraction(
"get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn,
stream_id,
network_tuple=network_tuple,
)
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple): def _count_public_rooms_txn(txn):
return { query_args = []
rm
for rm, vis in self.get_published_at_stream_id_txn( if network_tuple:
txn, stream_id, network_tuple=network_tuple if network_tuple.appservice_id:
).items() published_sql = """
if vis SELECT room_id from appservice_room_list
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
query_args.append(network_tuple.network_id)
else:
published_sql = """
SELECT room_id FROM rooms WHERE is_public
"""
else:
published_sql = """
SELECT room_id FROM rooms WHERE is_public
UNION SELECT room_id from appservice_room_list
"""
sql = """
SELECT
COALESCE(COUNT(*), 0)
FROM (
%(published_sql)s
) published
INNER JOIN room_stats_state USING (room_id)
INNER JOIN room_stats_current USING (room_id)
WHERE
(
join_rules = 'public' OR history_visibility = 'world_readable'
)
AND joined_members > 0
""" % {
"published_sql": published_sql
} }
def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): txn.execute(sql, query_args)
if network_tuple: return txn.fetchone()[0]
# We want to get from a particular list. No aggregation required.
return self.runInteraction("count_public_rooms", _count_public_rooms_txn)
@defer.inlineCallbacks
def get_largest_public_rooms(
self,
network_tuple,
search_filter,
limit,
last_room_id,
forwards,
ignore_non_federatable=False,
):
"""Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table).
Args:
network_tuple (ThirdPartyInstanceID|None):
search_filter (dict|None):
limit (int|None): Maxmimum number of rows to return, unlimited otherwise.
last_room_id (str|None): if present, a room ID which bounds the
result set, and is always *excluded* from the result set.
forwards (bool): true iff going forwards, going backwards otherwise
ignore_non_federatable (bool): If true filters out non-federatable rooms.
Returns:
Rooms in order: biggest number of joined users first.
We then arbitrarily use the room_id as a tie breaker.
sql = """
SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id
FROM public_room_list_stream
WHERE stream_id <= ? %s
GROUP BY room_id
) grouped USING (room_id, stream_id)
""" """
if network_tuple.appservice_id is not None: where_clauses = []
txn.execute( query_args = []
sql % ("AND appservice_id = ? AND network_id = ?",),
(stream_id, network_tuple.appservice_id, network_tuple.network_id), if last_room_id:
if forwards:
where_clauses.append("room_id < ?")
else:
where_clauses.append("? < room_id")
query_args += [last_room_id]
if search_filter and search_filter.get("generic_search_term", None):
search_term = "%" + search_filter["generic_search_term"] + "%"
where_clauses.append(
"""
(
name LIKE ?
OR topic LIKE ?
OR canonical_alias LIKE ?
) )
else: """
txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,)) )
return dict(txn) query_args += [search_term, search_term, search_term]
else:
# We want to get from all lists, so we need to aggregate the results
logger.info("Executing full list") if network_tuple:
if network_tuple.appservice_id:
sql = """ published_sql = """
SELECT room_id, visibility SELECT room_id from appservice_room_list
FROM public_room_list_stream WHERE appservice_id = ? AND network_id = ?
INNER JOIN ( """
SELECT query_args.append(network_tuple.appservice_id)
room_id, max(stream_id) AS stream_id, appservice_id, query_args.append(network_tuple.network_id)
network_id else:
FROM public_room_list_stream published_sql = """
WHERE stream_id <= ? SELECT room_id FROM rooms WHERE is_public
GROUP BY room_id, appservice_id, network_id """
) grouped USING (room_id, stream_id) else:
published_sql = """
SELECT room_id FROM rooms WHERE is_public
UNION SELECT room_id from appservice_room_list
""" """
txn.execute(sql, (stream_id,)) where_clause = ""
if where_clauses:
where_clause = " AND " + " AND ".join(where_clauses)
results = {} sql = """
# A room is visible if its visible on any list. SELECT
for room_id, visibility in txn: room_id, name, topic, canonical_alias, joined_members,
results[room_id] = bool(visibility) or results.get(room_id, False) avatar, history_visibility, joined_members, guest_access
FROM (
%(published_sql)s
) published
INNER JOIN room_stats_state USING (room_id)
INNER JOIN room_stats_current USING (room_id)
WHERE
(
join_rules = 'public' OR history_visibility = 'world_readable'
)
AND joined_members > 0
%(where_clause)s
ORDER BY joined_members %(dir)s, room_id %(dir)s
""" % {
"published_sql": published_sql,
"where_clause": where_clause,
"dir": "DESC" if forwards else "ASC",
}
if limit is not None:
query_args.append(limit)
sql += """
LIMIT ?
"""
def _get_largest_public_rooms_txn(txn):
txn.execute(sql, query_args)
results = self.cursor_to_dict(txn)
if not forwards:
results.reverse()
return results return results
def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple): ret_val = yield self.runInteraction(
def get_public_room_changes_txn(txn): "get_largest_public_rooms", _get_largest_public_rooms_txn
then_rooms = self.get_public_room_ids_at_stream_id_txn(
txn, prev_stream_id, network_tuple
)
now_rooms_dict = self.get_published_at_stream_id_txn(
txn, new_stream_id, network_tuple
)
now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
now_rooms_not_visible = set(
rm for rm, vis in now_rooms_dict.items() if not vis
)
newly_visible = now_rooms_visible - then_rooms
newly_unpublished = now_rooms_not_visible & then_rooms
return newly_visible, newly_unpublished
return self.runInteraction(
"get_public_room_changes", get_public_room_changes_txn
) )
defer.returnValue(ret_val)
@cached(max_entries=10000) @cached(max_entries=10000)
def is_room_blocked(self, room_id): def is_room_blocked(self, room_id):

View File

@ -0,0 +1,20 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- these tables are never used.
DROP TABLE IF EXISTS room_names;
DROP TABLE IF EXISTS topics;
DROP TABLE IF EXISTS history_visibility;
DROP TABLE IF EXISTS guest_access;

View File

@ -0,0 +1,16 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id);

View File

@ -318,6 +318,7 @@ class StreamToken(
) )
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
START = None # type: StreamToken
@classmethod @classmethod
def from_string(cls, string): def from_string(cls, string):
@ -402,7 +403,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
followed by the "stream_ordering" id of the event it comes after. followed by the "stream_ordering" id of the event it comes after.
""" """
__slots__ = [] __slots__ = [] # type: list
@classmethod @classmethod
def parse(cls, string): def parse(cls, string):

View File

@ -13,9 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections import collections
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union
from six.moves import range from six.moves import range
@ -213,7 +215,9 @@ class Linearizer(object):
# the first element is the number of things executing, and # the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the # the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing. # things blocked from executing.
self.key_to_defer = {} self.key_to_defer = (
{}
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
def queue(self, key): def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@ -340,10 +344,10 @@ class ReadWriteLock(object):
def __init__(self): def __init__(self):
# Latest readers queued # Latest readers queued
self.key_to_current_readers = {} self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
# Latest writer queued # Latest writer queued
self.key_to_current_writer = {} self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks @defer.inlineCallbacks
def read(self, key): def read(self, key):

View File

@ -16,6 +16,7 @@
import logging import logging
import os import os
from typing import Dict
import six import six
from six.moves import intern from six.moves import intern
@ -37,7 +38,7 @@ def get_cache_factor_for(cache_name):
caches_by_name = {} caches_by_name = {}
collectors_by_name = {} collectors_by_name = {} # type: Dict
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])

View File

@ -18,10 +18,12 @@ import inspect
import logging import logging
import threading import threading
from collections import namedtuple from collections import namedtuple
from typing import Any, cast
from six import itervalues from six import itervalues
from prometheus_client import Gauge from prometheus_client import Gauge
from typing_extensions import Protocol
from twisted.internet import defer from twisted.internet import defer
@ -37,6 +39,18 @@ from . import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _CachedFunction(Protocol):
invalidate = None # type: Any
invalidate_all = None # type: Any
invalidate_many = None # type: Any
prefill = None # type: Any
cache = None # type: Any
num_args = None # type: Any
def __name__(self):
...
cache_pending_metric = Gauge( cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending", "synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache", "Number of lookups currently pending for this cache",
@ -245,7 +259,9 @@ class Cache(object):
class _CacheDescriptorBase(object): class _CacheDescriptorBase(object):
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): def __init__(
self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
):
self.orig = orig self.orig = orig
if inlineCallbacks: if inlineCallbacks:
@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return tuple(get_cache_key_gen(args, kwargs)) return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated # whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(observer) return make_deferred_yieldable(observer)
wrapped = cast(_CachedFunction, _wrapped)
if self.num_args == 1: if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0]) wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val) wrapped.prefill = lambda key, val: cache.prefill(key[0], val)

View File

@ -1,3 +1,5 @@
from typing import Dict
from six import itervalues from six import itervalues
SENTINEL = object() SENTINEL = object()
@ -12,7 +14,7 @@ class TreeCache(object):
def __init__(self): def __init__(self):
self.size = 0 self.size = 0
self.root = {} self.root = {} # type: Dict
def __setitem__(self, key, value): def __setitem__(self, key, value):
return self.set(key, value) return self.set(key, value)

View File

@ -54,5 +54,5 @@ def load_python_module(location: str):
if spec is None: if spec is None:
raise Exception("Unable to load module at %s" % (location,)) raise Exception("Unable to load module at %s" % (location,))
mod = importlib.util.module_from_spec(spec) mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) spec.loader.exec_module(mod) # type: ignore
return mod return mod

View File

@ -1,39 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.handlers.room_list import RoomListNextBatch
import tests.unittest
import tests.utils
class RoomListTestCase(tests.unittest.TestCase):
""" Tests RoomList's RoomListNextBatch. """
def setUp(self):
pass
def test_check_read_batch_tokens(self):
batch_token = RoomListNextBatch(
stream_ordering="abcdef",
public_room_stream_id="123",
current_limit=20,
direction_is_forward=True,
).to_token()
next_batch = RoomListNextBatch.from_token(batch_token)
self.assertEquals(next_batch.stream_ordering, "abcdef")
self.assertEquals(next_batch.public_room_stream_id, "123")
self.assertEquals(next_batch.current_limit, 20)
self.assertEquals(next_batch.direction_is_forward, True)

View File

@ -23,8 +23,8 @@ from email.parser import Parser
import pkg_resources import pkg_resources
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import LoginType from synapse.api.constants import LoginType, Membership
from synapse.rest.client.v1 import login from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import account, register from synapse.rest.client.v2_alpha import account, register
from tests import unittest from tests import unittest
@ -244,16 +244,66 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource, synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets, login.register_servlets,
account.register_servlets, account.register_servlets,
room.register_servlets,
] ]
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver() self.hs = self.setup_test_homeserver()
return hs return self.hs
def test_deactivate_account(self): def test_deactivate_account(self):
user_id = self.register_user("kermit", "test") user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test") tok = self.login("kermit", "test")
self.deactivate(user_id, tok)
store = self.hs.get_datastore()
# Check that the user has been marked as deactivated.
self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
# Check that this access token has been invalidated.
request, channel = self.make_request("GET", "account/whoami")
self.render(request)
self.assertEqual(request.code, 401)
@unittest.INFO
def test_pending_invites(self):
"""Tests that deactivating a user rejects every pending invite for them."""
store = self.hs.get_datastore()
inviter_id = self.register_user("inviter", "test")
inviter_tok = self.login("inviter", "test")
invitee_id = self.register_user("invitee", "test")
invitee_tok = self.login("invitee", "test")
# Make @inviter:test invite @invitee:test in a new room.
room_id = self.helper.create_room_as(inviter_id, tok=inviter_tok)
self.helper.invite(
room=room_id, src=inviter_id, targ=invitee_id, tok=inviter_tok
)
# Make sure the invite is here.
pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
self.assertEqual(len(pending_invites), 1, pending_invites)
self.assertEqual(pending_invites[0].room_id, room_id, pending_invites)
# Deactivate @invitee:test.
self.deactivate(invitee_id, invitee_tok)
# Check that the invite isn't there anymore.
pending_invites = self.get_success(store.get_invited_rooms_for_user(invitee_id))
self.assertEqual(len(pending_invites), 0, pending_invites)
# Check that the membership of @invitee:test in the room is now "leave".
memberships = self.get_success(
store.get_rooms_for_user_where_membership_is(invitee_id, [Membership.LEAVE])
)
self.assertEqual(len(memberships), 1, memberships)
self.assertEqual(memberships[0].room_id, room_id, memberships)
def deactivate(self, user_id, tok):
request_data = json.dumps( request_data = json.dumps(
{ {
"auth": { "auth": {
@ -269,13 +319,3 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
) )
self.render(request) self.render(request)
self.assertEqual(request.code, 200) self.assertEqual(request.code, 200)
store = self.hs.get_datastore()
# Check that the user has been marked as deactivated.
self.assertTrue(self.get_success(store.get_user_deactivated_status(user_id)))
# Check that this access token has been invalidated.
request, channel = self.make_request("GET", "account/whoami")
self.render(request)
self.assertEqual(request.code, 401)