Merge branch 'develop' into babolivier/msc3026
commit
592d6305fd
|
@ -0,0 +1 @@
|
||||||
|
Checks if passwords are allowed before setting it for the user.
|
|
@ -0,0 +1 @@
|
||||||
|
Improve performance of federation catch up by sending events the latest events in the room to the remote, rather than just the last event sent by the local server.
|
|
@ -0,0 +1 @@
|
||||||
|
Add initial experimental support for a "space summary" API.
|
|
@ -0,0 +1 @@
|
||||||
|
In the `federation_client` commandline client, stop automatically adding the URL prefix, so that servlets on other prefixes can be tested.
|
|
@ -0,0 +1 @@
|
||||||
|
In the `federation_client` commandline client, handle inline `signing_key`s in `homeserver.yaml`.
|
|
@ -22,8 +22,8 @@ import sys
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from urllib import parse as urlparse
|
from urllib import parse as urlparse
|
||||||
|
|
||||||
import nacl.signing
|
|
||||||
import requests
|
import requests
|
||||||
|
import signedjson.key
|
||||||
import signedjson.types
|
import signedjson.types
|
||||||
import srvlookup
|
import srvlookup
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -44,18 +44,6 @@ def encode_base64(input_bytes):
|
||||||
return output_string
|
return output_string
|
||||||
|
|
||||||
|
|
||||||
def decode_base64(input_string):
|
|
||||||
"""Decode a base64 string to bytes inferring padding from the length of the
|
|
||||||
string."""
|
|
||||||
|
|
||||||
input_bytes = input_string.encode("ascii")
|
|
||||||
input_len = len(input_bytes)
|
|
||||||
padding = b"=" * (3 - ((input_len + 3) % 4))
|
|
||||||
output_len = 3 * ((input_len + 2) // 4) + (input_len + 2) % 4 - 2
|
|
||||||
output_bytes = base64.b64decode(input_bytes + padding)
|
|
||||||
return output_bytes[:output_len]
|
|
||||||
|
|
||||||
|
|
||||||
def encode_canonical_json(value):
|
def encode_canonical_json(value):
|
||||||
return json.dumps(
|
return json.dumps(
|
||||||
value,
|
value,
|
||||||
|
@ -88,42 +76,6 @@ def sign_json(
|
||||||
return json_object
|
return json_object
|
||||||
|
|
||||||
|
|
||||||
NACL_ED25519 = "ed25519"
|
|
||||||
|
|
||||||
|
|
||||||
def decode_signing_key_base64(algorithm, version, key_base64):
|
|
||||||
"""Decode a base64 encoded signing key
|
|
||||||
Args:
|
|
||||||
algorithm (str): The algorithm the key is for (currently "ed25519").
|
|
||||||
version (str): Identifies this key out of the keys for this entity.
|
|
||||||
key_base64 (str): Base64 encoded bytes of the key.
|
|
||||||
Returns:
|
|
||||||
A SigningKey object.
|
|
||||||
"""
|
|
||||||
if algorithm == NACL_ED25519:
|
|
||||||
key_bytes = decode_base64(key_base64)
|
|
||||||
key = nacl.signing.SigningKey(key_bytes)
|
|
||||||
key.version = version
|
|
||||||
key.alg = NACL_ED25519
|
|
||||||
return key
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported algorithm %s" % (algorithm,))
|
|
||||||
|
|
||||||
|
|
||||||
def read_signing_keys(stream):
|
|
||||||
"""Reads a list of keys from a stream
|
|
||||||
Args:
|
|
||||||
stream : A stream to iterate for keys.
|
|
||||||
Returns:
|
|
||||||
list of SigningKey objects.
|
|
||||||
"""
|
|
||||||
keys = []
|
|
||||||
for line in stream:
|
|
||||||
algorithm, version, key_base64 = line.split()
|
|
||||||
keys.append(decode_signing_key_base64(algorithm, version, key_base64))
|
|
||||||
return keys
|
|
||||||
|
|
||||||
|
|
||||||
def request(
|
def request(
|
||||||
method: Optional[str],
|
method: Optional[str],
|
||||||
origin_name: str,
|
origin_name: str,
|
||||||
|
@ -223,23 +175,28 @@ def main():
|
||||||
parser.add_argument("--body", help="Data to send as the body of the HTTP request")
|
parser.add_argument("--body", help="Data to send as the body of the HTTP request")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"path", help="request path. We will add '/_matrix/federation/v1/' to this."
|
"path", help="request path, including the '/_matrix/federation/...' prefix."
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not args.server_name or not args.signing_key_path:
|
args.signing_key = None
|
||||||
|
if args.signing_key_path:
|
||||||
|
with open(args.signing_key_path) as f:
|
||||||
|
args.signing_key = f.readline()
|
||||||
|
|
||||||
|
if not args.server_name or not args.signing_key:
|
||||||
read_args_from_config(args)
|
read_args_from_config(args)
|
||||||
|
|
||||||
with open(args.signing_key_path) as f:
|
algorithm, version, key_base64 = args.signing_key.split()
|
||||||
key = read_signing_keys(f)[0]
|
key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64)
|
||||||
|
|
||||||
result = request(
|
result = request(
|
||||||
args.method,
|
args.method,
|
||||||
args.server_name,
|
args.server_name,
|
||||||
key,
|
key,
|
||||||
args.destination,
|
args.destination,
|
||||||
"/_matrix/federation/v1/" + args.path,
|
args.path,
|
||||||
content=args.body,
|
content=args.body,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -255,10 +212,16 @@ def main():
|
||||||
def read_args_from_config(args):
|
def read_args_from_config(args):
|
||||||
with open(args.config, "r") as fh:
|
with open(args.config, "r") as fh:
|
||||||
config = yaml.safe_load(fh)
|
config = yaml.safe_load(fh)
|
||||||
|
|
||||||
if not args.server_name:
|
if not args.server_name:
|
||||||
args.server_name = config["server_name"]
|
args.server_name = config["server_name"]
|
||||||
if not args.signing_key_path:
|
|
||||||
args.signing_key_path = config["signing_key_path"]
|
if not args.signing_key:
|
||||||
|
if "signing_key" in config:
|
||||||
|
args.signing_key = config["signing_key"]
|
||||||
|
else:
|
||||||
|
with open(config["signing_key_path"]) as f:
|
||||||
|
args.signing_key = f.readline()
|
||||||
|
|
||||||
|
|
||||||
class MatrixConnectionAdapter(HTTPAdapter):
|
class MatrixConnectionAdapter(HTTPAdapter):
|
||||||
|
|
|
@ -101,6 +101,9 @@ class EventTypes:
|
||||||
|
|
||||||
Dummy = "org.matrix.dummy_event"
|
Dummy = "org.matrix.dummy_event"
|
||||||
|
|
||||||
|
MSC1772_SPACE_CHILD = "org.matrix.msc1772.space.child"
|
||||||
|
MSC1772_SPACE_PARENT = "org.matrix.msc1772.space.parent"
|
||||||
|
|
||||||
|
|
||||||
class EduTypes:
|
class EduTypes:
|
||||||
Presence = "m.presence"
|
Presence = "m.presence"
|
||||||
|
@ -161,6 +164,9 @@ class EventContentFields:
|
||||||
# cf https://github.com/matrix-org/matrix-doc/pull/2228
|
# cf https://github.com/matrix-org/matrix-doc/pull/2228
|
||||||
SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
|
SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
|
||||||
|
|
||||||
|
# cf https://github.com/matrix-org/matrix-doc/pull/1772
|
||||||
|
MSC1772_ROOM_TYPE = "org.matrix.msc1772.type"
|
||||||
|
|
||||||
|
|
||||||
class RoomEncryptionAlgorithms:
|
class RoomEncryptionAlgorithms:
|
||||||
MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
|
MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
|
||||||
|
|
|
@ -27,5 +27,7 @@ class ExperimentalConfig(Config):
|
||||||
|
|
||||||
# MSC2858 (multiple SSO identity providers)
|
# MSC2858 (multiple SSO identity providers)
|
||||||
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
|
self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
|
||||||
|
# Spaces (MSC1772, MSC2946, etc)
|
||||||
|
self.spaces_enabled = experimental.get("spaces_enabled", False) # type: bool
|
||||||
# MSC3026 (busy presence state)
|
# MSC3026 (busy presence state)
|
||||||
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
|
self.msc3026_enabled = experimental.get("msc3026_enabled", False) # type: bool
|
||||||
|
|
|
@ -35,7 +35,7 @@ from twisted.internet import defer
|
||||||
from twisted.internet.abstract import isIPAddress
|
from twisted.internet.abstract import isIPAddress
|
||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
|
|
||||||
from synapse.api.constants import EduTypes, EventTypes, Membership
|
from synapse.api.constants import EduTypes, EventTypes
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
Codes,
|
Codes,
|
||||||
|
@ -63,7 +63,7 @@ from synapse.replication.http.federation import (
|
||||||
ReplicationFederationSendEduRestServlet,
|
ReplicationFederationSendEduRestServlet,
|
||||||
ReplicationGetQueryRestServlet,
|
ReplicationGetQueryRestServlet,
|
||||||
)
|
)
|
||||||
from synapse.types import JsonDict, get_domain_from_id
|
from synapse.types import JsonDict
|
||||||
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
||||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||||
from synapse.util.caches.response_cache import ResponseCache
|
from synapse.util.caches.response_cache import ResponseCache
|
||||||
|
@ -727,27 +727,6 @@ class FederationServer(FederationBase):
|
||||||
if the event was unacceptable for any other reason (eg, too large,
|
if the event was unacceptable for any other reason (eg, too large,
|
||||||
too many prev_events, couldn't find the prev_events)
|
too many prev_events, couldn't find the prev_events)
|
||||||
"""
|
"""
|
||||||
# check that it's actually being sent from a valid destination to
|
|
||||||
# workaround bug #1753 in 0.18.5 and 0.18.6
|
|
||||||
if origin != get_domain_from_id(pdu.sender):
|
|
||||||
# We continue to accept join events from any server; this is
|
|
||||||
# necessary for the federation join dance to work correctly.
|
|
||||||
# (When we join over federation, the "helper" server is
|
|
||||||
# responsible for sending out the join event, rather than the
|
|
||||||
# origin. See bug #1893. This is also true for some third party
|
|
||||||
# invites).
|
|
||||||
if not (
|
|
||||||
pdu.type == "m.room.member"
|
|
||||||
and pdu.content
|
|
||||||
and pdu.content.get("membership", None)
|
|
||||||
in (Membership.JOIN, Membership.INVITE)
|
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
"Discarding PDU %s from invalid origin %s", pdu.event_id, origin
|
|
||||||
)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
|
|
||||||
|
|
||||||
# We've already checked that we know the room version by this point
|
# We've already checked that we know the room version by this point
|
||||||
room_version = await self.store.get_room_version(pdu.room_id)
|
room_version = await self.store.get_room_version(pdu.room_id)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
|
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
|
@ -77,6 +77,7 @@ class PerDestinationQueue:
|
||||||
self._transaction_manager = transaction_manager
|
self._transaction_manager = transaction_manager
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
self._federation_shard_config = hs.config.worker.federation_shard_config
|
self._federation_shard_config = hs.config.worker.federation_shard_config
|
||||||
|
self._state = hs.get_state_handler()
|
||||||
|
|
||||||
self._should_send_on_this_instance = True
|
self._should_send_on_this_instance = True
|
||||||
if not self._federation_shard_config.should_handle(
|
if not self._federation_shard_config.should_handle(
|
||||||
|
@ -415,22 +416,95 @@ class PerDestinationQueue:
|
||||||
"This should not happen." % event_ids
|
"This should not happen." % event_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if logger.isEnabledFor(logging.INFO):
|
# We send transactions with events from one room only, as its likely
|
||||||
rooms = [p.room_id for p in catchup_pdus]
|
# that the remote will have to do additional processing, which may
|
||||||
logger.info("Catching up rooms to %s: %r", self._destination, rooms)
|
# take some time. It's better to give it small amounts of work
|
||||||
|
# rather than risk the request timing out and repeatedly being
|
||||||
|
# retried, and not making any progress.
|
||||||
|
#
|
||||||
|
# Note: `catchup_pdus` will have exactly one PDU per room.
|
||||||
|
for pdu in catchup_pdus:
|
||||||
|
# The PDU from the DB will be the last PDU in the room from
|
||||||
|
# *this server* that wasn't sent to the remote. However, other
|
||||||
|
# servers may have sent lots of events since then, and we want
|
||||||
|
# to try and tell the remote only about the *latest* events in
|
||||||
|
# the room. This is so that it doesn't get inundated by events
|
||||||
|
# from various parts of the DAG, which all need to be processed.
|
||||||
|
#
|
||||||
|
# Note: this does mean that in large rooms a server coming back
|
||||||
|
# online will get sent the same events from all the different
|
||||||
|
# servers, but the remote will correctly deduplicate them and
|
||||||
|
# handle it only once.
|
||||||
|
|
||||||
await self._transaction_manager.send_new_transaction(
|
# Step 1, fetch the current extremities
|
||||||
self._destination, catchup_pdus, []
|
extrems = await self._store.get_prev_events_for_room(pdu.room_id)
|
||||||
)
|
|
||||||
|
|
||||||
sent_transactions_counter.inc()
|
if pdu.event_id in extrems:
|
||||||
final_pdu = catchup_pdus[-1]
|
# If the event is in the extremities, then great! We can just
|
||||||
self._last_successful_stream_ordering = cast(
|
# use that without having to do further checks.
|
||||||
int, final_pdu.internal_metadata.stream_ordering
|
room_catchup_pdus = [pdu]
|
||||||
)
|
else:
|
||||||
await self._store.set_destination_last_successful_stream_ordering(
|
# If not, fetch the extremities and figure out which we can
|
||||||
self._destination, self._last_successful_stream_ordering
|
# send.
|
||||||
)
|
extrem_events = await self._store.get_events_as_list(extrems)
|
||||||
|
|
||||||
|
new_pdus = []
|
||||||
|
for p in extrem_events:
|
||||||
|
# We pulled this from the DB, so it'll be non-null
|
||||||
|
assert p.internal_metadata.stream_ordering
|
||||||
|
|
||||||
|
# Filter out events that happened before the remote went
|
||||||
|
# offline
|
||||||
|
if (
|
||||||
|
p.internal_metadata.stream_ordering
|
||||||
|
< self._last_successful_stream_ordering
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Filter out events where the server is not in the room,
|
||||||
|
# e.g. it may have left/been kicked. *Ideally* we'd pull
|
||||||
|
# out the kick and send that, but it's a rare edge case
|
||||||
|
# so we don't bother for now (the server that sent the
|
||||||
|
# kick should send it out if its online).
|
||||||
|
hosts = await self._state.get_hosts_in_room_at_events(
|
||||||
|
p.room_id, [p.event_id]
|
||||||
|
)
|
||||||
|
if self._destination not in hosts:
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_pdus.append(p)
|
||||||
|
|
||||||
|
# If we've filtered out all the extremities, fall back to
|
||||||
|
# sending the original event. This should ensure that the
|
||||||
|
# server gets at least some of missed events (especially if
|
||||||
|
# the other sending servers are up).
|
||||||
|
if new_pdus:
|
||||||
|
room_catchup_pdus = new_pdus
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Catching up rooms to %s: %r", self._destination, pdu.room_id
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._transaction_manager.send_new_transaction(
|
||||||
|
self._destination, room_catchup_pdus, []
|
||||||
|
)
|
||||||
|
|
||||||
|
sent_transactions_counter.inc()
|
||||||
|
|
||||||
|
# We pulled this from the DB, so it'll be non-null
|
||||||
|
assert pdu.internal_metadata.stream_ordering
|
||||||
|
|
||||||
|
# Note that we mark the last successful stream ordering as that
|
||||||
|
# from the *original* PDU, rather than the PDU(s) we actually
|
||||||
|
# send. This is because we use it to mark our position in the
|
||||||
|
# queue of missed PDUs to process.
|
||||||
|
self._last_successful_stream_ordering = (
|
||||||
|
pdu.internal_metadata.stream_ordering
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._store.set_destination_last_successful_stream_ordering(
|
||||||
|
self._destination, self._last_successful_stream_ordering
|
||||||
|
)
|
||||||
|
|
||||||
def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
|
def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
|
||||||
if not self._pending_rrs:
|
if not self._pending_rrs:
|
||||||
|
|
|
@ -41,7 +41,7 @@ class SetPasswordHandler(BaseHandler):
|
||||||
logout_devices: bool,
|
logout_devices: bool,
|
||||||
requester: Optional[Requester] = None,
|
requester: Optional[Requester] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.hs.config.password_localdb_enabled:
|
if not self._auth_handler.can_change_password():
|
||||||
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -0,0 +1,199 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
from collections import deque
|
||||||
|
from typing import TYPE_CHECKING, Iterable, List, Optional, Set
|
||||||
|
|
||||||
|
from synapse.api.constants import EventContentFields, EventTypes, HistoryVisibility
|
||||||
|
from synapse.api.errors import AuthError
|
||||||
|
from synapse.events import EventBase
|
||||||
|
from synapse.events.utils import format_event_for_client_v2
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# number of rooms to return. We'll stop once we hit this limit.
|
||||||
|
# TODO: allow clients to reduce this with a request param.
|
||||||
|
MAX_ROOMS = 50
|
||||||
|
|
||||||
|
# max number of events to return per room.
|
||||||
|
MAX_ROOMS_PER_SPACE = 50
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceSummaryHandler:
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self._clock = hs.get_clock()
|
||||||
|
self._auth = hs.get_auth()
|
||||||
|
self._room_list_handler = hs.get_room_list_handler()
|
||||||
|
self._state_handler = hs.get_state_handler()
|
||||||
|
self._store = hs.get_datastore()
|
||||||
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
|
|
||||||
|
async def get_space_summary(
|
||||||
|
self,
|
||||||
|
requester: str,
|
||||||
|
room_id: str,
|
||||||
|
suggested_only: bool = False,
|
||||||
|
max_rooms_per_space: Optional[int] = None,
|
||||||
|
) -> JsonDict:
|
||||||
|
"""
|
||||||
|
Implementation of the space summary API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requester: user id of the user making this request
|
||||||
|
|
||||||
|
room_id: room id to start the summary at
|
||||||
|
|
||||||
|
suggested_only: whether we should only return children with the "suggested"
|
||||||
|
flag set.
|
||||||
|
|
||||||
|
max_rooms_per_space: an optional limit on the number of child rooms we will
|
||||||
|
return. This does not apply to the root room (ie, room_id), and
|
||||||
|
is overridden by ROOMS_PER_SPACE_LIMIT.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
summary dict to return
|
||||||
|
"""
|
||||||
|
# first of all, check that the user is in the room in question (or it's
|
||||||
|
# world-readable)
|
||||||
|
await self._auth.check_user_in_room_or_world_readable(room_id, requester)
|
||||||
|
|
||||||
|
# the queue of rooms to process
|
||||||
|
room_queue = deque((room_id,))
|
||||||
|
|
||||||
|
processed_rooms = set() # type: Set[str]
|
||||||
|
|
||||||
|
rooms_result = [] # type: List[JsonDict]
|
||||||
|
events_result = [] # type: List[JsonDict]
|
||||||
|
|
||||||
|
now = self._clock.time_msec()
|
||||||
|
|
||||||
|
while room_queue and len(rooms_result) < MAX_ROOMS:
|
||||||
|
room_id = room_queue.popleft()
|
||||||
|
logger.debug("Processing room %s", room_id)
|
||||||
|
processed_rooms.add(room_id)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._auth.check_user_in_room_or_world_readable(
|
||||||
|
room_id, requester
|
||||||
|
)
|
||||||
|
except AuthError:
|
||||||
|
logger.info(
|
||||||
|
"user %s cannot view room %s, omitting from summary",
|
||||||
|
requester,
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
room_entry = await self._build_room_entry(room_id)
|
||||||
|
rooms_result.append(room_entry)
|
||||||
|
|
||||||
|
# look for child rooms/spaces.
|
||||||
|
child_events = await self._get_child_events(room_id)
|
||||||
|
|
||||||
|
if suggested_only:
|
||||||
|
# we only care about suggested children
|
||||||
|
child_events = filter(_is_suggested_child_event, child_events)
|
||||||
|
|
||||||
|
# The client-specified max_rooms_per_space limit doesn't apply to the
|
||||||
|
# room_id specified in the request, so we ignore it if this is the
|
||||||
|
# first room we are processing. Otherwise, apply any client-specified
|
||||||
|
# limit, capping to our built-in limit.
|
||||||
|
if max_rooms_per_space is not None and len(processed_rooms) > 1:
|
||||||
|
max_rooms = min(MAX_ROOMS_PER_SPACE, max_rooms_per_space)
|
||||||
|
else:
|
||||||
|
max_rooms = MAX_ROOMS_PER_SPACE
|
||||||
|
|
||||||
|
for edge_event in itertools.islice(child_events, max_rooms):
|
||||||
|
edge_room_id = edge_event.state_key
|
||||||
|
|
||||||
|
events_result.append(
|
||||||
|
await self._event_serializer.serialize_event(
|
||||||
|
edge_event,
|
||||||
|
time_now=now,
|
||||||
|
event_format=format_event_for_client_v2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# if we haven't yet visited the target of this link, add it to the queue
|
||||||
|
if edge_room_id not in processed_rooms:
|
||||||
|
room_queue.append(edge_room_id)
|
||||||
|
|
||||||
|
return {"rooms": rooms_result, "events": events_result}
|
||||||
|
|
||||||
|
async def _build_room_entry(self, room_id: str) -> JsonDict:
|
||||||
|
"""Generate en entry suitable for the 'rooms' list in the summary response"""
|
||||||
|
stats = await self._store.get_room_with_stats(room_id)
|
||||||
|
|
||||||
|
# currently this should be impossible because we call
|
||||||
|
# check_user_in_room_or_world_readable on the room before we get here, so
|
||||||
|
# there should always be an entry
|
||||||
|
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
|
||||||
|
|
||||||
|
current_state_ids = await self._store.get_current_state_ids(room_id)
|
||||||
|
create_event = await self._store.get_event(
|
||||||
|
current_state_ids[(EventTypes.Create, "")]
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: update once MSC1772 lands
|
||||||
|
room_type = create_event.content.get(EventContentFields.MSC1772_ROOM_TYPE)
|
||||||
|
|
||||||
|
entry = {
|
||||||
|
"room_id": stats["room_id"],
|
||||||
|
"name": stats["name"],
|
||||||
|
"topic": stats["topic"],
|
||||||
|
"canonical_alias": stats["canonical_alias"],
|
||||||
|
"num_joined_members": stats["joined_members"],
|
||||||
|
"avatar_url": stats["avatar"],
|
||||||
|
"world_readable": (
|
||||||
|
stats["history_visibility"] == HistoryVisibility.WORLD_READABLE
|
||||||
|
),
|
||||||
|
"guest_can_join": stats["guest_access"] == "can_join",
|
||||||
|
"room_type": room_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter out Nones – rather omit the field altogether
|
||||||
|
room_entry = {k: v for k, v in entry.items() if v is not None}
|
||||||
|
|
||||||
|
return room_entry
|
||||||
|
|
||||||
|
async def _get_child_events(self, room_id: str) -> Iterable[EventBase]:
|
||||||
|
# look for child rooms/spaces.
|
||||||
|
current_state_ids = await self._store.get_current_state_ids(room_id)
|
||||||
|
|
||||||
|
events = await self._store.get_events_as_list(
|
||||||
|
[
|
||||||
|
event_id
|
||||||
|
for key, event_id in current_state_ids.items()
|
||||||
|
# TODO: update once MSC1772 lands
|
||||||
|
if key[0] == EventTypes.MSC1772_SPACE_CHILD
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# filter out any events without a "via" (which implies it has been redacted)
|
||||||
|
return (e for e in events if e.content.get("via"))
|
||||||
|
|
||||||
|
|
||||||
|
def _is_suggested_child_event(edge_event: EventBase) -> bool:
|
||||||
|
suggested = edge_event.content.get("suggested")
|
||||||
|
if isinstance(suggested, bool) and suggested:
|
||||||
|
return True
|
||||||
|
logger.debug("Ignorning not-suggested child %s", edge_event.state_key)
|
||||||
|
return False
|
|
@ -271,7 +271,7 @@ class UserRestServletV2(RestServlet):
|
||||||
elif not deactivate and user["deactivated"]:
|
elif not deactivate and user["deactivated"]:
|
||||||
if (
|
if (
|
||||||
"password" not in body
|
"password" not in body
|
||||||
and self.hs.config.password_localdb_enabled
|
and self.auth_handler.can_change_password()
|
||||||
):
|
):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400, "Must provide a password to re-activate an account."
|
400, "Must provide a password to re-activate an account."
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
from urllib import parse as urlparse
|
from urllib import parse as urlparse
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
|
@ -35,16 +35,25 @@ from synapse.events.utils import format_event_for_client_v2
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet,
|
RestServlet,
|
||||||
assert_params_in_dict,
|
assert_params_in_dict,
|
||||||
|
parse_boolean,
|
||||||
parse_integer,
|
parse_integer,
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
parse_string,
|
parse_string,
|
||||||
)
|
)
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.opentracing import set_tag
|
from synapse.logging.opentracing import set_tag
|
||||||
from synapse.rest.client.transactions import HttpTransactionCache
|
from synapse.rest.client.transactions import HttpTransactionCache
|
||||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||||
from synapse.storage.state import StateFilter
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
|
from synapse.types import (
|
||||||
|
JsonDict,
|
||||||
|
RoomAlias,
|
||||||
|
RoomID,
|
||||||
|
StreamToken,
|
||||||
|
ThirdPartyInstanceID,
|
||||||
|
UserID,
|
||||||
|
)
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.stringutils import parse_and_validate_server_name, random_string
|
from synapse.util.stringutils import parse_and_validate_server_name, random_string
|
||||||
|
|
||||||
|
@ -987,7 +996,58 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server, is_worker=False):
|
class RoomSpaceSummaryRestServlet(RestServlet):
|
||||||
|
PATTERNS = (
|
||||||
|
re.compile(
|
||||||
|
"^/_matrix/client/unstable/org.matrix.msc2946"
|
||||||
|
"/rooms/(?P<room_id>[^/]*)/spaces$"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
super().__init__()
|
||||||
|
self._auth = hs.get_auth()
|
||||||
|
self._space_summary_handler = hs.get_space_summary_handler()
|
||||||
|
|
||||||
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, room_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self._auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
|
return 200, await self._space_summary_handler.get_space_summary(
|
||||||
|
requester.user.to_string(),
|
||||||
|
room_id,
|
||||||
|
suggested_only=parse_boolean(request, "suggested_only", default=False),
|
||||||
|
max_rooms_per_space=parse_integer(request, "max_rooms_per_space"),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_POST(
|
||||||
|
self, request: SynapseRequest, room_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
requester = await self._auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
suggested_only = content.get("suggested_only", False)
|
||||||
|
if not isinstance(suggested_only, bool):
|
||||||
|
raise SynapseError(
|
||||||
|
400, "'suggested_only' must be a boolean", Codes.BAD_JSON
|
||||||
|
)
|
||||||
|
|
||||||
|
max_rooms_per_space = content.get("max_rooms_per_space")
|
||||||
|
if max_rooms_per_space is not None and not isinstance(max_rooms_per_space, int):
|
||||||
|
raise SynapseError(
|
||||||
|
400, "'max_rooms_per_space' must be an integer", Codes.BAD_JSON
|
||||||
|
)
|
||||||
|
|
||||||
|
return 200, await self._space_summary_handler.get_space_summary(
|
||||||
|
requester.user.to_string(),
|
||||||
|
room_id,
|
||||||
|
suggested_only=suggested_only,
|
||||||
|
max_rooms_per_space=max_rooms_per_space,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs: "HomeServer", http_server, is_worker=False):
|
||||||
RoomStateEventRestServlet(hs).register(http_server)
|
RoomStateEventRestServlet(hs).register(http_server)
|
||||||
RoomMemberListRestServlet(hs).register(http_server)
|
RoomMemberListRestServlet(hs).register(http_server)
|
||||||
JoinedRoomMemberListRestServlet(hs).register(http_server)
|
JoinedRoomMemberListRestServlet(hs).register(http_server)
|
||||||
|
@ -1001,6 +1061,9 @@ def register_servlets(hs, http_server, is_worker=False):
|
||||||
RoomTypingRestServlet(hs).register(http_server)
|
RoomTypingRestServlet(hs).register(http_server)
|
||||||
RoomEventContextServlet(hs).register(http_server)
|
RoomEventContextServlet(hs).register(http_server)
|
||||||
|
|
||||||
|
if hs.config.experimental.spaces_enabled:
|
||||||
|
RoomSpaceSummaryRestServlet(hs).register(http_server)
|
||||||
|
|
||||||
# Some servlets only get registered for the main process.
|
# Some servlets only get registered for the main process.
|
||||||
if not is_worker:
|
if not is_worker:
|
||||||
RoomCreateRestServlet(hs).register(http_server)
|
RoomCreateRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -100,6 +100,7 @@ from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHand
|
||||||
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
|
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
|
||||||
from synapse.handlers.search import SearchHandler
|
from synapse.handlers.search import SearchHandler
|
||||||
from synapse.handlers.set_password import SetPasswordHandler
|
from synapse.handlers.set_password import SetPasswordHandler
|
||||||
|
from synapse.handlers.space_summary import SpaceSummaryHandler
|
||||||
from synapse.handlers.sso import SsoHandler
|
from synapse.handlers.sso import SsoHandler
|
||||||
from synapse.handlers.stats import StatsHandler
|
from synapse.handlers.stats import StatsHandler
|
||||||
from synapse.handlers.sync import SyncHandler
|
from synapse.handlers.sync import SyncHandler
|
||||||
|
@ -732,6 +733,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
def get_account_data_handler(self) -> AccountDataHandler:
|
def get_account_data_handler(self) -> AccountDataHandler:
|
||||||
return AccountDataHandler(self)
|
return AccountDataHandler(self)
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
|
def get_space_summary_handler(self) -> SpaceSummaryHandler:
|
||||||
|
return SpaceSummaryHandler(self)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_external_cache(self) -> ExternalCache:
|
def get_external_cache(self) -> ExternalCache:
|
||||||
return ExternalCache(self)
|
return ExternalCache(self)
|
||||||
|
|
|
@ -1210,6 +1210,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream(
|
||||||
txn, self.get_user_deactivated_status, (user_id,)
|
txn, self.get_user_deactivated_status, (user_id,)
|
||||||
)
|
)
|
||||||
|
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
|
||||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
|
|
|
@ -2,6 +2,7 @@ from typing import List, Tuple
|
||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.federation.sender import PerDestinationQueue, TransactionManager
|
from synapse.federation.sender import PerDestinationQueue, TransactionManager
|
||||||
from synapse.federation.units import Edu
|
from synapse.federation.units import Edu
|
||||||
|
@ -421,3 +422,51 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
||||||
self.assertNotIn("zzzerver", woken)
|
self.assertNotIn("zzzerver", woken)
|
||||||
# - all destinations are woken exactly once; they appear once in woken.
|
# - all destinations are woken exactly once; they appear once in woken.
|
||||||
self.assertCountEqual(woken, server_names[:-1])
|
self.assertCountEqual(woken, server_names[:-1])
|
||||||
|
|
||||||
|
@override_config({"send_federation": True})
|
||||||
|
def test_not_latest_event(self):
|
||||||
|
"""Test that we send the latest event in the room even if its not ours."""
|
||||||
|
|
||||||
|
per_dest_queue, sent_pdus = self.make_fake_destination_queue()
|
||||||
|
|
||||||
|
# Make a room with a local user, and two servers. One will go offline
|
||||||
|
# and one will send some events.
|
||||||
|
self.register_user("u1", "you the one")
|
||||||
|
u1_token = self.login("u1", "you the one")
|
||||||
|
room_1 = self.helper.create_room_as("u1", tok=u1_token)
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
|
||||||
|
)
|
||||||
|
event_1 = self.get_success(
|
||||||
|
event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join")
|
||||||
|
)
|
||||||
|
|
||||||
|
# First we send something from the local server, so that we notice the
|
||||||
|
# remote is down and go into catchup mode.
|
||||||
|
self.helper.send(room_1, "you hear me!!", tok=u1_token)
|
||||||
|
|
||||||
|
# Now simulate us receiving an event from the still online remote.
|
||||||
|
event_2 = self.get_success(
|
||||||
|
event_injection.inject_event(
|
||||||
|
self.hs,
|
||||||
|
type=EventTypes.Message,
|
||||||
|
sender="@user:host3",
|
||||||
|
room_id=room_1,
|
||||||
|
content={"msgtype": "m.text", "body": "Hello"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
self.hs.get_datastore().set_destination_last_successful_stream_ordering(
|
||||||
|
"host2", event_1.internal_metadata.stream_ordering
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_success(per_dest_queue._catch_up_transmission_loop())
|
||||||
|
|
||||||
|
# We expect only the last message from the remote, event_2, to have been
|
||||||
|
# sent, rather than the last *local* event that was sent.
|
||||||
|
self.assertEqual(len(sent_pdus), 1)
|
||||||
|
self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
|
||||||
|
self.assertFalse(per_dest_queue._catching_up)
|
||||||
|
|
|
@ -1003,12 +1003,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
|
# create users and get access tokens
|
||||||
|
# regardless of whether password login or SSO is allowed
|
||||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||||
self.admin_user_tok = self.login("admin", "pass")
|
self.admin_user_tok = self.get_success(
|
||||||
|
self.auth_handler.get_access_token_for_user_id(
|
||||||
|
self.admin_user, device_id=None, valid_until_ms=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.other_user = self.register_user("user", "pass", displayname="User")
|
self.other_user = self.register_user("user", "pass", displayname="User")
|
||||||
self.other_user_token = self.login("user", "pass")
|
self.other_user_token = self.get_success(
|
||||||
|
self.auth_handler.get_access_token_for_user_id(
|
||||||
|
self.other_user, device_id=None, valid_until_ms=None
|
||||||
|
)
|
||||||
|
)
|
||||||
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
|
self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
|
||||||
self.other_user
|
self.other_user
|
||||||
)
|
)
|
||||||
|
@ -1081,7 +1092,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||||
self.assertEqual(True, channel.json_body["admin"])
|
self.assertTrue(channel.json_body["admin"])
|
||||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
|
@ -1096,9 +1107,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||||
self.assertEqual(True, channel.json_body["admin"])
|
self.assertTrue(channel.json_body["admin"])
|
||||||
self.assertEqual(False, channel.json_body["is_guest"])
|
self.assertFalse(channel.json_body["is_guest"])
|
||||||
self.assertEqual(False, channel.json_body["deactivated"])
|
self.assertFalse(channel.json_body["deactivated"])
|
||||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||||
|
|
||||||
def test_create_user(self):
|
def test_create_user(self):
|
||||||
|
@ -1130,7 +1141,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||||
self.assertEqual(False, channel.json_body["admin"])
|
self.assertFalse(channel.json_body["admin"])
|
||||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
|
@ -1145,10 +1156,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||||
self.assertEqual(False, channel.json_body["admin"])
|
self.assertFalse(channel.json_body["admin"])
|
||||||
self.assertEqual(False, channel.json_body["is_guest"])
|
self.assertFalse(channel.json_body["is_guest"])
|
||||||
self.assertEqual(False, channel.json_body["deactivated"])
|
self.assertFalse(channel.json_body["deactivated"])
|
||||||
self.assertEqual(False, channel.json_body["shadow_banned"])
|
self.assertFalse(channel.json_body["shadow_banned"])
|
||||||
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
|
@ -1197,7 +1208,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||||
self.assertEqual(False, channel.json_body["admin"])
|
self.assertFalse(channel.json_body["admin"])
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
|
{"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
|
||||||
|
@ -1237,7 +1248,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
# Admin user is not blocked by mau anymore
|
# Admin user is not blocked by mau anymore
|
||||||
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||||
self.assertEqual(False, channel.json_body["admin"])
|
self.assertFalse(channel.json_body["admin"])
|
||||||
|
|
||||||
@override_config(
|
@override_config(
|
||||||
{
|
{
|
||||||
|
@ -1429,24 +1440,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("@user:test", channel.json_body["name"])
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
self.assertEqual(False, channel.json_body["deactivated"])
|
self.assertFalse(channel.json_body["deactivated"])
|
||||||
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
|
self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
|
||||||
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
|
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
|
||||||
self.assertEqual("User", channel.json_body["displayname"])
|
self.assertEqual("User", channel.json_body["displayname"])
|
||||||
|
|
||||||
# Deactivate user
|
# Deactivate user
|
||||||
body = json.dumps({"deactivated": True})
|
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
self.url_other_user,
|
self.url_other_user,
|
||||||
access_token=self.admin_user_tok,
|
access_token=self.admin_user_tok,
|
||||||
content=body.encode(encoding="utf_8"),
|
content={"deactivated": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("@user:test", channel.json_body["name"])
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
self.assertEqual(True, channel.json_body["deactivated"])
|
self.assertTrue(channel.json_body["deactivated"])
|
||||||
|
self.assertIsNone(channel.json_body["password_hash"])
|
||||||
self.assertEqual(0, len(channel.json_body["threepids"]))
|
self.assertEqual(0, len(channel.json_body["threepids"]))
|
||||||
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
|
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
|
||||||
self.assertEqual("User", channel.json_body["displayname"])
|
self.assertEqual("User", channel.json_body["displayname"])
|
||||||
|
@ -1461,7 +1471,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("@user:test", channel.json_body["name"])
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
self.assertEqual(True, channel.json_body["deactivated"])
|
self.assertTrue(channel.json_body["deactivated"])
|
||||||
|
self.assertIsNone(channel.json_body["password_hash"])
|
||||||
self.assertEqual(0, len(channel.json_body["threepids"]))
|
self.assertEqual(0, len(channel.json_body["threepids"]))
|
||||||
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
|
self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
|
||||||
self.assertEqual("User", channel.json_body["displayname"])
|
self.assertEqual("User", channel.json_body["displayname"])
|
||||||
|
@ -1478,41 +1489,37 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertTrue(profile["display_name"] == "User")
|
self.assertTrue(profile["display_name"] == "User")
|
||||||
|
|
||||||
# Deactivate user
|
# Deactivate user
|
||||||
body = json.dumps({"deactivated": True})
|
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
self.url_other_user,
|
self.url_other_user,
|
||||||
access_token=self.admin_user_tok,
|
access_token=self.admin_user_tok,
|
||||||
content=body.encode(encoding="utf_8"),
|
content={"deactivated": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("@user:test", channel.json_body["name"])
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
self.assertEqual(True, channel.json_body["deactivated"])
|
self.assertTrue(channel.json_body["deactivated"])
|
||||||
|
|
||||||
# is not in user directory
|
# is not in user directory
|
||||||
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
|
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
|
||||||
self.assertTrue(profile is None)
|
self.assertIsNone(profile)
|
||||||
|
|
||||||
# Set new displayname user
|
# Set new displayname user
|
||||||
body = json.dumps({"displayname": "Foobar"})
|
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
self.url_other_user,
|
self.url_other_user,
|
||||||
access_token=self.admin_user_tok,
|
access_token=self.admin_user_tok,
|
||||||
content=body.encode(encoding="utf_8"),
|
content={"displayname": "Foobar"},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("@user:test", channel.json_body["name"])
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
self.assertEqual(True, channel.json_body["deactivated"])
|
self.assertTrue(channel.json_body["deactivated"])
|
||||||
self.assertEqual("Foobar", channel.json_body["displayname"])
|
self.assertEqual("Foobar", channel.json_body["displayname"])
|
||||||
|
|
||||||
# is not in user directory
|
# is not in user directory
|
||||||
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
|
profile = self.get_success(self.store.get_user_in_directory(self.other_user))
|
||||||
self.assertTrue(profile is None)
|
self.assertIsNone(profile)
|
||||||
|
|
||||||
def test_reactivate_user(self):
|
def test_reactivate_user(self):
|
||||||
"""
|
"""
|
||||||
|
@ -1520,24 +1527,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Deactivate the user.
|
# Deactivate the user.
|
||||||
channel = self.make_request(
|
self._deactivate_user("@user:test")
|
||||||
"PUT",
|
|
||||||
self.url_other_user,
|
|
||||||
access_token=self.admin_user_tok,
|
|
||||||
content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
|
|
||||||
)
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
|
||||||
self._is_erased("@user:test", False)
|
|
||||||
d = self.store.mark_user_erased("@user:test")
|
|
||||||
self.assertIsNone(self.get_success(d))
|
|
||||||
self._is_erased("@user:test", True)
|
|
||||||
|
|
||||||
# Attempt to reactivate the user (without a password).
|
# Attempt to reactivate the user (without a password).
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
self.url_other_user,
|
self.url_other_user,
|
||||||
access_token=self.admin_user_tok,
|
access_token=self.admin_user_tok,
|
||||||
content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
|
content={"deactivated": False},
|
||||||
)
|
)
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
@ -1546,22 +1543,76 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
"PUT",
|
"PUT",
|
||||||
self.url_other_user,
|
self.url_other_user,
|
||||||
access_token=self.admin_user_tok,
|
access_token=self.admin_user_tok,
|
||||||
content=json.dumps({"deactivated": False, "password": "foo"}).encode(
|
content={"deactivated": False, "password": "foo"},
|
||||||
encoding="utf_8"
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
|
||||||
|
|
||||||
# Get user
|
|
||||||
channel = self.make_request(
|
|
||||||
"GET",
|
|
||||||
self.url_other_user,
|
|
||||||
access_token=self.admin_user_tok,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("@user:test", channel.json_body["name"])
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
self.assertEqual(False, channel.json_body["deactivated"])
|
self.assertFalse(channel.json_body["deactivated"])
|
||||||
|
self.assertIsNotNone(channel.json_body["password_hash"])
|
||||||
|
self._is_erased("@user:test", False)
|
||||||
|
|
||||||
|
@override_config({"password_config": {"localdb_enabled": False}})
|
||||||
|
def test_reactivate_user_localdb_disabled(self):
|
||||||
|
"""
|
||||||
|
Test reactivating another user when using SSO.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Deactivate the user.
|
||||||
|
self._deactivate_user("@user:test")
|
||||||
|
|
||||||
|
# Reactivate the user with a password
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url_other_user,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={"deactivated": False, "password": "foo"},
|
||||||
|
)
|
||||||
|
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
# Reactivate the user without a password.
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url_other_user,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={"deactivated": False},
|
||||||
|
)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
|
self.assertFalse(channel.json_body["deactivated"])
|
||||||
|
self.assertIsNone(channel.json_body["password_hash"])
|
||||||
|
self._is_erased("@user:test", False)
|
||||||
|
|
||||||
|
@override_config({"password_config": {"enabled": False}})
|
||||||
|
def test_reactivate_user_password_disabled(self):
|
||||||
|
"""
|
||||||
|
Test reactivating another user when using SSO.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Deactivate the user.
|
||||||
|
self._deactivate_user("@user:test")
|
||||||
|
|
||||||
|
# Reactivate the user with a password
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url_other_user,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={"deactivated": False, "password": "foo"},
|
||||||
|
)
|
||||||
|
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
# Reactivate the user without a password.
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
self.url_other_user,
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={"deactivated": False},
|
||||||
|
)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
|
self.assertFalse(channel.json_body["deactivated"])
|
||||||
|
self.assertIsNone(channel.json_body["password_hash"])
|
||||||
self._is_erased("@user:test", False)
|
self._is_erased("@user:test", False)
|
||||||
|
|
||||||
def test_set_user_as_admin(self):
|
def test_set_user_as_admin(self):
|
||||||
|
@ -1570,18 +1621,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Set a user as an admin
|
# Set a user as an admin
|
||||||
body = json.dumps({"admin": True})
|
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
self.url_other_user,
|
self.url_other_user,
|
||||||
access_token=self.admin_user_tok,
|
access_token=self.admin_user_tok,
|
||||||
content=body.encode(encoding="utf_8"),
|
content={"admin": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("@user:test", channel.json_body["name"])
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
self.assertEqual(True, channel.json_body["admin"])
|
self.assertTrue(channel.json_body["admin"])
|
||||||
|
|
||||||
# Get user
|
# Get user
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
|
@ -1592,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
self.assertEqual("@user:test", channel.json_body["name"])
|
self.assertEqual("@user:test", channel.json_body["name"])
|
||||||
self.assertEqual(True, channel.json_body["admin"])
|
self.assertTrue(channel.json_body["admin"])
|
||||||
|
|
||||||
def test_accidental_deactivation_prevention(self):
|
def test_accidental_deactivation_prevention(self):
|
||||||
"""
|
"""
|
||||||
|
@ -1602,13 +1651,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
url = "/_synapse/admin/v2/users/@bob:test"
|
url = "/_synapse/admin/v2/users/@bob:test"
|
||||||
|
|
||||||
# Create user
|
# Create user
|
||||||
body = json.dumps({"password": "abc123"})
|
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
url,
|
url,
|
||||||
access_token=self.admin_user_tok,
|
access_token=self.admin_user_tok,
|
||||||
content=body.encode(encoding="utf_8"),
|
content={"password": "abc123"},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
@ -1628,13 +1675,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(0, channel.json_body["deactivated"])
|
self.assertEqual(0, channel.json_body["deactivated"])
|
||||||
|
|
||||||
# Change password (and use a str for deactivate instead of a bool)
|
# Change password (and use a str for deactivate instead of a bool)
|
||||||
body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops!
|
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"PUT",
|
"PUT",
|
||||||
url,
|
url,
|
||||||
access_token=self.admin_user_tok,
|
access_token=self.admin_user_tok,
|
||||||
content=body.encode(encoding="utf_8"),
|
content={"password": "abc123", "deactivated": "false"},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
@ -1653,7 +1698,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
# Ensure they're still alive
|
# Ensure they're still alive
|
||||||
self.assertEqual(0, channel.json_body["deactivated"])
|
self.assertEqual(0, channel.json_body["deactivated"])
|
||||||
|
|
||||||
def _is_erased(self, user_id, expect):
|
def _is_erased(self, user_id: str, expect: bool) -> None:
|
||||||
"""Assert that the user is erased or not"""
|
"""Assert that the user is erased or not"""
|
||||||
d = self.store.is_user_erased(user_id)
|
d = self.store.is_user_erased(user_id)
|
||||||
if expect:
|
if expect:
|
||||||
|
@ -1661,6 +1706,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||||
else:
|
else:
|
||||||
self.assertFalse(self.get_success(d))
|
self.assertFalse(self.get_success(d))
|
||||||
|
|
||||||
|
def _deactivate_user(self, user_id: str) -> None:
|
||||||
|
"""Deactivate user and set as erased"""
|
||||||
|
|
||||||
|
# Deactivate the user.
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
"/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
|
||||||
|
access_token=self.admin_user_tok,
|
||||||
|
content={"deactivated": True},
|
||||||
|
)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertTrue(channel.json_body["deactivated"])
|
||||||
|
self.assertIsNone(channel.json_body["password_hash"])
|
||||||
|
self._is_erased(user_id, False)
|
||||||
|
d = self.store.mark_user_erased(user_id)
|
||||||
|
self.assertIsNone(self.get_success(d))
|
||||||
|
self._is_erased(user_id, True)
|
||||||
|
|
||||||
|
|
||||||
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue