Update the MSC3083 support to verify if joins are from an authorized server. (#10254)
parent
4fb92d93ea
commit
228decfce1
|
@ -0,0 +1 @@
|
|||
Update support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083) to consider changes in the MSC around which servers can issue join events.
|
|
@ -75,6 +75,9 @@ class Codes:
|
|||
INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
|
||||
USER_DEACTIVATED = "M_USER_DEACTIVATED"
|
||||
BAD_ALIAS = "M_BAD_ALIAS"
|
||||
# For restricted join rules.
|
||||
UNABLE_AUTHORISE_JOIN = "M_UNABLE_TO_AUTHORISE_JOIN"
|
||||
UNABLE_TO_GRANT_JOIN = "M_UNABLE_TO_GRANT_JOIN"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
|
|
|
@ -168,7 +168,7 @@ class RoomVersions:
|
|||
msc2403_knocking=False,
|
||||
)
|
||||
MSC3083 = RoomVersion(
|
||||
"org.matrix.msc3083",
|
||||
"org.matrix.msc3083.v2",
|
||||
RoomDisposition.UNSTABLE,
|
||||
EventFormatVersions.V3,
|
||||
StateResolutionVersions.V2,
|
||||
|
|
|
@ -106,6 +106,18 @@ def check(
|
|||
if not event.signatures.get(event_id_domain):
|
||||
raise AuthError(403, "Event not signed by sending server")
|
||||
|
||||
is_invite_via_allow_rule = (
|
||||
event.type == EventTypes.Member
|
||||
and event.membership == Membership.JOIN
|
||||
and "join_authorised_via_users_server" in event.content
|
||||
)
|
||||
if is_invite_via_allow_rule:
|
||||
authoriser_domain = get_domain_from_id(
|
||||
event.content["join_authorised_via_users_server"]
|
||||
)
|
||||
if not event.signatures.get(authoriser_domain):
|
||||
raise AuthError(403, "Event not signed by authorising server")
|
||||
|
||||
# Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
|
||||
#
|
||||
# 1. If type is m.room.create:
|
||||
|
@ -177,7 +189,7 @@ def check(
|
|||
# https://github.com/vector-im/vector-web/issues/1208 hopefully
|
||||
if event.type == EventTypes.ThirdPartyInvite:
|
||||
user_level = get_user_power_level(event.user_id, auth_events)
|
||||
invite_level = _get_named_level(auth_events, "invite", 0)
|
||||
invite_level = get_named_level(auth_events, "invite", 0)
|
||||
|
||||
if user_level < invite_level:
|
||||
raise AuthError(403, "You don't have permission to invite users")
|
||||
|
@ -285,8 +297,8 @@ def _is_membership_change_allowed(
|
|||
user_level = get_user_power_level(event.user_id, auth_events)
|
||||
target_level = get_user_power_level(target_user_id, auth_events)
|
||||
|
||||
# FIXME (erikj): What should we do here as the default?
|
||||
ban_level = _get_named_level(auth_events, "ban", 50)
|
||||
invite_level = get_named_level(auth_events, "invite", 0)
|
||||
ban_level = get_named_level(auth_events, "ban", 50)
|
||||
|
||||
logger.debug(
|
||||
"_is_membership_change_allowed: %s",
|
||||
|
@ -336,8 +348,6 @@ def _is_membership_change_allowed(
|
|||
elif target_in_room: # the target is already in the room.
|
||||
raise AuthError(403, "%s is already in the room." % target_user_id)
|
||||
else:
|
||||
invite_level = _get_named_level(auth_events, "invite", 0)
|
||||
|
||||
if user_level < invite_level:
|
||||
raise AuthError(403, "You don't have permission to invite users")
|
||||
elif Membership.JOIN == membership:
|
||||
|
@ -345,16 +355,41 @@ def _is_membership_change_allowed(
|
|||
# * They are not banned.
|
||||
# * They are accepting a previously sent invitation.
|
||||
# * They are already joined (it's a NOOP).
|
||||
# * The room is public or restricted.
|
||||
# * The room is public.
|
||||
# * The room is restricted and the user meets the allows rules.
|
||||
if event.user_id != target_user_id:
|
||||
raise AuthError(403, "Cannot force another user to join.")
|
||||
elif target_banned:
|
||||
raise AuthError(403, "You are banned from this room")
|
||||
elif join_rule == JoinRules.PUBLIC or (
|
||||
elif join_rule == JoinRules.PUBLIC:
|
||||
pass
|
||||
elif (
|
||||
room_version.msc3083_join_rules
|
||||
and join_rule == JoinRules.MSC3083_RESTRICTED
|
||||
):
|
||||
pass
|
||||
# This is the same as public, but the event must contain a reference
|
||||
# to the server who authorised the join. If the event does not contain
|
||||
# the proper content it is rejected.
|
||||
#
|
||||
# Note that if the caller is in the room or invited, then they do
|
||||
# not need to meet the allow rules.
|
||||
if not caller_in_room and not caller_invited:
|
||||
authorising_user = event.content.get("join_authorised_via_users_server")
|
||||
|
||||
if authorising_user is None:
|
||||
raise AuthError(403, "Join event is missing authorising user.")
|
||||
|
||||
# The authorising user must be in the room.
|
||||
key = (EventTypes.Member, authorising_user)
|
||||
member_event = auth_events.get(key)
|
||||
_check_joined_room(member_event, authorising_user, event.room_id)
|
||||
|
||||
authorising_user_level = get_user_power_level(
|
||||
authorising_user, auth_events
|
||||
)
|
||||
if authorising_user_level < invite_level:
|
||||
raise AuthError(403, "Join event authorised by invalid server.")
|
||||
|
||||
elif join_rule == JoinRules.INVITE or (
|
||||
room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
|
||||
):
|
||||
|
@ -369,7 +404,7 @@ def _is_membership_change_allowed(
|
|||
if target_banned and user_level < ban_level:
|
||||
raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
|
||||
elif target_user_id != event.user_id:
|
||||
kick_level = _get_named_level(auth_events, "kick", 50)
|
||||
kick_level = get_named_level(auth_events, "kick", 50)
|
||||
|
||||
if user_level < kick_level or user_level <= target_level:
|
||||
raise AuthError(403, "You cannot kick user %s." % target_user_id)
|
||||
|
@ -445,7 +480,7 @@ def get_send_level(
|
|||
|
||||
|
||||
def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
|
||||
power_levels_event = _get_power_level_event(auth_events)
|
||||
power_levels_event = get_power_level_event(auth_events)
|
||||
|
||||
send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
|
||||
user_level = get_user_power_level(event.user_id, auth_events)
|
||||
|
@ -485,7 +520,7 @@ def check_redaction(
|
|||
"""
|
||||
user_level = get_user_power_level(event.user_id, auth_events)
|
||||
|
||||
redact_level = _get_named_level(auth_events, "redact", 50)
|
||||
redact_level = get_named_level(auth_events, "redact", 50)
|
||||
|
||||
if user_level >= redact_level:
|
||||
return False
|
||||
|
@ -600,7 +635,7 @@ def _check_power_levels(
|
|||
)
|
||||
|
||||
|
||||
def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
|
||||
def get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
|
||||
return auth_events.get((EventTypes.PowerLevels, ""))
|
||||
|
||||
|
||||
|
@ -616,7 +651,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
|
|||
Returns:
|
||||
the user's power level in this room.
|
||||
"""
|
||||
power_level_event = _get_power_level_event(auth_events)
|
||||
power_level_event = get_power_level_event(auth_events)
|
||||
if power_level_event:
|
||||
level = power_level_event.content.get("users", {}).get(user_id)
|
||||
if not level:
|
||||
|
@ -640,8 +675,8 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
|
|||
return 0
|
||||
|
||||
|
||||
def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
|
||||
power_level_event = _get_power_level_event(auth_events)
|
||||
def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
|
||||
power_level_event = get_power_level_event(auth_events)
|
||||
|
||||
if not power_level_event:
|
||||
return default
|
||||
|
@ -728,7 +763,9 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
|
|||
return public_keys
|
||||
|
||||
|
||||
def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str, str]]:
|
||||
def auth_types_for_event(
|
||||
room_version: RoomVersion, event: Union[EventBase, EventBuilder]
|
||||
) -> Set[Tuple[str, str]]:
|
||||
"""Given an event, return a list of (EventType, StateKey) that may be
|
||||
needed to auth the event. The returned list may be a superset of what
|
||||
would actually be required depending on the full state of the room.
|
||||
|
@ -760,4 +797,12 @@ def auth_types_for_event(event: Union[EventBase, EventBuilder]) -> Set[Tuple[str
|
|||
)
|
||||
auth_types.add(key)
|
||||
|
||||
if room_version.msc3083_join_rules and membership == Membership.JOIN:
|
||||
if "join_authorised_via_users_server" in event.content:
|
||||
key = (
|
||||
EventTypes.Member,
|
||||
event.content["join_authorised_via_users_server"],
|
||||
)
|
||||
auth_types.add(key)
|
||||
|
||||
return auth_types
|
||||
|
|
|
@ -178,6 +178,34 @@ async def _check_sigs_on_pdu(
|
|||
)
|
||||
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
|
||||
|
||||
# If this is a join event for a restricted room it may have been authorised
|
||||
# via a different server from the sending server. Check those signatures.
|
||||
if (
|
||||
room_version.msc3083_join_rules
|
||||
and pdu.type == EventTypes.Member
|
||||
and pdu.membership == Membership.JOIN
|
||||
and "join_authorised_via_users_server" in pdu.content
|
||||
):
|
||||
authorising_server = get_domain_from_id(
|
||||
pdu.content["join_authorised_via_users_server"]
|
||||
)
|
||||
try:
|
||||
await keyring.verify_event_for_server(
|
||||
authorising_server,
|
||||
pdu,
|
||||
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
|
||||
)
|
||||
except Exception as e:
|
||||
errmsg = (
|
||||
"event id %s: unable to verify signature for authorising server %s: %s"
|
||||
% (
|
||||
pdu.event_id,
|
||||
authorising_server,
|
||||
e,
|
||||
)
|
||||
)
|
||||
raise SynapseError(403, errmsg, Codes.FORBIDDEN)
|
||||
|
||||
|
||||
def _is_invite_via_3pid(event: EventBase) -> bool:
|
||||
return (
|
||||
|
|
|
@ -19,7 +19,6 @@ import itertools
|
|||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
|
@ -79,7 +78,15 @@ class InvalidResponseError(RuntimeError):
|
|||
we couldn't parse
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class SendJoinResult:
|
||||
# The event to persist.
|
||||
event: EventBase
|
||||
# A string giving the server the event was sent to.
|
||||
origin: str
|
||||
state: List[EventBase]
|
||||
auth_chain: List[EventBase]
|
||||
|
||||
|
||||
class FederationClient(FederationBase):
|
||||
|
@ -677,7 +684,7 @@ class FederationClient(FederationBase):
|
|||
|
||||
async def send_join(
|
||||
self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
|
||||
) -> Dict[str, Any]:
|
||||
) -> SendJoinResult:
|
||||
"""Sends a join event to one of a list of homeservers.
|
||||
|
||||
Doing so will cause the remote server to add the event to the graph,
|
||||
|
@ -691,18 +698,38 @@ class FederationClient(FederationBase):
|
|||
did the make_join)
|
||||
|
||||
Returns:
|
||||
a dict with members ``origin`` (a string
|
||||
giving the server the event was sent to, ``state`` (?) and
|
||||
``auth_chain``.
|
||||
The result of the send join request.
|
||||
|
||||
Raises:
|
||||
SynapseError: if the chosen remote server returns a 300/400 code, or
|
||||
no servers successfully handle the request.
|
||||
"""
|
||||
|
||||
async def send_request(destination) -> Dict[str, Any]:
|
||||
async def send_request(destination) -> SendJoinResult:
|
||||
response = await self._do_send_join(room_version, destination, pdu)
|
||||
|
||||
# If an event was returned (and expected to be returned):
|
||||
#
|
||||
# * Ensure it has the same event ID (note that the event ID is a hash
|
||||
# of the event fields for versions which support MSC3083).
|
||||
# * Ensure the signatures are good.
|
||||
#
|
||||
# Otherwise, fallback to the provided event.
|
||||
if room_version.msc3083_join_rules and response.event:
|
||||
event = response.event
|
||||
|
||||
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
|
||||
pdu=event,
|
||||
origin=destination,
|
||||
outlier=True,
|
||||
room_version=room_version,
|
||||
)
|
||||
|
||||
if valid_pdu is None or event.event_id != pdu.event_id:
|
||||
raise InvalidResponseError("Returned an invalid join event")
|
||||
else:
|
||||
event = pdu
|
||||
|
||||
state = response.state
|
||||
auth_chain = response.auth_events
|
||||
|
||||
|
@ -784,11 +811,21 @@ class FederationClient(FederationBase):
|
|||
% (auth_chain_create_events,)
|
||||
)
|
||||
|
||||
return {
|
||||
"state": signed_state,
|
||||
"auth_chain": signed_auth,
|
||||
"origin": destination,
|
||||
}
|
||||
return SendJoinResult(
|
||||
event=event,
|
||||
state=signed_state,
|
||||
auth_chain=signed_auth,
|
||||
origin=destination,
|
||||
)
|
||||
|
||||
if room_version.msc3083_join_rules:
|
||||
# If the join is being authorised via allow rules, we need to send
|
||||
# the /send_join back to the same server that was originally used
|
||||
# with /make_join.
|
||||
if "join_authorised_via_users_server" in pdu.content:
|
||||
destinations = [
|
||||
get_domain_from_id(pdu.content["join_authorised_via_users_server"])
|
||||
]
|
||||
|
||||
return await self._try_destination_list("send_join", destinations, send_request)
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ from synapse.api.errors import (
|
|||
UnsupportedRoomVersionError,
|
||||
)
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||
|
@ -64,7 +65,7 @@ from synapse.replication.http.federation import (
|
|||
ReplicationGetQueryRestServlet,
|
||||
)
|
||||
from synapse.storage.databases.main.lock import Lock
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
@ -586,7 +587,7 @@ class FederationServer(FederationBase):
|
|||
async def on_send_join_request(
|
||||
self, origin: str, content: JsonDict, room_id: str
|
||||
) -> Dict[str, Any]:
|
||||
context = await self._on_send_membership_event(
|
||||
event, context = await self._on_send_membership_event(
|
||||
origin, content, Membership.JOIN, room_id
|
||||
)
|
||||
|
||||
|
@ -597,6 +598,7 @@ class FederationServer(FederationBase):
|
|||
|
||||
time_now = self._clock.time_msec()
|
||||
return {
|
||||
"org.matrix.msc3083.v2.event": event.get_pdu_json(),
|
||||
"state": [p.get_pdu_json(time_now) for p in state.values()],
|
||||
"auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
|
||||
}
|
||||
|
@ -681,7 +683,7 @@ class FederationServer(FederationBase):
|
|||
Returns:
|
||||
The stripped room state.
|
||||
"""
|
||||
event_context = await self._on_send_membership_event(
|
||||
_, context = await self._on_send_membership_event(
|
||||
origin, content, Membership.KNOCK, room_id
|
||||
)
|
||||
|
||||
|
@ -690,14 +692,14 @@ class FederationServer(FederationBase):
|
|||
# related to the room while the knock request is pending.
|
||||
stripped_room_state = (
|
||||
await self.store.get_stripped_room_state_from_event_context(
|
||||
event_context, self._room_prejoin_state_types
|
||||
context, self._room_prejoin_state_types
|
||||
)
|
||||
)
|
||||
return {"knock_state_events": stripped_room_state}
|
||||
|
||||
async def _on_send_membership_event(
|
||||
self, origin: str, content: JsonDict, membership_type: str, room_id: str
|
||||
) -> EventContext:
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
"""Handle an on_send_{join,leave,knock} request
|
||||
|
||||
Does some preliminary validation before passing the request on to the
|
||||
|
@ -712,7 +714,7 @@ class FederationServer(FederationBase):
|
|||
in the event
|
||||
|
||||
Returns:
|
||||
The context of the event after inserting it into the room graph.
|
||||
The event and context of the event after inserting it into the room graph.
|
||||
|
||||
Raises:
|
||||
SynapseError if there is a problem with the request, including things like
|
||||
|
@ -748,6 +750,33 @@ class FederationServer(FederationBase):
|
|||
|
||||
logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
|
||||
|
||||
# Sign the event since we're vouching on behalf of the remote server that
|
||||
# the event is valid to be sent into the room. Currently this is only done
|
||||
# if the user is being joined via restricted join rules.
|
||||
if (
|
||||
room_version.msc3083_join_rules
|
||||
and event.membership == Membership.JOIN
|
||||
and "join_authorised_via_users_server" in event.content
|
||||
):
|
||||
# We can only authorise our own users.
|
||||
authorising_server = get_domain_from_id(
|
||||
event.content["join_authorised_via_users_server"]
|
||||
)
|
||||
if authorising_server != self.server_name:
|
||||
raise SynapseError(
|
||||
400,
|
||||
f"Cannot authorise request from resident server: {authorising_server}",
|
||||
)
|
||||
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
room_version,
|
||||
event.get_pdu_json(),
|
||||
self.hs.hostname,
|
||||
self.hs.signing_key,
|
||||
)
|
||||
)
|
||||
|
||||
event = await self._check_sigs_and_hash(room_version, event)
|
||||
|
||||
return await self.handler.on_send_membership_event(origin, event)
|
||||
|
|
|
@ -1219,8 +1219,26 @@ def _create_v2_path(path: str, *args: str) -> str:
|
|||
class SendJoinResponse:
|
||||
"""The parsed response of a `/send_join` request."""
|
||||
|
||||
# The list of auth events from the /send_join response.
|
||||
auth_events: List[EventBase]
|
||||
# The list of state from the /send_join response.
|
||||
state: List[EventBase]
|
||||
# The raw join event from the /send_join response.
|
||||
event_dict: JsonDict
|
||||
# The parsed join event from the /send_join response. This will be None if
|
||||
# "event" is not included in the response.
|
||||
event: Optional[EventBase] = None
|
||||
|
||||
|
||||
@ijson.coroutine
|
||||
def _event_parser(event_dict: JsonDict):
|
||||
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
|
||||
to add them to a given dictionary.
|
||||
"""
|
||||
|
||||
while True:
|
||||
key, value = yield
|
||||
event_dict[key] = value
|
||||
|
||||
|
||||
@ijson.coroutine
|
||||
|
@ -1246,7 +1264,8 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
|
|||
CONTENT_TYPE = "application/json"
|
||||
|
||||
def __init__(self, room_version: RoomVersion, v1_api: bool):
|
||||
self._response = SendJoinResponse([], [])
|
||||
self._response = SendJoinResponse([], [], {})
|
||||
self._room_version = room_version
|
||||
|
||||
# The V1 API has the shape of `[200, {...}]`, which we handle by
|
||||
# prefixing with `item.*`.
|
||||
|
@ -1260,12 +1279,21 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
|
|||
_event_list_parser(room_version, self._response.auth_events),
|
||||
prefix + "auth_chain.item",
|
||||
)
|
||||
self._coro_event = ijson.kvitems_coro(
|
||||
_event_parser(self._response.event_dict),
|
||||
prefix + "org.matrix.msc3083.v2.event",
|
||||
)
|
||||
|
||||
def write(self, data: bytes) -> int:
|
||||
self._coro_state.send(data)
|
||||
self._coro_auth.send(data)
|
||||
self._coro_event.send(data)
|
||||
|
||||
return len(data)
|
||||
|
||||
def finish(self) -> SendJoinResponse:
|
||||
if self._response.event_dict:
|
||||
self._response.event = make_event_from_dict(
|
||||
self._response.event_dict, self._room_version
|
||||
)
|
||||
return self._response
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, List, Optional, Union
|
||||
|
||||
from synapse import event_auth
|
||||
|
@ -20,16 +21,18 @@ from synapse.api.constants import (
|
|||
Membership,
|
||||
RestrictedJoinRuleTypes,
|
||||
)
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.builder import EventBuilder
|
||||
from synapse.types import StateMap
|
||||
from synapse.types import StateMap, get_domain_from_id
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventAuthHandler:
|
||||
"""
|
||||
|
@ -39,6 +42,7 @@ class EventAuthHandler:
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
self._clock = hs.get_clock()
|
||||
self._store = hs.get_datastore()
|
||||
self._server_name = hs.hostname
|
||||
|
||||
async def check_from_context(
|
||||
self, room_version: str, event, context, do_sig_check=True
|
||||
|
@ -81,15 +85,76 @@ class EventAuthHandler:
|
|||
# introduce undesirable "state reset" behaviour.
|
||||
#
|
||||
# All of which sounds a bit tricky so we don't bother for now.
|
||||
|
||||
auth_ids = []
|
||||
for etype, state_key in event_auth.auth_types_for_event(event):
|
||||
for etype, state_key in event_auth.auth_types_for_event(
|
||||
event.room_version, event
|
||||
):
|
||||
auth_ev_id = current_state_ids.get((etype, state_key))
|
||||
if auth_ev_id:
|
||||
auth_ids.append(auth_ev_id)
|
||||
|
||||
return auth_ids
|
||||
|
||||
async def get_user_which_could_invite(
|
||||
self, room_id: str, current_state_ids: StateMap[str]
|
||||
) -> str:
|
||||
"""
|
||||
Searches the room state for a local user who has the power level necessary
|
||||
to invite other users.
|
||||
|
||||
Args:
|
||||
room_id: The room ID under search.
|
||||
current_state_ids: The current state of the room.
|
||||
|
||||
Returns:
|
||||
The MXID of the user which could issue an invite.
|
||||
|
||||
Raises:
|
||||
SynapseError if no appropriate user is found.
|
||||
"""
|
||||
power_level_event_id = current_state_ids.get((EventTypes.PowerLevels, ""))
|
||||
invite_level = 0
|
||||
users_default_level = 0
|
||||
if power_level_event_id:
|
||||
power_level_event = await self._store.get_event(power_level_event_id)
|
||||
invite_level = power_level_event.content.get("invite", invite_level)
|
||||
users_default_level = power_level_event.content.get(
|
||||
"users_default", users_default_level
|
||||
)
|
||||
users = power_level_event.content.get("users", {})
|
||||
else:
|
||||
users = {}
|
||||
|
||||
# Find the user with the highest power level.
|
||||
users_in_room = await self._store.get_users_in_room(room_id)
|
||||
# Only interested in local users.
|
||||
local_users_in_room = [
|
||||
u for u in users_in_room if get_domain_from_id(u) == self._server_name
|
||||
]
|
||||
chosen_user = max(
|
||||
local_users_in_room,
|
||||
key=lambda user: users.get(user, users_default_level),
|
||||
default=None,
|
||||
)
|
||||
|
||||
# Return the chosen if they can issue invites.
|
||||
user_power_level = users.get(chosen_user, users_default_level)
|
||||
if chosen_user and user_power_level >= invite_level:
|
||||
logger.debug(
|
||||
"Found a user who can issue invites %s with power level %d >= invite level %d",
|
||||
chosen_user,
|
||||
user_power_level,
|
||||
invite_level,
|
||||
)
|
||||
return chosen_user
|
||||
|
||||
# No user was found.
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Unable to find a user which could issue an invite",
|
||||
Codes.UNABLE_TO_GRANT_JOIN,
|
||||
)
|
||||
|
||||
async def check_host_in_room(self, room_id: str, host: str) -> bool:
|
||||
with Measure(self._clock, "check_host_in_room"):
|
||||
return await self._store.is_host_joined(room_id, host)
|
||||
|
@ -134,6 +199,18 @@ class EventAuthHandler:
|
|||
# in any of them.
|
||||
allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
|
||||
if not await self.is_user_in_rooms(allowed_rooms, user_id):
|
||||
|
||||
# If this is a remote request, the user might be in an allowed room
|
||||
# that we do not know about.
|
||||
if get_domain_from_id(user_id) != self._server_name:
|
||||
for room_id in allowed_rooms:
|
||||
if not await self._store.is_host_joined(room_id, self._server_name):
|
||||
raise SynapseError(
|
||||
400,
|
||||
f"Unable to check if {user_id} is in allowed rooms.",
|
||||
Codes.UNABLE_AUTHORISE_JOIN,
|
||||
)
|
||||
|
||||
raise AuthError(
|
||||
403,
|
||||
"You do not belong to any of the required rooms to join this room.",
|
||||
|
|
|
@ -1494,9 +1494,10 @@ class FederationHandler(BaseHandler):
|
|||
host_list, event, room_version_obj
|
||||
)
|
||||
|
||||
origin = ret["origin"]
|
||||
state = ret["state"]
|
||||
auth_chain = ret["auth_chain"]
|
||||
event = ret.event
|
||||
origin = ret.origin
|
||||
state = ret.state
|
||||
auth_chain = ret.auth_chain
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
logger.debug("do_invite_join auth_chain: %s", auth_chain)
|
||||
|
@ -1676,7 +1677,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# checking the room version will check that we've actually heard of the room
|
||||
# (and return a 404 otherwise)
|
||||
room_version = await self.store.get_room_version_id(room_id)
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
# now check that we are *still* in the room
|
||||
is_in_room = await self._event_auth_handler.check_host_in_room(
|
||||
|
@ -1691,8 +1692,38 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
event_content = {"membership": Membership.JOIN}
|
||||
|
||||
# If the current room is using restricted join rules, additional information
|
||||
# may need to be included in the event content in order to efficiently
|
||||
# validate the event.
|
||||
#
|
||||
# Note that this requires the /send_join request to come back to the
|
||||
# same server.
|
||||
if room_version.msc3083_join_rules:
|
||||
state_ids = await self.store.get_current_state_ids(room_id)
|
||||
if await self._event_auth_handler.has_restricted_join_rules(
|
||||
state_ids, room_version
|
||||
):
|
||||
prev_member_event_id = state_ids.get((EventTypes.Member, user_id), None)
|
||||
# If the user is invited or joined to the room already, then
|
||||
# no additional info is needed.
|
||||
include_auth_user_id = True
|
||||
if prev_member_event_id:
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
include_auth_user_id = prev_member_event.membership not in (
|
||||
Membership.JOIN,
|
||||
Membership.INVITE,
|
||||
)
|
||||
|
||||
if include_auth_user_id:
|
||||
event_content[
|
||||
"join_authorised_via_users_server"
|
||||
] = await self._event_auth_handler.get_user_which_could_invite(
|
||||
room_id,
|
||||
state_ids,
|
||||
)
|
||||
|
||||
builder = self.event_builder_factory.new(
|
||||
room_version,
|
||||
room_version.identifier,
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
"content": event_content,
|
||||
|
@ -1710,10 +1741,13 @@ class FederationHandler(BaseHandler):
|
|||
logger.warning("Failed to create join to %s because %s", room_id, e)
|
||||
raise
|
||||
|
||||
# Ensure the user can even join the room.
|
||||
await self._check_join_restrictions(context, event)
|
||||
|
||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||
# when we get the event back in `on_send_join_request`
|
||||
await self._event_auth_handler.check_from_context(
|
||||
room_version, event, context, do_sig_check=False
|
||||
room_version.identifier, event, context, do_sig_check=False
|
||||
)
|
||||
|
||||
return event
|
||||
|
@ -1958,7 +1992,7 @@ class FederationHandler(BaseHandler):
|
|||
@log_function
|
||||
async def on_send_membership_event(
|
||||
self, origin: str, event: EventBase
|
||||
) -> EventContext:
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
"""
|
||||
We have received a join/leave/knock event for a room via send_join/leave/knock.
|
||||
|
||||
|
@ -1981,7 +2015,7 @@ class FederationHandler(BaseHandler):
|
|||
event: The member event that has been signed by the remote homeserver.
|
||||
|
||||
Returns:
|
||||
The context of the event after inserting it into the room graph.
|
||||
The event and context of the event after inserting it into the room graph.
|
||||
|
||||
Raises:
|
||||
SynapseError if the event is not accepted into the room
|
||||
|
@ -2037,7 +2071,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# all looks good, we can persist the event.
|
||||
await self._run_push_actions_and_persist_event(event, context)
|
||||
return context
|
||||
return event, context
|
||||
|
||||
async def _check_join_restrictions(
|
||||
self, context: EventContext, event: EventBase
|
||||
|
@ -2473,7 +2507,7 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
# Now check if event pass auth against said current state
|
||||
auth_types = auth_types_for_event(event)
|
||||
auth_types = auth_types_for_event(room_version_obj, event)
|
||||
current_state_ids_list = [
|
||||
e for k, e in current_state_ids.items() if k in auth_types
|
||||
]
|
||||
|
|
|
@ -16,7 +16,7 @@ import abc
|
|||
import logging
|
||||
import random
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from synapse import types
|
||||
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
|
||||
|
@ -28,6 +28,7 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
)
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.event_auth import get_named_level, get_power_level_event
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.types import (
|
||||
|
@ -340,16 +341,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
|
||||
if event.membership == Membership.JOIN:
|
||||
newly_joined = True
|
||||
prev_member_event = None
|
||||
if prev_member_event_id:
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
|
||||
# Check if the member should be allowed access via membership in a space.
|
||||
await self.event_auth_handler.check_restricted_join_rules(
|
||||
prev_state_ids, event.room_version, user_id, prev_member_event
|
||||
)
|
||||
|
||||
# Only rate-limit if the user actually joined the room, otherwise we'll end
|
||||
# up blocking profile updates.
|
||||
if newly_joined and ratelimit:
|
||||
|
@ -701,7 +696,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
if not is_host_in_room:
|
||||
# Check if a remote join should be performed.
|
||||
remote_join, remote_room_hosts = await self._should_perform_remote_join(
|
||||
target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
|
||||
)
|
||||
if remote_join:
|
||||
if ratelimit:
|
||||
time_now_s = self.clock.time()
|
||||
(
|
||||
|
@ -826,6 +825,106 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
outlier=outlier,
|
||||
)
|
||||
|
||||
async def _should_perform_remote_join(
|
||||
self,
|
||||
user_id: str,
|
||||
room_id: str,
|
||||
remote_room_hosts: List[str],
|
||||
content: JsonDict,
|
||||
is_host_in_room: bool,
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
Check whether the server should do a remote join (as opposed to a local
|
||||
join) for a user.
|
||||
|
||||
Generally a remote join is used if:
|
||||
|
||||
* The server is not yet in the room.
|
||||
* The server is in the room, the room has restricted join rules, the user
|
||||
is not joined or invited to the room, and the server does not have
|
||||
another user who is capable of issuing invites.
|
||||
|
||||
Args:
|
||||
user_id: The user joining the room.
|
||||
room_id: The room being joined.
|
||||
remote_room_hosts: A list of remote room hosts.
|
||||
content: The content to use as the event body of the join. This may
|
||||
be modified.
|
||||
is_host_in_room: True if the host is in the room.
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
True if a remote join should be performed. False if the join can be
|
||||
done locally.
|
||||
|
||||
A list of remote room hosts to use. This is an empty list if a
|
||||
local join is to be done.
|
||||
"""
|
||||
# If the host isn't in the room, pass through the prospective hosts.
|
||||
if not is_host_in_room:
|
||||
return True, remote_room_hosts
|
||||
|
||||
# If the host is in the room, but not one of the authorised hosts
|
||||
# for restricted join rules, a remote join must be used.
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
current_state_ids = await self.store.get_current_state_ids(room_id)
|
||||
|
||||
# If restricted join rules are not being used, a local join can always
|
||||
# be used.
|
||||
if not await self.event_auth_handler.has_restricted_join_rules(
|
||||
current_state_ids, room_version
|
||||
):
|
||||
return False, []
|
||||
|
||||
# If the user is invited to the room or already joined, the join
|
||||
# event can always be issued locally.
|
||||
prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
|
||||
prev_member_event = None
|
||||
if prev_member_event_id:
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
if prev_member_event.membership in (
|
||||
Membership.JOIN,
|
||||
Membership.INVITE,
|
||||
):
|
||||
return False, []
|
||||
|
||||
# If the local host has a user who can issue invites, then a local
|
||||
# join can be done.
|
||||
#
|
||||
# If not, generate a new list of remote hosts based on which
|
||||
# can issue invites.
|
||||
event_map = await self.store.get_events(current_state_ids.values())
|
||||
current_state = {
|
||||
state_key: event_map[event_id]
|
||||
for state_key, event_id in current_state_ids.items()
|
||||
}
|
||||
allowed_servers = get_servers_from_users(
|
||||
get_users_which_can_issue_invite(current_state)
|
||||
)
|
||||
|
||||
# If the local server is not one of allowed servers, then a remote
|
||||
# join must be done. Return the list of prospective servers based on
|
||||
# which can issue invites.
|
||||
if self.hs.hostname not in allowed_servers:
|
||||
return True, list(allowed_servers)
|
||||
|
||||
# Ensure the member should be allowed access via membership in a room.
|
||||
await self.event_auth_handler.check_restricted_join_rules(
|
||||
current_state_ids, room_version, user_id, prev_member_event
|
||||
)
|
||||
|
||||
# If this is going to be a local join, additional information must
|
||||
# be included in the event content in order to efficiently validate
|
||||
# the event.
|
||||
content[
|
||||
"join_authorised_via_users_server"
|
||||
] = await self.event_auth_handler.get_user_which_could_invite(
|
||||
room_id,
|
||||
current_state_ids,
|
||||
)
|
||||
|
||||
return False, []
|
||||
|
||||
async def transfer_room_state_on_room_upgrade(
|
||||
self, old_room_id: str, room_id: str
|
||||
) -> None:
|
||||
|
@ -1514,3 +1613,63 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
|
||||
if membership:
|
||||
await self.store.forget(user_id, room_id)
|
||||
|
||||
|
||||
def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]:
|
||||
"""
|
||||
Return the list of users which can issue invites.
|
||||
|
||||
This is done by exploring the joined users and comparing their power levels
|
||||
to the necessyar power level to issue an invite.
|
||||
|
||||
Args:
|
||||
auth_events: state in force at this point in the room
|
||||
|
||||
Returns:
|
||||
The users which can issue invites.
|
||||
"""
|
||||
invite_level = get_named_level(auth_events, "invite", 0)
|
||||
users_default_level = get_named_level(auth_events, "users_default", 0)
|
||||
power_level_event = get_power_level_event(auth_events)
|
||||
|
||||
# Custom power-levels for users.
|
||||
if power_level_event:
|
||||
users = power_level_event.content.get("users", {})
|
||||
else:
|
||||
users = {}
|
||||
|
||||
result = []
|
||||
|
||||
# Check which members are able to invite by ensuring they're joined and have
|
||||
# the necessary power level.
|
||||
for (event_type, state_key), event in auth_events.items():
|
||||
if event_type != EventTypes.Member:
|
||||
continue
|
||||
|
||||
if event.membership != Membership.JOIN:
|
||||
continue
|
||||
|
||||
# Check if the user has a custom power level.
|
||||
if users.get(state_key, users_default_level) >= invite_level:
|
||||
result.append(state_key)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_servers_from_users(users: List[str]) -> Set[str]:
|
||||
"""
|
||||
Resolve a list of users into their servers.
|
||||
|
||||
Args:
|
||||
users: A list of users.
|
||||
|
||||
Returns:
|
||||
A set of servers.
|
||||
"""
|
||||
servers = set()
|
||||
for user in users:
|
||||
try:
|
||||
servers.add(get_domain_from_id(user))
|
||||
except SynapseError:
|
||||
pass
|
||||
return servers
|
||||
|
|
|
@ -636,16 +636,20 @@ class StateResolutionHandler:
|
|||
"""
|
||||
try:
|
||||
with Measure(self.clock, "state._resolve_events") as m:
|
||||
v = KNOWN_ROOM_VERSIONS[room_version]
|
||||
if v.state_res == StateResolutionVersions.V1:
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
if room_version_obj.state_res == StateResolutionVersions.V1:
|
||||
return await v1.resolve_events_with_store(
|
||||
room_id, state_sets, event_map, state_res_store.get_events
|
||||
room_id,
|
||||
room_version_obj,
|
||||
state_sets,
|
||||
event_map,
|
||||
state_res_store.get_events,
|
||||
)
|
||||
else:
|
||||
return await v2.resolve_events_with_store(
|
||||
self.clock,
|
||||
room_id,
|
||||
room_version,
|
||||
room_version_obj,
|
||||
state_sets,
|
||||
event_map,
|
||||
state_res_store,
|
||||
|
|
|
@ -29,7 +29,7 @@ from typing import (
|
|||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.api.room_versions import RoomVersion, RoomVersions
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
|
||||
|
@ -41,6 +41,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
|
|||
|
||||
async def resolve_events_with_store(
|
||||
room_id: str,
|
||||
room_version: RoomVersion,
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
|
||||
|
@ -104,7 +105,7 @@ async def resolve_events_with_store(
|
|||
# get the ids of the auth events which allow us to authenticate the
|
||||
# conflicted state, picking only from the unconflicting state.
|
||||
auth_events = _create_auth_events_from_maps(
|
||||
unconflicted_state, conflicted_state, state_map
|
||||
room_version, unconflicted_state, conflicted_state, state_map
|
||||
)
|
||||
|
||||
new_needed_events = set(auth_events.values())
|
||||
|
@ -132,7 +133,7 @@ async def resolve_events_with_store(
|
|||
state_map.update(state_map_new)
|
||||
|
||||
return _resolve_with_state(
|
||||
unconflicted_state, conflicted_state, auth_events, state_map
|
||||
room_version, unconflicted_state, conflicted_state, auth_events, state_map
|
||||
)
|
||||
|
||||
|
||||
|
@ -187,6 +188,7 @@ def _seperate(
|
|||
|
||||
|
||||
def _create_auth_events_from_maps(
|
||||
room_version: RoomVersion,
|
||||
unconflicted_state: StateMap[str],
|
||||
conflicted_state: StateMap[Set[str]],
|
||||
state_map: Dict[str, EventBase],
|
||||
|
@ -194,6 +196,7 @@ def _create_auth_events_from_maps(
|
|||
"""
|
||||
|
||||
Args:
|
||||
room_version: The room version.
|
||||
unconflicted_state: The unconflicted state map.
|
||||
conflicted_state: The conflicted state map.
|
||||
state_map:
|
||||
|
@ -205,7 +208,9 @@ def _create_auth_events_from_maps(
|
|||
for event_ids in conflicted_state.values():
|
||||
for event_id in event_ids:
|
||||
if event_id in state_map:
|
||||
keys = event_auth.auth_types_for_event(state_map[event_id])
|
||||
keys = event_auth.auth_types_for_event(
|
||||
room_version, state_map[event_id]
|
||||
)
|
||||
for key in keys:
|
||||
if key not in auth_events:
|
||||
auth_event_id = unconflicted_state.get(key, None)
|
||||
|
@ -215,6 +220,7 @@ def _create_auth_events_from_maps(
|
|||
|
||||
|
||||
def _resolve_with_state(
|
||||
room_version: RoomVersion,
|
||||
unconflicted_state_ids: MutableStateMap[str],
|
||||
conflicted_state_ids: StateMap[Set[str]],
|
||||
auth_event_ids: StateMap[str],
|
||||
|
@ -235,7 +241,9 @@ def _resolve_with_state(
|
|||
}
|
||||
|
||||
try:
|
||||
resolved_state = _resolve_state_events(conflicted_state, auth_events)
|
||||
resolved_state = _resolve_state_events(
|
||||
room_version, conflicted_state, auth_events
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve state")
|
||||
raise
|
||||
|
@ -248,7 +256,9 @@ def _resolve_with_state(
|
|||
|
||||
|
||||
def _resolve_state_events(
|
||||
conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
|
||||
room_version: RoomVersion,
|
||||
conflicted_state: StateMap[List[EventBase]],
|
||||
auth_events: MutableStateMap[EventBase],
|
||||
) -> StateMap[EventBase]:
|
||||
"""This is where we actually decide which of the conflicted state to
|
||||
use.
|
||||
|
@ -263,21 +273,27 @@ def _resolve_state_events(
|
|||
if POWER_KEY in conflicted_state:
|
||||
events = conflicted_state[POWER_KEY]
|
||||
logger.debug("Resolving conflicted power levels %r", events)
|
||||
resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events)
|
||||
resolved_state[POWER_KEY] = _resolve_auth_events(
|
||||
room_version, events, auth_events
|
||||
)
|
||||
|
||||
auth_events.update(resolved_state)
|
||||
|
||||
for key, events in conflicted_state.items():
|
||||
if key[0] == EventTypes.JoinRules:
|
||||
logger.debug("Resolving conflicted join rules %r", events)
|
||||
resolved_state[key] = _resolve_auth_events(events, auth_events)
|
||||
resolved_state[key] = _resolve_auth_events(
|
||||
room_version, events, auth_events
|
||||
)
|
||||
|
||||
auth_events.update(resolved_state)
|
||||
|
||||
for key, events in conflicted_state.items():
|
||||
if key[0] == EventTypes.Member:
|
||||
logger.debug("Resolving conflicted member lists %r", events)
|
||||
resolved_state[key] = _resolve_auth_events(events, auth_events)
|
||||
resolved_state[key] = _resolve_auth_events(
|
||||
room_version, events, auth_events
|
||||
)
|
||||
|
||||
auth_events.update(resolved_state)
|
||||
|
||||
|
@ -290,12 +306,14 @@ def _resolve_state_events(
|
|||
|
||||
|
||||
def _resolve_auth_events(
|
||||
events: List[EventBase], auth_events: StateMap[EventBase]
|
||||
room_version: RoomVersion, events: List[EventBase], auth_events: StateMap[EventBase]
|
||||
) -> EventBase:
|
||||
reverse = list(reversed(_ordered_events(events)))
|
||||
|
||||
auth_keys = {
|
||||
key for event in events for key in event_auth.auth_types_for_event(event)
|
||||
key
|
||||
for event in events
|
||||
for key in event_auth.auth_types_for_event(room_version, event)
|
||||
}
|
||||
|
||||
new_auth_events = {}
|
||||
|
|
|
@ -36,7 +36,7 @@ import synapse.state
|
|||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
from synapse.util import Clock
|
||||
|
@ -53,7 +53,7 @@ _AWAIT_AFTER_ITERATIONS = 100
|
|||
async def resolve_events_with_store(
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
room_version: RoomVersion,
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
|
@ -497,7 +497,7 @@ async def _reverse_topological_power_sort(
|
|||
async def _iterative_auth_checks(
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
room_version: RoomVersion,
|
||||
event_ids: List[str],
|
||||
base_state: StateMap[str],
|
||||
event_map: Dict[str, EventBase],
|
||||
|
@ -519,7 +519,6 @@ async def _iterative_auth_checks(
|
|||
Returns the final updated state
|
||||
"""
|
||||
resolved_state = dict(base_state)
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
|
||||
for idx, event_id in enumerate(event_ids, start=1):
|
||||
event = event_map[event_id]
|
||||
|
@ -538,7 +537,7 @@ async def _iterative_auth_checks(
|
|||
if ev.rejected_reason is None:
|
||||
auth_events[(ev.type, ev.state_key)] = ev
|
||||
|
||||
for key in event_auth.auth_types_for_event(event):
|
||||
for key in event_auth.auth_types_for_event(room_version, event):
|
||||
if key in resolved_state:
|
||||
ev_id = resolved_state[key]
|
||||
ev = await _get_event(room_id, ev_id, event_map, state_res_store)
|
||||
|
@ -548,7 +547,7 @@ async def _iterative_auth_checks(
|
|||
|
||||
try:
|
||||
event_auth.check(
|
||||
room_version_obj,
|
||||
room_version,
|
||||
event,
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
|
|
|
@ -484,7 +484,7 @@ class StateTestCase(unittest.TestCase):
|
|||
state_d = resolve_events_with_store(
|
||||
FakeClock(),
|
||||
ROOM_ID,
|
||||
RoomVersions.V2.identifier,
|
||||
RoomVersions.V2,
|
||||
[state_at_event[n] for n in prev_events],
|
||||
event_map=event_map,
|
||||
state_res_store=TestStateResolutionStore(event_map),
|
||||
|
@ -496,7 +496,7 @@ class StateTestCase(unittest.TestCase):
|
|||
if fake_event.state_key is not None:
|
||||
state_after[(fake_event.type, fake_event.state_key)] = event_id
|
||||
|
||||
auth_types = set(auth_types_for_event(fake_event))
|
||||
auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event))
|
||||
|
||||
auth_events = []
|
||||
for key in auth_types:
|
||||
|
@ -633,7 +633,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
|
|||
state_d = resolve_events_with_store(
|
||||
FakeClock(),
|
||||
ROOM_ID,
|
||||
RoomVersions.V2.identifier,
|
||||
RoomVersions.V2,
|
||||
[self.state_at_bob, self.state_at_charlie],
|
||||
event_map=None,
|
||||
state_res_store=TestStateResolutionStore(self.event_map),
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
|
@ -234,8 +234,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
async def build(
|
||||
self,
|
||||
prev_event_ids,
|
||||
auth_event_ids,
|
||||
prev_event_ids: List[str],
|
||||
auth_event_ids: Optional[List[str]],
|
||||
depth: Optional[int] = None,
|
||||
):
|
||||
built_event = await self._base_builder.build(
|
||||
|
|
|
@ -351,7 +351,11 @@ class EventAuthTestCase(unittest.TestCase):
|
|||
"""
|
||||
Test joining a restricted room from MSC3083.
|
||||
|
||||
This is pretty much the same test as public.
|
||||
This is similar to the public test, but has some additional checks on
|
||||
signatures.
|
||||
|
||||
The checks which care about signatures fake them by simply adding an
|
||||
object of the proper form, not generating valid signatures.
|
||||
"""
|
||||
creator = "@creator:example.com"
|
||||
pleb = "@joiner:example.com"
|
||||
|
@ -359,6 +363,7 @@ class EventAuthTestCase(unittest.TestCase):
|
|||
auth_events = {
|
||||
("m.room.create", ""): _create_event(creator),
|
||||
("m.room.member", creator): _join_event(creator),
|
||||
("m.room.power_levels", ""): _power_levels_event(creator, {"invite": 0}),
|
||||
("m.room.join_rules", ""): _join_rules_event(creator, "restricted"),
|
||||
}
|
||||
|
||||
|
@ -371,19 +376,81 @@ class EventAuthTestCase(unittest.TestCase):
|
|||
do_sig_check=False,
|
||||
)
|
||||
|
||||
# Check join.
|
||||
# A properly formatted join event should work.
|
||||
authorised_join_event = _join_event(
|
||||
pleb,
|
||||
additional_content={
|
||||
"join_authorised_via_users_server": "@creator:example.com"
|
||||
},
|
||||
)
|
||||
event_auth.check(
|
||||
RoomVersions.MSC3083,
|
||||
_join_event(pleb),
|
||||
authorised_join_event,
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
|
||||
# A user cannot be force-joined to a room.
|
||||
# A join issued by a specific user works (i.e. the power level checks
|
||||
# are done properly).
|
||||
pl_auth_events = auth_events.copy()
|
||||
pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(
|
||||
creator, {"invite": 100, "users": {"@inviter:foo.test": 150}}
|
||||
)
|
||||
pl_auth_events[("m.room.member", "@inviter:foo.test")] = _join_event(
|
||||
"@inviter:foo.test"
|
||||
)
|
||||
event_auth.check(
|
||||
RoomVersions.MSC3083,
|
||||
_join_event(
|
||||
pleb,
|
||||
additional_content={
|
||||
"join_authorised_via_users_server": "@inviter:foo.test"
|
||||
},
|
||||
),
|
||||
pl_auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
|
||||
# A join which is missing an authorised server is rejected.
|
||||
with self.assertRaises(AuthError):
|
||||
event_auth.check(
|
||||
RoomVersions.MSC3083,
|
||||
_member_event(pleb, "join", sender=creator),
|
||||
_join_event(pleb),
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
|
||||
# An join authorised by a user who is not in the room is rejected.
|
||||
pl_auth_events = auth_events.copy()
|
||||
pl_auth_events[("m.room.power_levels", "")] = _power_levels_event(
|
||||
creator, {"invite": 100, "users": {"@other:example.com": 150}}
|
||||
)
|
||||
with self.assertRaises(AuthError):
|
||||
event_auth.check(
|
||||
RoomVersions.MSC3083,
|
||||
_join_event(
|
||||
pleb,
|
||||
additional_content={
|
||||
"join_authorised_via_users_server": "@other:example.com"
|
||||
},
|
||||
),
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
|
||||
# A user cannot be force-joined to a room. (This uses an event which
|
||||
# *would* be valid, but is sent be a different user.)
|
||||
with self.assertRaises(AuthError):
|
||||
event_auth.check(
|
||||
RoomVersions.MSC3083,
|
||||
_member_event(
|
||||
pleb,
|
||||
"join",
|
||||
sender=creator,
|
||||
additional_content={
|
||||
"join_authorised_via_users_server": "@inviter:foo.test"
|
||||
},
|
||||
),
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
|
@ -393,7 +460,7 @@ class EventAuthTestCase(unittest.TestCase):
|
|||
with self.assertRaises(AuthError):
|
||||
event_auth.check(
|
||||
RoomVersions.MSC3083,
|
||||
_join_event(pleb),
|
||||
authorised_join_event,
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
|
@ -402,12 +469,13 @@ class EventAuthTestCase(unittest.TestCase):
|
|||
auth_events[("m.room.member", pleb)] = _member_event(pleb, "leave")
|
||||
event_auth.check(
|
||||
RoomVersions.MSC3083,
|
||||
_join_event(pleb),
|
||||
authorised_join_event,
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
)
|
||||
|
||||
# A user can send a join if they're in the room.
|
||||
# A user can send a join if they're in the room. (This doesn't need to
|
||||
# be authorised since the user is already joined.)
|
||||
auth_events[("m.room.member", pleb)] = _member_event(pleb, "join")
|
||||
event_auth.check(
|
||||
RoomVersions.MSC3083,
|
||||
|
@ -416,7 +484,8 @@ class EventAuthTestCase(unittest.TestCase):
|
|||
do_sig_check=False,
|
||||
)
|
||||
|
||||
# A user can accept an invite.
|
||||
# A user can accept an invite. (This doesn't need to be authorised since
|
||||
# the user was invited.)
|
||||
auth_events[("m.room.member", pleb)] = _member_event(
|
||||
pleb, "invite", sender=creator
|
||||
)
|
||||
|
@ -446,7 +515,10 @@ def _create_event(user_id: str) -> EventBase:
|
|||
|
||||
|
||||
def _member_event(
|
||||
user_id: str, membership: str, sender: Optional[str] = None
|
||||
user_id: str,
|
||||
membership: str,
|
||||
sender: Optional[str] = None,
|
||||
additional_content: Optional[dict] = None,
|
||||
) -> EventBase:
|
||||
return make_event_from_dict(
|
||||
{
|
||||
|
@ -455,14 +527,14 @@ def _member_event(
|
|||
"type": "m.room.member",
|
||||
"sender": sender or user_id,
|
||||
"state_key": user_id,
|
||||
"content": {"membership": membership},
|
||||
"content": {"membership": membership, **(additional_content or {})},
|
||||
"prev_events": [],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _join_event(user_id: str) -> EventBase:
|
||||
return _member_event(user_id, "join")
|
||||
def _join_event(user_id: str, additional_content: Optional[dict] = None) -> EventBase:
|
||||
return _member_event(user_id, "join", additional_content=additional_content)
|
||||
|
||||
|
||||
def _power_levels_event(sender: str, content: JsonDict) -> EventBase:
|
||||
|
|
Loading…
Reference in New Issue