Reduce the amount of state we pull from the DB (#12811)
parent
6b46c3eb3d
commit
e3163e2e11
|
@ -0,0 +1 @@
|
|||
Reduce the amount of state we pull from the DB.
|
|
@ -29,12 +29,11 @@ from synapse.api.errors import (
|
|||
MissingClientTokenError,
|
||||
)
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.events import EventBase
|
||||
from synapse.http import get_request_user_agent
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
|
||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||
from synapse.types import Requester, StateMap, UserID, create_requester
|
||||
from synapse.types import Requester, UserID, create_requester
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||
|
||||
|
@ -61,8 +60,8 @@ class Auth:
|
|||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastores().main
|
||||
self.state = hs.get_state_handler()
|
||||
self._account_validity_handler = hs.get_account_validity_handler()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
|
||||
10000, "token_cache"
|
||||
|
@ -79,9 +78,8 @@ class Auth:
|
|||
self,
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
current_state: Optional[StateMap[EventBase]] = None,
|
||||
allow_departed_users: bool = False,
|
||||
) -> EventBase:
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""Check if the user is in the room, or was at some point.
|
||||
Args:
|
||||
room_id: The room to check.
|
||||
|
@ -99,29 +97,28 @@ class Auth:
|
|||
Raises:
|
||||
AuthError if the user is/was not in the room.
|
||||
Returns:
|
||||
Membership event for the user if the user was in the
|
||||
room. This will be the join event if they are currently joined to
|
||||
the room. This will be the leave event if they have left the room.
|
||||
The current membership of the user in the room and the
|
||||
membership event ID of the user.
|
||||
"""
|
||||
if current_state:
|
||||
member = current_state.get((EventTypes.Member, user_id), None)
|
||||
else:
|
||||
member = await self.state.get_current_state(
|
||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
||||
)
|
||||
|
||||
if member:
|
||||
membership = member.membership
|
||||
(
|
||||
membership,
|
||||
member_event_id,
|
||||
) = await self.store.get_local_current_membership_for_user_in_room(
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
)
|
||||
|
||||
if membership:
|
||||
if membership == Membership.JOIN:
|
||||
return member
|
||||
return membership, member_event_id
|
||||
|
||||
# XXX this looks totally bogus. Why do we not allow users who have been banned,
|
||||
# or those who were members previously and have been re-invited?
|
||||
if allow_departed_users and membership == Membership.LEAVE:
|
||||
forgot = await self.store.did_forget(user_id, room_id)
|
||||
if not forgot:
|
||||
return member
|
||||
return membership, member_event_id
|
||||
|
||||
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
||||
|
||||
|
@ -602,8 +599,11 @@ class Auth:
|
|||
# We currently require the user is a "moderator" in the room. We do this
|
||||
# by checking if they would (theoretically) be able to change the
|
||||
# m.room.canonical_alias events
|
||||
power_level_event = await self.state.get_current_state(
|
||||
room_id, EventTypes.PowerLevels, ""
|
||||
|
||||
power_level_event = (
|
||||
await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.PowerLevels, ""
|
||||
)
|
||||
)
|
||||
|
||||
auth_events = {}
|
||||
|
@ -693,12 +693,11 @@ class Auth:
|
|||
# * The user is a non-guest user, and was ever in the room
|
||||
# * The user is a guest user, and has joined the room
|
||||
# else it will throw.
|
||||
member_event = await self.check_user_in_room(
|
||||
return await self.check_user_in_room(
|
||||
room_id, user_id, allow_departed_users=allow_departed_users
|
||||
)
|
||||
return member_event.membership, member_event.event_id
|
||||
except AuthError:
|
||||
visibility = await self.state.get_current_state(
|
||||
visibility = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if (
|
||||
|
|
|
@ -53,6 +53,7 @@ class FederationBase:
|
|||
self.spam_checker = hs.get_spam_checker()
|
||||
self.store = hs.get_datastores().main
|
||||
self._clock = hs.get_clock()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
async def _check_sigs_and_hash(
|
||||
self, room_version: RoomVersion, pdu: EventBase
|
||||
|
|
|
@ -1223,14 +1223,10 @@ class FederationServer(FederationBase):
|
|||
Raises:
|
||||
AuthError if the server does not match the ACL
|
||||
"""
|
||||
state_ids = await self._state_storage_controller.get_current_state_ids(room_id)
|
||||
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
|
||||
|
||||
if not acl_event_id:
|
||||
return
|
||||
|
||||
acl_event = await self.store.get_event(acl_event_id)
|
||||
if server_matches_acl_event(server_name, acl_event):
|
||||
acl_event = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.ServerACL, ""
|
||||
)
|
||||
if not acl_event or server_matches_acl_event(server_name, acl_event):
|
||||
return
|
||||
|
||||
raise AuthError(code=403, msg="Server is banned from room")
|
||||
|
|
|
@ -320,7 +320,7 @@ class DirectoryHandler:
|
|||
Raises:
|
||||
ShadowBanError if the requester has been shadow-banned.
|
||||
"""
|
||||
alias_event = await self.state.get_current_state(
|
||||
alias_event = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
|
||||
|
|
|
@ -371,7 +371,7 @@ class FederationHandler:
|
|||
# First we try hosts that are already in the room
|
||||
# TODO: HEURISTIC ALERT.
|
||||
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
curr_state = await self._storage_controllers.state.get_current_state(room_id)
|
||||
|
||||
curr_domains = get_domains_from_state(curr_state)
|
||||
|
||||
|
|
|
@ -1584,9 +1584,11 @@ class FederationEventHandler:
|
|||
if guest_access == GuestAccess.CAN_JOIN:
|
||||
return
|
||||
|
||||
current_state_map = await self._state_handler.get_current_state(event.room_id)
|
||||
current_state = list(current_state_map.values())
|
||||
await self._get_room_member_handler().kick_guest_users(current_state)
|
||||
current_state = await self._storage_controllers.state.get_current_state(
|
||||
event.room_id
|
||||
)
|
||||
current_state_list = list(current_state.values())
|
||||
await self._get_room_member_handler().kick_guest_users(current_state_list)
|
||||
|
||||
async def _check_for_soft_fail(
|
||||
self,
|
||||
|
@ -1614,6 +1616,9 @@ class FederationEventHandler:
|
|||
room_version = await self._store.get_room_version_id(event.room_id)
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
|
||||
# The event types we want to pull from the "current" state.
|
||||
auth_types = auth_types_for_event(room_version_obj, event)
|
||||
|
||||
# Calculate the "current state".
|
||||
if state_ids is not None:
|
||||
# If we're explicitly given the state then we won't have all the
|
||||
|
@ -1643,8 +1648,10 @@ class FederationEventHandler:
|
|||
)
|
||||
)
|
||||
else:
|
||||
current_state_ids = await self._state_handler.get_current_state_ids(
|
||||
event.room_id, latest_event_ids=extrem_ids
|
||||
current_state_ids = (
|
||||
await self._state_storage_controller.get_current_state_ids(
|
||||
event.room_id, StateFilter.from_types(auth_types)
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
|
@ -1654,7 +1661,6 @@ class FederationEventHandler:
|
|||
)
|
||||
|
||||
# Now check if event pass auth against said current state
|
||||
auth_types = auth_types_for_event(room_version_obj, event)
|
||||
current_state_ids_list = [
|
||||
e for k, e in current_state_ids.items() if k in auth_types
|
||||
]
|
||||
|
|
|
@ -190,7 +190,7 @@ class InitialSyncHandler:
|
|||
if event.membership == Membership.JOIN:
|
||||
room_end_token = now_token.room_key
|
||||
deferred_room_state = run_in_background(
|
||||
self.state_handler.get_current_state, event.room_id
|
||||
self._state_storage_controller.get_current_state, event.room_id
|
||||
)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
room_end_token = RoomStreamToken(
|
||||
|
@ -407,7 +407,9 @@ class InitialSyncHandler:
|
|||
membership: str,
|
||||
is_peeking: bool,
|
||||
) -> JsonDict:
|
||||
current_state = await self.state.get_current_state(room_id=room_id)
|
||||
current_state = await self._storage_controllers.state.get_current_state(
|
||||
room_id=room_id
|
||||
)
|
||||
|
||||
# TODO: These concurrently
|
||||
time_now = self.clock.time_msec()
|
||||
|
|
|
@ -125,7 +125,9 @@ class MessageHandler:
|
|||
)
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
data = await self.state.get_current_state(room_id, event_type, state_key)
|
||||
data = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, event_type, state_key
|
||||
)
|
||||
elif membership == Membership.LEAVE:
|
||||
key = (event_type, state_key)
|
||||
# If the membership is not JOIN, then the event ID should exist.
|
||||
|
|
|
@ -1333,6 +1333,7 @@ class TimestampLookupHandler:
|
|||
self.store = hs.get_datastores().main
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.federation_client = hs.get_federation_client()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
async def get_event_for_timestamp(
|
||||
self,
|
||||
|
@ -1406,7 +1407,9 @@ class TimestampLookupHandler:
|
|||
)
|
||||
|
||||
# Find other homeservers from the given state in the room
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
curr_state = await self._storage_controllers.state.get_current_state(
|
||||
room_id
|
||||
)
|
||||
curr_domains = get_domains_from_state(curr_state)
|
||||
likely_domains = [
|
||||
domain for domain, depth in curr_domains if domain != self.server_name
|
||||
|
|
|
@ -1401,7 +1401,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
txn_id: Optional[str],
|
||||
id_access_token: Optional[str] = None,
|
||||
) -> int:
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
room_state = await self._storage_controllers.state.get_current_state(
|
||||
room_id,
|
||||
StateFilter.from_types(
|
||||
[
|
||||
(EventTypes.Member, user.to_string()),
|
||||
(EventTypes.CanonicalAlias, ""),
|
||||
(EventTypes.Name, ""),
|
||||
(EventTypes.Create, ""),
|
||||
(EventTypes.JoinRules, ""),
|
||||
(EventTypes.RoomAvatar, ""),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
inviter_display_name = ""
|
||||
inviter_avatar_url = ""
|
||||
|
@ -1797,7 +1809,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||
async def forget(self, user: UserID, room_id: str) -> None:
|
||||
user_id = user.to_string()
|
||||
|
||||
member = await self.state_handler.get_current_state(
|
||||
member = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
||||
)
|
||||
membership = member.membership if member else None
|
||||
|
|
|
@ -348,7 +348,7 @@ class SearchHandler:
|
|||
state_results = {}
|
||||
if include_state:
|
||||
for room_id in {e.room_id for e in search_result.allowed_events}:
|
||||
state = await self.state_handler.get_current_state(room_id)
|
||||
state = await self._storage_controllers.state.get_current_state(room_id)
|
||||
state_results[room_id] = list(state.values())
|
||||
|
||||
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||
|
|
|
@ -681,7 +681,7 @@ class Notifier:
|
|||
return joined_room_ids, True
|
||||
|
||||
async def _is_world_readable(self, room_id: str) -> bool:
|
||||
state = await self.state_handler.get_current_state(
|
||||
state = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if state and "history_visibility" in state.content:
|
||||
|
|
|
@ -34,6 +34,7 @@ from synapse.rest.admin._base import (
|
|||
assert_user_is_admin,
|
||||
)
|
||||
from synapse.storage.databases.main.room import RoomSortOrder
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import JsonDict, RoomID, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
|
||||
|
@ -448,7 +449,8 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
|||
super().__init__(hs)
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_POST(
|
||||
|
@ -490,8 +492,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
|||
)
|
||||
|
||||
# send invite if room has "JoinRules.INVITE"
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
|
||||
join_rules_event = (
|
||||
await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.JoinRules, ""
|
||||
)
|
||||
)
|
||||
if join_rules_event:
|
||||
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
|
||||
# update_membership with an action of "invite" can raise a
|
||||
|
@ -536,6 +541,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
|||
super().__init__(hs)
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self._state_storage_controller = hs.get_storage_controllers().state
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
@ -553,12 +559,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
|||
user_to_add = content.get("user_id", requester.user.to_string())
|
||||
|
||||
# Figure out which local users currently have power in the room, if any.
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
if not room_state:
|
||||
filtered_room_state = await self._state_storage_controller.get_current_state(
|
||||
room_id,
|
||||
StateFilter.from_types(
|
||||
[
|
||||
(EventTypes.Create, ""),
|
||||
(EventTypes.PowerLevels, ""),
|
||||
(EventTypes.JoinRules, ""),
|
||||
(EventTypes.Member, user_to_add),
|
||||
]
|
||||
),
|
||||
)
|
||||
if not filtered_room_state:
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
|
||||
|
||||
create_event = room_state[(EventTypes.Create, "")]
|
||||
power_levels = room_state.get((EventTypes.PowerLevels, ""))
|
||||
create_event = filtered_room_state[(EventTypes.Create, "")]
|
||||
power_levels = filtered_room_state.get((EventTypes.PowerLevels, ""))
|
||||
|
||||
if power_levels is not None:
|
||||
# We pick the local user with the highest power.
|
||||
|
@ -634,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
|||
|
||||
# Now we check if the user we're granting admin rights to is already in
|
||||
# the room. If not and it's not a public room we invite them.
|
||||
member_event = room_state.get((EventTypes.Member, user_to_add))
|
||||
member_event = filtered_room_state.get((EventTypes.Member, user_to_add))
|
||||
is_joined = False
|
||||
if member_event:
|
||||
is_joined = member_event.content["membership"] in (
|
||||
|
@ -645,7 +661,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
|||
if is_joined:
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
join_rules = room_state.get((EventTypes.JoinRules, ""))
|
||||
join_rules = filtered_room_state.get((EventTypes.JoinRules, ""))
|
||||
is_public = False
|
||||
if join_rules:
|
||||
is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
|
||||
|
|
|
@ -650,6 +650,7 @@ class RoomEventServlet(RestServlet):
|
|||
self.clock = hs.get_clock()
|
||||
self._store = hs.get_datastores().main
|
||||
self._state = hs.get_state_handler()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.event_handler = hs.get_event_handler()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self._relations_handler = hs.get_relations_handler()
|
||||
|
@ -673,8 +674,10 @@ class RoomEventServlet(RestServlet):
|
|||
if include_unredacted_content and not await self.auth.is_server_admin(
|
||||
requester.user
|
||||
):
|
||||
power_level_event = await self._state.get_current_state(
|
||||
room_id, EventTypes.PowerLevels, ""
|
||||
power_level_event = (
|
||||
await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.PowerLevels, ""
|
||||
)
|
||||
)
|
||||
|
||||
auth_events = {}
|
||||
|
|
|
@ -36,6 +36,7 @@ class ResourceLimitsServerNotices:
|
|||
def __init__(self, hs: "HomeServer"):
|
||||
self._server_notices_manager = hs.get_server_notices_manager()
|
||||
self._store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._auth = hs.get_auth()
|
||||
self._config = hs.config
|
||||
self._resouce_limited = False
|
||||
|
@ -178,8 +179,10 @@ class ResourceLimitsServerNotices:
|
|||
currently_blocked = False
|
||||
pinned_state_event = None
|
||||
try:
|
||||
pinned_state_event = await self._state.get_current_state(
|
||||
room_id, event_type=EventTypes.Pinned
|
||||
pinned_state_event = (
|
||||
await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, event_type=EventTypes.Pinned, state_key=""
|
||||
)
|
||||
)
|
||||
except AuthError:
|
||||
# The user has yet to join the server notices room
|
||||
|
|
|
@ -32,13 +32,11 @@ from typing import (
|
|||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
from prometheus_client import Counter, Histogram
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
|
||||
|
@ -132,85 +130,20 @@ class StateHandler:
|
|||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
@overload
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: Literal[None] = None,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> StateMap[EventBase]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> Optional[EventBase]:
|
||||
...
|
||||
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: Optional[str] = None,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> Union[Optional[EventBase], StateMap[EventBase]]:
|
||||
"""Retrieves the current state for the room. This is done by
|
||||
calling `get_latest_events_in_room` to get the leading edges of the
|
||||
event graph and then resolving any of the state conflicts.
|
||||
|
||||
This is equivalent to getting the state of an event that were to send
|
||||
next before receiving any new events.
|
||||
|
||||
Returns:
|
||||
If `event_type` is specified, then the method returns only the one
|
||||
event (or None) with that `event_type` and `state_key`.
|
||||
|
||||
Otherwise, a map from (type, state_key) to event.
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state")
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
state = ret.state
|
||||
|
||||
if event_type:
|
||||
event_id = state.get((event_type, state_key))
|
||||
event = None
|
||||
if event_id:
|
||||
event = await self.store.get_event(event_id, allow_none=True)
|
||||
return event
|
||||
|
||||
state_map = await self.store.get_events(
|
||||
list(state.values()), get_prev_content=False
|
||||
)
|
||||
return {
|
||||
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
|
||||
}
|
||||
|
||||
async def get_current_state_ids(
|
||||
self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
|
||||
self,
|
||||
room_id: str,
|
||||
latest_event_ids: Collection[str],
|
||||
) -> StateMap[str]:
|
||||
"""Get the current state, or the state at a set of events, for a room
|
||||
|
||||
Args:
|
||||
room_id:
|
||||
latest_event_ids: if given, the forward extremities to resolve. If
|
||||
None, we look them up from the database (via a cache).
|
||||
latest_event_ids: The forward extremities to resolve.
|
||||
|
||||
Returns:
|
||||
the state dict, mapping from (event_type, state_key) -> event_id
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
return ret.state
|
||||
|
|
|
@ -455,3 +455,30 @@ class StateStorageController:
|
|||
return await self.stores.main.get_partial_current_state_deltas(
|
||||
prev_stream_id, max_stream_id
|
||||
)
|
||||
|
||||
async def get_current_state(
|
||||
self, room_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[EventBase]:
|
||||
"""Same as `get_current_state_ids` but also fetches the events"""
|
||||
state_map_ids = await self.get_current_state_ids(room_id, state_filter)
|
||||
|
||||
event_map = await self.stores.main.get_events(list(state_map_ids.values()))
|
||||
|
||||
state_map = {}
|
||||
for key, event_id in state_map_ids.items():
|
||||
event = event_map.get(event_id)
|
||||
if event:
|
||||
state_map[key] = event
|
||||
|
||||
return state_map
|
||||
|
||||
async def get_current_state_event(
|
||||
self, room_id: str, event_type: str, state_key: str
|
||||
) -> Optional[EventBase]:
|
||||
"""Get the current state event for the given type/state_key."""
|
||||
|
||||
key = (event_type, state_key)
|
||||
state_map = await self.get_current_state(
|
||||
room_id, StateFilter.from_types((key,))
|
||||
)
|
||||
return state_map.get(key)
|
||||
|
|
|
@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
|||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
||||
super().prepare(reactor, clock, hs)
|
||||
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
# create the room
|
||||
creator_user_id = self.register_user("kermit", "test")
|
||||
tok = self.login("kermit", "test")
|
||||
|
@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
|||
|
||||
# the room should show that the new user is a member
|
||||
r = self.get_success(
|
||||
self.hs.get_state_handler().get_current_state(self._room_id)
|
||||
self._storage_controllers.state.get_current_state(self._room_id)
|
||||
)
|
||||
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
||||
|
||||
|
@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
|||
|
||||
# the room should show that the new user is a member
|
||||
r = self.get_success(
|
||||
self.hs.get_state_handler().get_current_state(self._room_id)
|
||||
self._storage_controllers.state.get_current_state(self._room_id)
|
||||
)
|
||||
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
||||
|
||||
|
|
|
@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||
self.store = hs.get_datastores().main
|
||||
self.handler = hs.get_directory_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
# Create user
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
|
@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
|||
def _get_canonical_alias(self):
|
||||
"""Get the canonical alias state of the room."""
|
||||
return self.get_success(
|
||||
self.state_handler.get_current_state(
|
||||
self._storage_controllers.state.get_current_state_event(
|
||||
self.room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
)
|
||||
|
|
|
@ -32,6 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
def prepare(self, reactor, clock, homeserver):
|
||||
self.state = self.hs.get_state_handler()
|
||||
self._persistence = self.hs.get_storage_controllers().persistence
|
||||
self._state_storage_controller = self.hs.get_storage_controllers().state
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
self.register_user("user", "pass")
|
||||
|
@ -104,7 +105,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self.state.get_current_state_ids(self.room_id)
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
@ -137,7 +138,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
# setting. The state resolution across the old and new event will then
|
||||
# include it, and so the resolved state won't match the new state.
|
||||
state_before_gap = dict(
|
||||
self.get_success(self.state.get_current_state_ids(self.room_id))
|
||||
self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
)
|
||||
state_before_gap.pop(("m.room.history_visibility", ""))
|
||||
|
||||
|
@ -181,7 +184,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self.state.get_current_state_ids(self.room_id)
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
@ -213,7 +216,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self.state.get_current_state_ids(self.room_id)
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
@ -255,7 +258,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self.state.get_current_state_ids(self.room_id)
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
@ -299,7 +302,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self.state.get_current_state_ids(self.room_id)
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
@ -335,7 +338,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self.state.get_current_state_ids(self.room_id)
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
|
|
@ -102,9 +102,10 @@ class PurgeTests(HomeserverTestCase):
|
|||
first = self.helper.send(self.room_id, body="test1")
|
||||
|
||||
# Get the current room state.
|
||||
state_handler = self.hs.get_state_handler()
|
||||
create_event = self.get_success(
|
||||
state_handler.get_current_state(self.room_id, "m.room.create", "")
|
||||
self._storage_controllers.state.get_current_state_event(
|
||||
self.room_id, "m.room.create", ""
|
||||
)
|
||||
)
|
||||
self.assertIsNotNone(create_event)
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
|||
# Room events need the full datastore, for persist_event() and
|
||||
# get_room_state()
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage = hs.get_storage_controllers()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.event_factory = hs.get_event_factory()
|
||||
|
||||
self.room = RoomID.from_string("!abcde:test")
|
||||
|
@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
|||
|
||||
def inject_room_event(self, **kwargs):
|
||||
self.get_success(
|
||||
self._storage.persistence.persist_event(
|
||||
self._storage_controllers.persistence.persist_event(
|
||||
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
|
||||
)
|
||||
)
|
||||
|
@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
state = self.get_success(
|
||||
self.store.get_current_state(room_id=self.room.to_string())
|
||||
self._storage_controllers.state.get_current_state(
|
||||
room_id=self.room.to_string()
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(1, len(state))
|
||||
|
@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
|
|||
)
|
||||
|
||||
state = self.get_success(
|
||||
self.store.get_current_state(room_id=self.room.to_string())
|
||||
self._storage_controllers.state.get_current_state(
|
||||
room_id=self.room.to_string()
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(1, len(state))
|
||||
|
|
Loading…
Reference in New Issue