Reduce the amount of state we pull from the DB (#12811)

pull/12963/head
Erik Johnston 2022-06-06 11:24:12 +03:00 committed by GitHub
parent 6b46c3eb3d
commit e3163e2e11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 162 additions and 147 deletions

1
changelog.d/12811.misc Normal file
View File

@ -0,0 +1 @@
Reduce the amount of state we pull from the DB.

View File

@ -29,12 +29,11 @@ from synapse.api.errors import (
MissingClientTokenError,
)
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import Requester, StateMap, UserID, create_requester
from synapse.types import Requester, UserID, create_requester
from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
@ -61,8 +60,8 @@ class Auth:
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
self.state = hs.get_state_handler()
self._account_validity_handler = hs.get_account_validity_handler()
self._storage_controllers = hs.get_storage_controllers()
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
10000, "token_cache"
@ -79,9 +78,8 @@ class Auth:
self,
room_id: str,
user_id: str,
current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False,
) -> EventBase:
) -> Tuple[str, Optional[str]]:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@ -99,29 +97,28 @@ class Auth:
Raises:
AuthError if the user is/was not in the room.
Returns:
Membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
The current membership of the user in the room and the
membership event ID of the user.
"""
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = await self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
if member:
membership = member.membership
(
membership,
member_event_id,
) = await self.store.get_local_current_membership_for_user_in_room(
user_id=user_id,
room_id=room_id,
)
if membership:
if membership == Membership.JOIN:
return member
return membership, member_event_id
# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
return member
return membership, member_event_id
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
@ -602,8 +599,11 @@ class Auth:
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
# m.room.canonical_alias events
power_level_event = await self.state.get_current_state(
room_id, EventTypes.PowerLevels, ""
power_level_event = (
await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.PowerLevels, ""
)
)
auth_events = {}
@ -693,12 +693,11 @@ class Auth:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = await self.check_user_in_room(
return await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users
)
return member_event.membership, member_event.event_id
except AuthError:
visibility = await self.state.get_current_state(
visibility = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (

View File

@ -53,6 +53,7 @@ class FederationBase:
self.spam_checker = hs.get_spam_checker()
self.store = hs.get_datastores().main
self._clock = hs.get_clock()
self._storage_controllers = hs.get_storage_controllers()
async def _check_sigs_and_hash(
self, room_version: RoomVersion, pdu: EventBase

View File

@ -1223,14 +1223,10 @@ class FederationServer(FederationBase):
Raises:
AuthError if the server does not match the ACL
"""
state_ids = await self._state_storage_controller.get_current_state_ids(room_id)
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
if not acl_event_id:
return
acl_event = await self.store.get_event(acl_event_id)
if server_matches_acl_event(server_name, acl_event):
acl_event = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.ServerACL, ""
)
if not acl_event or server_matches_acl_event(server_name, acl_event):
return
raise AuthError(code=403, msg="Server is banned from room")

View File

@ -320,7 +320,7 @@ class DirectoryHandler:
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
alias_event = await self.state.get_current_state(
alias_event = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.CanonicalAlias, ""
)

View File

@ -371,7 +371,7 @@ class FederationHandler:
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
curr_state = await self.state_handler.get_current_state(room_id)
curr_state = await self._storage_controllers.state.get_current_state(room_id)
curr_domains = get_domains_from_state(curr_state)

View File

@ -1584,9 +1584,11 @@ class FederationEventHandler:
if guest_access == GuestAccess.CAN_JOIN:
return
current_state_map = await self._state_handler.get_current_state(event.room_id)
current_state = list(current_state_map.values())
await self._get_room_member_handler().kick_guest_users(current_state)
current_state = await self._storage_controllers.state.get_current_state(
event.room_id
)
current_state_list = list(current_state.values())
await self._get_room_member_handler().kick_guest_users(current_state_list)
async def _check_for_soft_fail(
self,
@ -1614,6 +1616,9 @@ class FederationEventHandler:
room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# The event types we want to pull from the "current" state.
auth_types = auth_types_for_event(room_version_obj, event)
# Calculate the "current state".
if state_ids is not None:
# If we're explicitly given the state then we won't have all the
@ -1643,8 +1648,10 @@ class FederationEventHandler:
)
)
else:
current_state_ids = await self._state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
current_state_ids = (
await self._state_storage_controller.get_current_state_ids(
event.room_id, StateFilter.from_types(auth_types)
)
)
logger.debug(
@ -1654,7 +1661,6 @@ class FederationEventHandler:
)
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(room_version_obj, event)
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]

View File

@ -190,7 +190,7 @@ class InitialSyncHandler:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = run_in_background(
self.state_handler.get_current_state, event.room_id
self._state_storage_controller.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(
@ -407,7 +407,9 @@ class InitialSyncHandler:
membership: str,
is_peeking: bool,
) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id)
current_state = await self._storage_controllers.state.get_current_state(
room_id=room_id
)
# TODO: These concurrently
time_now = self.clock.time_msec()

View File

@ -125,7 +125,9 @@ class MessageHandler:
)
if membership == Membership.JOIN:
data = await self.state.get_current_state(room_id, event_type, state_key)
data = await self._storage_controllers.state.get_current_state_event(
room_id, event_type, state_key
)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
# If the membership is not JOIN, then the event ID should exist.

View File

@ -1333,6 +1333,7 @@ class TimestampLookupHandler:
self.store = hs.get_datastores().main
self.state_handler = hs.get_state_handler()
self.federation_client = hs.get_federation_client()
self._storage_controllers = hs.get_storage_controllers()
async def get_event_for_timestamp(
self,
@ -1406,7 +1407,9 @@ class TimestampLookupHandler:
)
# Find other homeservers from the given state in the room
curr_state = await self.state_handler.get_current_state(room_id)
curr_state = await self._storage_controllers.state.get_current_state(
room_id
)
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
domain for domain, depth in curr_domains if domain != self.server_name

View File

@ -1401,7 +1401,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> int:
room_state = await self.state_handler.get_current_state(room_id)
room_state = await self._storage_controllers.state.get_current_state(
room_id,
StateFilter.from_types(
[
(EventTypes.Member, user.to_string()),
(EventTypes.CanonicalAlias, ""),
(EventTypes.Name, ""),
(EventTypes.Create, ""),
(EventTypes.JoinRules, ""),
(EventTypes.RoomAvatar, ""),
]
),
)
inviter_display_name = ""
inviter_avatar_url = ""
@ -1797,7 +1809,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
async def forget(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
member = await self.state_handler.get_current_state(
member = await self._storage_controllers.state.get_current_state_event(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None

View File

@ -348,7 +348,7 @@ class SearchHandler:
state_results = {}
if include_state:
for room_id in {e.room_id for e in search_result.allowed_events}:
state = await self.state_handler.get_current_state(room_id)
state = await self._storage_controllers.state.get_current_state(room_id)
state_results[room_id] = list(state.values())
aggregations = await self._relations_handler.get_bundled_aggregations(

View File

@ -681,7 +681,7 @@ class Notifier:
return joined_room_ids, True
async def _is_world_readable(self, room_id: str) -> bool:
state = await self.state_handler.get_current_state(
state = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:

View File

@ -34,6 +34,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, RoomID, UserID, create_requester
from synapse.util import json_decoder
@ -448,7 +449,8 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
super().__init__(hs)
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.is_mine = hs.is_mine
async def on_POST(
@ -490,8 +492,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
)
# send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id)
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
join_rules_event = (
await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.JoinRules, ""
)
)
if join_rules_event:
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
# update_membership with an action of "invite" can raise a
@ -536,6 +541,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
super().__init__(hs)
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self._state_storage_controller = hs.get_storage_controllers().state
self.event_creation_handler = hs.get_event_creation_handler()
self.state_handler = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
@ -553,12 +559,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
user_to_add = content.get("user_id", requester.user.to_string())
# Figure out which local users currently have power in the room, if any.
room_state = await self.state_handler.get_current_state(room_id)
if not room_state:
filtered_room_state = await self._state_storage_controller.get_current_state(
room_id,
StateFilter.from_types(
[
(EventTypes.Create, ""),
(EventTypes.PowerLevels, ""),
(EventTypes.JoinRules, ""),
(EventTypes.Member, user_to_add),
]
),
)
if not filtered_room_state:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
create_event = room_state[(EventTypes.Create, "")]
power_levels = room_state.get((EventTypes.PowerLevels, ""))
create_event = filtered_room_state[(EventTypes.Create, "")]
power_levels = filtered_room_state.get((EventTypes.PowerLevels, ""))
if power_levels is not None:
# We pick the local user with the highest power.
@ -634,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
# Now we check if the user we're granting admin rights to is already in
# the room. If not and it's not a public room we invite them.
member_event = room_state.get((EventTypes.Member, user_to_add))
member_event = filtered_room_state.get((EventTypes.Member, user_to_add))
is_joined = False
if member_event:
is_joined = member_event.content["membership"] in (
@ -645,7 +661,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
if is_joined:
return HTTPStatus.OK, {}
join_rules = room_state.get((EventTypes.JoinRules, ""))
join_rules = filtered_room_state.get((EventTypes.JoinRules, ""))
is_public = False
if join_rules:
is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC

View File

@ -650,6 +650,7 @@ class RoomEventServlet(RestServlet):
self.clock = hs.get_clock()
self._store = hs.get_datastores().main
self._state = hs.get_state_handler()
self._storage_controllers = hs.get_storage_controllers()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
@ -673,8 +674,10 @@ class RoomEventServlet(RestServlet):
if include_unredacted_content and not await self.auth.is_server_admin(
requester.user
):
power_level_event = await self._state.get_current_state(
room_id, EventTypes.PowerLevels, ""
power_level_event = (
await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.PowerLevels, ""
)
)
auth_events = {}

View File

@ -36,6 +36,7 @@ class ResourceLimitsServerNotices:
def __init__(self, hs: "HomeServer"):
self._server_notices_manager = hs.get_server_notices_manager()
self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth()
self._config = hs.config
self._resouce_limited = False
@ -178,8 +179,10 @@ class ResourceLimitsServerNotices:
currently_blocked = False
pinned_state_event = None
try:
pinned_state_event = await self._state.get_current_state(
room_id, event_type=EventTypes.Pinned
pinned_state_event = (
await self._storage_controllers.state.get_current_state_event(
room_id, event_type=EventTypes.Pinned, state_key=""
)
)
except AuthError:
# The user has yet to join the server notices room

View File

@ -32,13 +32,11 @@ from typing import (
Set,
Tuple,
Union,
overload,
)
import attr
from frozendict import frozendict
from prometheus_client import Counter, Histogram
from typing_extensions import Literal
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@ -132,85 +130,20 @@ class StateHandler:
self._state_resolution_handler = hs.get_state_resolution_handler()
self._storage_controllers = hs.get_storage_controllers()
@overload
async def get_current_state(
self,
room_id: str,
event_type: Literal[None] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> StateMap[EventBase]:
...
@overload
async def get_current_state(
self,
room_id: str,
event_type: str,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Optional[EventBase]:
...
async def get_current_state(
self,
room_id: str,
event_type: Optional[str] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Union[Optional[EventBase], StateMap[EventBase]]:
"""Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
This is equivalent to getting the state of an event that were to send
next before receiving any new events.
Returns:
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
Otherwise, a map from (type, state_key) to event.
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
event = await self.store.get_event(event_id, allow_none=True)
return event
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
return {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
async def get_current_state_ids(
self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
self,
room_id: str,
latest_event_ids: Collection[str],
) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room
Args:
room_id:
latest_event_ids: if given, the forward extremities to resolve. If
None, we look them up from the database (via a cache).
latest_event_ids: The forward extremities to resolve.
Returns:
the state dict, mapping from (event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return ret.state

View File

@ -455,3 +455,30 @@ class StateStorageController:
return await self.stores.main.get_partial_current_state_deltas(
prev_stream_id, max_stream_id
)
async def get_current_state(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
"""Same as `get_current_state_ids` but also fetches the events"""
state_map_ids = await self.get_current_state_ids(room_id, state_filter)
event_map = await self.stores.main.get_events(list(state_map_ids.values()))
state_map = {}
for key, event_id in state_map_ids.items():
event = event_map.get(event_id)
if event:
state_map[key] = event
return state_map
async def get_current_state_event(
self, room_id: str, event_type: str, state_key: str
) -> Optional[EventBase]:
"""Get the current state event for the given type/state_key."""
key = (event_type, state_key)
state_map = await self.get_current_state(
room_id, StateFilter.from_types((key,))
)
return state_map.get(key)

View File

@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
super().prepare(reactor, clock, hs)
self._storage_controllers = hs.get_storage_controllers()
# create the room
creator_user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
self._storage_controllers.state.get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
self._storage_controllers.state.get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")

View File

@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
self._storage_controllers = hs.get_storage_controllers()
# Create user
self.admin_user = self.register_user("admin", "pass", admin=True)
@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
def _get_canonical_alias(self):
"""Get the canonical alias state of the room."""
return self.get_success(
self.state_handler.get_current_state(
self._storage_controllers.state.get_current_state_event(
self.room_id, EventTypes.CanonicalAlias, ""
)
)

View File

@ -32,6 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.state = self.hs.get_state_handler()
self._persistence = self.hs.get_storage_controllers().persistence
self._state_storage_controller = self.hs.get_storage_controllers().state
self.store = self.hs.get_datastores().main
self.register_user("user", "pass")
@ -104,7 +105,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
self._state_storage_controller.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap)
@ -137,7 +138,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
# setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state.
state_before_gap = dict(
self.get_success(self.state.get_current_state_ids(self.room_id))
self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
)
state_before_gap.pop(("m.room.history_visibility", ""))
@ -181,7 +184,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
self._state_storage_controller.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap)
@ -213,7 +216,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
self._state_storage_controller.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap)
@ -255,7 +258,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
self._state_storage_controller.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap)
@ -299,7 +302,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
self._state_storage_controller.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap)
@ -335,7 +338,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
)
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
self._state_storage_controller.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap)

View File

@ -102,9 +102,10 @@ class PurgeTests(HomeserverTestCase):
first = self.helper.send(self.room_id, body="test1")
# Get the current room state.
state_handler = self.hs.get_state_handler()
create_event = self.get_success(
state_handler.get_current_state(self.room_id, "m.room.create", "")
self._storage_controllers.state.get_current_state_event(
self.room_id, "m.room.create", ""
)
)
self.assertIsNotNone(create_event)

View File

@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
# Room events need the full datastore, for persist_event() and
# get_room_state()
self.store = hs.get_datastores().main
self._storage = hs.get_storage_controllers()
self._storage_controllers = hs.get_storage_controllers()
self.event_factory = hs.get_event_factory()
self.room = RoomID.from_string("!abcde:test")
@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
def inject_room_event(self, **kwargs):
self.get_success(
self._storage.persistence.persist_event(
self._storage_controllers.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
)
state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
self._storage_controllers.state.get_current_state(
room_id=self.room.to_string()
)
)
self.assertEqual(1, len(state))
@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase):
)
state = self.get_success(
self.store.get_current_state(room_id=self.room.to_string())
self._storage_controllers.state.get_current_state(
room_id=self.room.to_string()
)
)
self.assertEqual(1, len(state))