Convert a synapse.events to async/await. (#7949)
parent
5f65e62681
commit
8553f46498
|
@ -1 +1 @@
|
||||||
Convert push to async/await.
|
Convert various parts of the codebase to async/await.
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Convert various parts of the codebase to async/await.
|
|
@ -1 +1 @@
|
||||||
Convert groups and visibility code to async / await.
|
Convert various parts of the codebase to async/await.
|
||||||
|
|
|
@ -82,7 +82,7 @@ class Auth(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
|
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
|
||||||
prev_state_ids = yield context.get_prev_state_ids()
|
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||||
auth_events_ids = yield self.compute_auth_events(
|
auth_events_ids = yield self.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,8 +17,6 @@ from typing import Optional
|
||||||
import attr
|
import attr
|
||||||
from nacl.signing import SigningKey
|
from nacl.signing import SigningKey
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import MAX_DEPTH
|
from synapse.api.constants import MAX_DEPTH
|
||||||
from synapse.api.errors import UnsupportedRoomVersionError
|
from synapse.api.errors import UnsupportedRoomVersionError
|
||||||
from synapse.api.room_versions import (
|
from synapse.api.room_versions import (
|
||||||
|
@ -95,31 +93,30 @@ class EventBuilder(object):
|
||||||
def is_state(self):
|
def is_state(self):
|
||||||
return self._state_key is not None
|
return self._state_key is not None
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def build(self, prev_event_ids):
|
||||||
def build(self, prev_event_ids):
|
|
||||||
"""Transform into a fully signed and hashed event
|
"""Transform into a fully signed and hashed event
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prev_event_ids (list[str]): The event IDs to use as the prev events
|
prev_event_ids (list[str]): The event IDs to use as the prev events
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[FrozenEvent]
|
FrozenEvent
|
||||||
"""
|
"""
|
||||||
|
|
||||||
state_ids = yield defer.ensureDeferred(
|
state_ids = await self._state.get_current_state_ids(
|
||||||
self._state.get_current_state_ids(self.room_id, prev_event_ids)
|
self.room_id, prev_event_ids
|
||||||
)
|
)
|
||||||
auth_ids = yield self._auth.compute_auth_events(self, state_ids)
|
auth_ids = await self._auth.compute_auth_events(self, state_ids)
|
||||||
|
|
||||||
format_version = self.room_version.event_format
|
format_version = self.room_version.event_format
|
||||||
if format_version == EventFormatVersions.V1:
|
if format_version == EventFormatVersions.V1:
|
||||||
auth_events = yield self._store.add_event_hashes(auth_ids)
|
auth_events = await self._store.add_event_hashes(auth_ids)
|
||||||
prev_events = yield self._store.add_event_hashes(prev_event_ids)
|
prev_events = await self._store.add_event_hashes(prev_event_ids)
|
||||||
else:
|
else:
|
||||||
auth_events = auth_ids
|
auth_events = auth_ids
|
||||||
prev_events = prev_event_ids
|
prev_events = prev_event_ids
|
||||||
|
|
||||||
old_depth = yield self._store.get_max_depth_of(prev_event_ids)
|
old_depth = await self._store.get_max_depth_of(prev_event_ids)
|
||||||
depth = old_depth + 1
|
depth = old_depth + 1
|
||||||
|
|
||||||
# we cap depth of generated events, to ensure that they are not
|
# we cap depth of generated events, to ensure that they are not
|
||||||
|
|
|
@ -12,17 +12,19 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.types import StateMap
|
from synapse.types import StateMap
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.storage.data_stores.main import DataStore
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
class EventContext:
|
class EventContext:
|
||||||
|
@ -129,8 +131,7 @@ class EventContext:
|
||||||
delta_ids=delta_ids,
|
delta_ids=delta_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def serialize(self, event: EventBase, store: "DataStore") -> dict:
|
||||||
def serialize(self, event, store):
|
|
||||||
"""Converts self to a type that can be serialized as JSON, and then
|
"""Converts self to a type that can be serialized as JSON, and then
|
||||||
deserialized by `deserialize`
|
deserialized by `deserialize`
|
||||||
|
|
||||||
|
@ -146,7 +147,7 @@ class EventContext:
|
||||||
# the prev_state_ids, so if we're a state event we include the event
|
# the prev_state_ids, so if we're a state event we include the event
|
||||||
# id that we replaced in the state.
|
# id that we replaced in the state.
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
prev_state_ids = yield self.get_prev_state_ids()
|
prev_state_ids = await self.get_prev_state_ids()
|
||||||
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
||||||
else:
|
else:
|
||||||
prev_state_id = None
|
prev_state_id = None
|
||||||
|
@ -214,8 +215,7 @@ class EventContext:
|
||||||
|
|
||||||
return self._state_group
|
return self._state_group
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_current_state_ids(self) -> Optional[StateMap[str]]:
|
||||||
def get_current_state_ids(self):
|
|
||||||
"""
|
"""
|
||||||
Gets the room state map, including this event - ie, the state in ``state_group``
|
Gets the room state map, including this event - ie, the state in ``state_group``
|
||||||
|
|
||||||
|
@ -224,32 +224,31 @@ class EventContext:
|
||||||
``rejected`` is set.
|
``rejected`` is set.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[(str, str), str]|None]: Returns None if state_group
|
Returns None if state_group is None, which happens when the associated
|
||||||
is None, which happens when the associated event is an outlier.
|
event is an outlier.
|
||||||
|
|
||||||
Maps a (type, state_key) to the event ID of the state event matching
|
Maps a (type, state_key) to the event ID of the state event matching
|
||||||
this tuple.
|
this tuple.
|
||||||
"""
|
"""
|
||||||
if self.rejected:
|
if self.rejected:
|
||||||
raise RuntimeError("Attempt to access state_ids of rejected event")
|
raise RuntimeError("Attempt to access state_ids of rejected event")
|
||||||
|
|
||||||
yield self._ensure_fetched()
|
await self._ensure_fetched()
|
||||||
return self._current_state_ids
|
return self._current_state_ids
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_prev_state_ids(self):
|
||||||
def get_prev_state_ids(self):
|
|
||||||
"""
|
"""
|
||||||
Gets the room state map, excluding this event.
|
Gets the room state map, excluding this event.
|
||||||
|
|
||||||
For a non-state event, this will be the same as get_current_state_ids().
|
For a non-state event, this will be the same as get_current_state_ids().
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[(str, str), str]|None]: Returns None if state_group
|
dict[(str, str), str]|None: Returns None if state_group
|
||||||
is None, which happens when the associated event is an outlier.
|
is None, which happens when the associated event is an outlier.
|
||||||
Maps a (type, state_key) to the event ID of the state event matching
|
Maps a (type, state_key) to the event ID of the state event matching
|
||||||
this tuple.
|
this tuple.
|
||||||
"""
|
"""
|
||||||
yield self._ensure_fetched()
|
await self._ensure_fetched()
|
||||||
return self._prev_state_ids
|
return self._prev_state_ids
|
||||||
|
|
||||||
def get_cached_current_state_ids(self):
|
def get_cached_current_state_ids(self):
|
||||||
|
@ -269,8 +268,8 @@ class EventContext:
|
||||||
|
|
||||||
return self._current_state_ids
|
return self._current_state_ids
|
||||||
|
|
||||||
def _ensure_fetched(self):
|
async def _ensure_fetched(self):
|
||||||
return defer.succeed(None)
|
return None
|
||||||
|
|
||||||
|
|
||||||
@attr.s(slots=True)
|
@attr.s(slots=True)
|
||||||
|
@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext):
|
||||||
_event_state_key = attr.ib(default=None)
|
_event_state_key = attr.ib(default=None)
|
||||||
_fetching_state_deferred = attr.ib(default=None)
|
_fetching_state_deferred = attr.ib(default=None)
|
||||||
|
|
||||||
def _ensure_fetched(self):
|
async def _ensure_fetched(self):
|
||||||
if not self._fetching_state_deferred:
|
if not self._fetching_state_deferred:
|
||||||
self._fetching_state_deferred = run_in_background(self._fill_out_state)
|
self._fetching_state_deferred = run_in_background(self._fill_out_state)
|
||||||
|
|
||||||
return make_deferred_yieldable(self._fetching_state_deferred)
|
return await make_deferred_yieldable(self._fetching_state_deferred)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _fill_out_state(self):
|
||||||
def _fill_out_state(self):
|
|
||||||
"""Called to populate the _current_state_ids and _prev_state_ids
|
"""Called to populate the _current_state_ids and _prev_state_ids
|
||||||
attributes by loading from the database.
|
attributes by loading from the database.
|
||||||
"""
|
"""
|
||||||
if self.state_group is None:
|
if self.state_group is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
|
self._current_state_ids = await self._storage.state.get_state_ids_for_group(
|
||||||
self.state_group
|
self.state_group
|
||||||
)
|
)
|
||||||
if self._event_state_key is not None:
|
if self._event_state_key is not None:
|
||||||
|
|
|
@ -13,7 +13,9 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
from synapse.events import EventBase
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
|
from synapse.types import Requester
|
||||||
|
|
||||||
|
|
||||||
class ThirdPartyEventRules(object):
|
class ThirdPartyEventRules(object):
|
||||||
|
@ -39,76 +41,79 @@ class ThirdPartyEventRules(object):
|
||||||
config=config, http_client=hs.get_simple_http_client()
|
config=config, http_client=hs.get_simple_http_client()
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def check_event_allowed(
|
||||||
def check_event_allowed(self, event, context):
|
self, event: EventBase, context: EventContext
|
||||||
|
) -> bool:
|
||||||
"""Check if a provided event should be allowed in the given context.
|
"""Check if a provided event should be allowed in the given context.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (synapse.events.EventBase): The event to be checked.
|
event: The event to be checked.
|
||||||
context (synapse.events.snapshot.EventContext): The context of the event.
|
context: The context of the event.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[bool]: True if the event should be allowed, False if not.
|
True if the event should be allowed, False if not.
|
||||||
"""
|
"""
|
||||||
if self.third_party_rules is None:
|
if self.third_party_rules is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids()
|
prev_state_ids = await context.get_prev_state_ids()
|
||||||
|
|
||||||
# Retrieve the state events from the database.
|
# Retrieve the state events from the database.
|
||||||
state_events = {}
|
state_events = {}
|
||||||
for key, event_id in prev_state_ids.items():
|
for key, event_id in prev_state_ids.items():
|
||||||
state_events[key] = yield self.store.get_event(event_id, allow_none=True)
|
state_events[key] = await self.store.get_event(event_id, allow_none=True)
|
||||||
|
|
||||||
ret = yield self.third_party_rules.check_event_allowed(event, state_events)
|
ret = await self.third_party_rules.check_event_allowed(event, state_events)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_create_room(
|
||||||
def on_create_room(self, requester, config, is_requester_admin):
|
self, requester: Requester, config: dict, is_requester_admin: bool
|
||||||
|
) -> bool:
|
||||||
"""Intercept requests to create room to allow, deny or update the
|
"""Intercept requests to create room to allow, deny or update the
|
||||||
request config.
|
request config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
requester (Requester)
|
requester
|
||||||
config (dict): The creation config from the client.
|
config: The creation config from the client.
|
||||||
is_requester_admin (bool): If the requester is an admin
|
is_requester_admin: If the requester is an admin
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[bool]: Whether room creation is allowed or denied.
|
Whether room creation is allowed or denied.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.third_party_rules is None:
|
if self.third_party_rules is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
ret = yield self.third_party_rules.on_create_room(
|
ret = await self.third_party_rules.on_create_room(
|
||||||
requester, config, is_requester_admin
|
requester, config, is_requester_admin
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def check_threepid_can_be_invited(
|
||||||
def check_threepid_can_be_invited(self, medium, address, room_id):
|
self, medium: str, address: str, room_id: str
|
||||||
|
) -> bool:
|
||||||
"""Check if a provided 3PID can be invited in the given room.
|
"""Check if a provided 3PID can be invited in the given room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
medium (str): The 3PID's medium.
|
medium: The 3PID's medium.
|
||||||
address (str): The 3PID's address.
|
address: The 3PID's address.
|
||||||
room_id (str): The room we want to invite the threepid to.
|
room_id: The room we want to invite the threepid to.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[bool], True if the 3PID can be invited, False if not.
|
True if the 3PID can be invited, False if not.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.third_party_rules is None:
|
if self.third_party_rules is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
state_ids = yield self.store.get_filtered_current_state_ids(room_id)
|
state_ids = await self.store.get_filtered_current_state_ids(room_id)
|
||||||
room_state_events = yield self.store.get_events(state_ids.values())
|
room_state_events = await self.store.get_events(state_ids.values())
|
||||||
|
|
||||||
state_events = {}
|
state_events = {}
|
||||||
for key, event_id in state_ids.items():
|
for key, event_id in state_ids.items():
|
||||||
state_events[key] = room_state_events[event_id]
|
state_events[key] = room_state_events[event_id]
|
||||||
|
|
||||||
ret = yield self.third_party_rules.check_threepid_can_be_invited(
|
ret = await self.third_party_rules.check_threepid_can_be_invited(
|
||||||
medium, address, state_events
|
medium, address, state_events
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
|
@ -18,8 +18,6 @@ from typing import Any, Mapping, Union
|
||||||
|
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, RelationTypes
|
from synapse.api.constants import EventTypes, RelationTypes
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import RoomVersion
|
||||||
|
@ -337,8 +335,9 @@ class EventClientSerializer(object):
|
||||||
hs.config.experimental_msc1849_support_enabled
|
hs.config.experimental_msc1849_support_enabled
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def serialize_event(
|
||||||
def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
|
self, event, time_now, bundle_aggregations=True, **kwargs
|
||||||
|
):
|
||||||
"""Serializes a single event.
|
"""Serializes a single event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -348,7 +347,7 @@ class EventClientSerializer(object):
|
||||||
**kwargs: Arguments to pass to `serialize_event`
|
**kwargs: Arguments to pass to `serialize_event`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict]: The serialized event
|
dict: The serialized event
|
||||||
"""
|
"""
|
||||||
# To handle the case of presence events and the like
|
# To handle the case of presence events and the like
|
||||||
if not isinstance(event, EventBase):
|
if not isinstance(event, EventBase):
|
||||||
|
@ -363,8 +362,8 @@ class EventClientSerializer(object):
|
||||||
if not event.internal_metadata.is_redacted() and (
|
if not event.internal_metadata.is_redacted() and (
|
||||||
self.experimental_msc1849_support_enabled and bundle_aggregations
|
self.experimental_msc1849_support_enabled and bundle_aggregations
|
||||||
):
|
):
|
||||||
annotations = yield self.store.get_aggregation_groups_for_event(event_id)
|
annotations = await self.store.get_aggregation_groups_for_event(event_id)
|
||||||
references = yield self.store.get_relations_for_event(
|
references = await self.store.get_relations_for_event(
|
||||||
event_id, RelationTypes.REFERENCE, direction="f"
|
event_id, RelationTypes.REFERENCE, direction="f"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -378,7 +377,7 @@ class EventClientSerializer(object):
|
||||||
|
|
||||||
edit = None
|
edit = None
|
||||||
if event.type == EventTypes.Message:
|
if event.type == EventTypes.Message:
|
||||||
edit = yield self.store.get_applicable_edit(event_id)
|
edit = await self.store.get_applicable_edit(event_id)
|
||||||
|
|
||||||
if edit:
|
if edit:
|
||||||
# If there is an edit replace the content, preserving existing
|
# If there is an edit replace the content, preserving existing
|
||||||
|
|
|
@ -2470,7 +2470,7 @@ class FederationHandler(BaseHandler):
|
||||||
}
|
}
|
||||||
|
|
||||||
current_state_ids = await context.get_current_state_ids()
|
current_state_ids = await context.get_current_state_ids()
|
||||||
current_state_ids = dict(current_state_ids)
|
current_state_ids = dict(current_state_ids) # type: ignore
|
||||||
|
|
||||||
current_state_ids.update(state_updates)
|
current_state_ids.update(state_updates)
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||||
"""
|
"""
|
||||||
event_payloads = []
|
event_payloads = []
|
||||||
for event, context in event_and_contexts:
|
for event, context in event_and_contexts:
|
||||||
serialized_context = yield context.serialize(event, store)
|
serialized_context = yield defer.ensureDeferred(
|
||||||
|
context.serialize(event, store)
|
||||||
|
)
|
||||||
|
|
||||||
event_payloads.append(
|
event_payloads.append(
|
||||||
{
|
{
|
||||||
|
|
|
@ -77,7 +77,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||||
extra_users (list(UserID)): Any extra users to notify about event
|
extra_users (list(UserID)): Any extra users to notify about event
|
||||||
"""
|
"""
|
||||||
|
|
||||||
serialized_context = yield context.serialize(event, store)
|
serialized_context = yield defer.ensureDeferred(context.serialize(event, store))
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"event": event.get_pdu_json(),
|
"event": event.get_pdu_json(),
|
||||||
|
|
|
@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def build(self, prev_event_ids):
|
def build(self, prev_event_ids):
|
||||||
built_event = yield self._base_builder.build(prev_event_ids)
|
built_event = yield defer.ensureDeferred(
|
||||||
|
self._base_builder.build(prev_event_ids)
|
||||||
|
)
|
||||||
|
|
||||||
built_event._event_id = self._event_id
|
built_event._event_id = self._event_id
|
||||||
built_event._dict["event_id"] = self._event_id
|
built_event._dict["event_id"] = self._event_id
|
||||||
|
|
|
@ -213,7 +213,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
ctx_c = context_store["C"]
|
ctx_c = context_store["C"]
|
||||||
ctx_d = context_store["D"]
|
ctx_d = context_store["D"]
|
||||||
|
|
||||||
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
|
||||||
self.assertEqual(2, len(prev_state_ids))
|
self.assertEqual(2, len(prev_state_ids))
|
||||||
|
|
||||||
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
|
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
|
||||||
|
@ -259,7 +259,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
ctx_c = context_store["C"]
|
ctx_c = context_store["C"]
|
||||||
ctx_d = context_store["D"]
|
ctx_d = context_store["D"]
|
||||||
|
|
||||||
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
|
||||||
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
|
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
|
||||||
|
|
||||||
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
|
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
|
||||||
|
@ -318,7 +318,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
ctx_c = context_store["C"]
|
ctx_c = context_store["C"]
|
||||||
ctx_e = context_store["E"]
|
ctx_e = context_store["E"]
|
||||||
|
|
||||||
prev_state_ids = yield ctx_e.get_prev_state_ids()
|
prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
|
||||||
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
|
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
|
||||||
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
|
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
|
||||||
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
|
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
|
||||||
|
@ -393,7 +393,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
ctx_b = context_store["B"]
|
ctx_b = context_store["B"]
|
||||||
ctx_d = context_store["D"]
|
ctx_d = context_store["D"]
|
||||||
|
|
||||||
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
|
||||||
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
|
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
|
||||||
|
|
||||||
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
|
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
|
||||||
|
@ -425,7 +425,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
self.state.compute_event_context(event, old_state=old_state)
|
self.state.compute_event_context(event, old_state=old_state)
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids()
|
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||||
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
|
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
|
||||||
|
|
||||||
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
||||||
|
@ -450,7 +450,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
self.state.compute_event_context(event, old_state=old_state)
|
self.state.compute_event_context(event, old_state=old_state)
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids()
|
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||||
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
|
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
|
||||||
|
|
||||||
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
||||||
|
@ -519,7 +519,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
|
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
|
||||||
|
|
||||||
prev_state_ids = yield context.get_prev_state_ids()
|
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||||
|
|
||||||
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
|
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue