Type annotations for `test_v2` (#12985)
parent
04ca3a52f6
commit
97053c9406
|
@ -0,0 +1 @@
|
||||||
|
Add type annotations to `tests.state.test_v2`.
|
4
mypy.ini
4
mypy.ini
|
@ -56,7 +56,6 @@ exclude = (?x)
|
||||||
|tests/rest/media/v1/test_media_storage.py
|
|tests/rest/media/v1/test_media_storage.py
|
||||||
|tests/server.py
|
|tests/server.py
|
||||||
|tests/server_notices/test_resource_limits_server_notices.py
|
|tests/server_notices/test_resource_limits_server_notices.py
|
||||||
|tests/state/test_v2.py
|
|
||||||
|tests/test_metrics.py
|
|tests/test_metrics.py
|
||||||
|tests/test_server.py
|
|tests/test_server.py
|
||||||
|tests/test_state.py
|
|tests/test_state.py
|
||||||
|
@ -115,6 +114,9 @@ disallow_untyped_defs = False
|
||||||
[mypy-tests.handlers.test_user_directory]
|
[mypy-tests.handlers.test_user_directory]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
[mypy-tests.state.test_profile]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-tests.storage.test_profile]
|
[mypy-tests.storage.test_profile]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
|
|
@ -17,12 +17,14 @@ import itertools
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
|
@ -30,33 +32,58 @@ from typing import (
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal, Protocol
|
||||||
|
|
||||||
import synapse.state
|
|
||||||
from synapse import event_auth
|
from synapse import event_auth
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import RoomVersion
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.types import MutableStateMap, StateMap
|
from synapse.types import MutableStateMap, StateMap
|
||||||
from synapse.util import Clock
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Clock(Protocol):
|
||||||
|
# This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
|
||||||
|
# We only ever sleep(0) though, so that other async functions can make forward
|
||||||
|
# progress without waiting for stateres to complete.
|
||||||
|
def sleep(self, duration_ms: float) -> Awaitable[None]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class StateResolutionStore(Protocol):
|
||||||
|
# This is usually synapse.state.StateResolutionStore, but it's replaced with a
|
||||||
|
# TestStateResolutionStore in tests.
|
||||||
|
def get_events(
|
||||||
|
self, event_ids: Collection[str], allow_rejected: bool = False
|
||||||
|
) -> Awaitable[Dict[str, EventBase]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_auth_chain_difference(
|
||||||
|
self, room_id: str, state_sets: List[Set[str]]
|
||||||
|
) -> Awaitable[Set[str]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# We want to await to the reactor occasionally during state res when dealing
|
# We want to await to the reactor occasionally during state res when dealing
|
||||||
# with large data sets, so that we don't exhaust the reactor. This is done by
|
# with large data sets, so that we don't exhaust the reactor. This is done by
|
||||||
# awaiting to reactor during loops every N iterations.
|
# awaiting to reactor during loops every N iterations.
|
||||||
_AWAIT_AFTER_ITERATIONS = 100
|
_AWAIT_AFTER_ITERATIONS = 100
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"resolve_events_with_store",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def resolve_events_with_store(
|
async def resolve_events_with_store(
|
||||||
clock: Clock,
|
clock: Clock,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
room_version: RoomVersion,
|
room_version: RoomVersion,
|
||||||
state_sets: Sequence[StateMap[str]],
|
state_sets: Sequence[StateMap[str]],
|
||||||
event_map: Optional[Dict[str, EventBase]],
|
event_map: Optional[Dict[str, EventBase]],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: StateResolutionStore,
|
||||||
) -> StateMap[str]:
|
) -> StateMap[str]:
|
||||||
"""Resolves the state using the v2 state resolution algorithm
|
"""Resolves the state using the v2 state resolution algorithm
|
||||||
|
|
||||||
|
@ -194,7 +221,7 @@ async def _get_power_level_for_sender(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
event_map: Dict[str, EventBase],
|
event_map: Dict[str, EventBase],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: StateResolutionStore,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Return the power level of the sender of the given event according to
|
"""Return the power level of the sender of the given event according to
|
||||||
their auth events.
|
their auth events.
|
||||||
|
@ -243,9 +270,9 @@ async def _get_power_level_for_sender(
|
||||||
|
|
||||||
async def _get_auth_chain_difference(
|
async def _get_auth_chain_difference(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
state_sets: Sequence[StateMap[str]],
|
state_sets: Sequence[Mapping[Any, str]],
|
||||||
event_map: Dict[str, EventBase],
|
event_map: Dict[str, EventBase],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: StateResolutionStore,
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""Compare the auth chains of each state set and return the set of events
|
"""Compare the auth chains of each state set and return the set of events
|
||||||
that only appear in some but not all of the auth chains.
|
that only appear in some but not all of the auth chains.
|
||||||
|
@ -406,7 +433,7 @@ async def _add_event_and_auth_chain_to_graph(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
event_map: Dict[str, EventBase],
|
event_map: Dict[str, EventBase],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: StateResolutionStore,
|
||||||
auth_diff: Set[str],
|
auth_diff: Set[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Helper function for _reverse_topological_power_sort that add the event
|
"""Helper function for _reverse_topological_power_sort that add the event
|
||||||
|
@ -440,7 +467,7 @@ async def _reverse_topological_power_sort(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_ids: Iterable[str],
|
event_ids: Iterable[str],
|
||||||
event_map: Dict[str, EventBase],
|
event_map: Dict[str, EventBase],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: StateResolutionStore,
|
||||||
auth_diff: Set[str],
|
auth_diff: Set[str],
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Returns a list of the event_ids sorted by reverse topological ordering,
|
"""Returns a list of the event_ids sorted by reverse topological ordering,
|
||||||
|
@ -501,7 +528,7 @@ async def _iterative_auth_checks(
|
||||||
event_ids: List[str],
|
event_ids: List[str],
|
||||||
base_state: StateMap[str],
|
base_state: StateMap[str],
|
||||||
event_map: Dict[str, EventBase],
|
event_map: Dict[str, EventBase],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: StateResolutionStore,
|
||||||
) -> MutableStateMap[str]:
|
) -> MutableStateMap[str]:
|
||||||
"""Sequentially apply auth checks to each event in given list, updating the
|
"""Sequentially apply auth checks to each event in given list, updating the
|
||||||
state as it goes along.
|
state as it goes along.
|
||||||
|
@ -570,7 +597,7 @@ async def _mainline_sort(
|
||||||
event_ids: List[str],
|
event_ids: List[str],
|
||||||
resolved_power_event_id: Optional[str],
|
resolved_power_event_id: Optional[str],
|
||||||
event_map: Dict[str, EventBase],
|
event_map: Dict[str, EventBase],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: StateResolutionStore,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Returns a sorted list of event_ids sorted by mainline ordering based on
|
"""Returns a sorted list of event_ids sorted by mainline ordering based on
|
||||||
the given event resolved_power_event_id
|
the given event resolved_power_event_id
|
||||||
|
@ -639,7 +666,7 @@ async def _get_mainline_depth_for_event(
|
||||||
event: EventBase,
|
event: EventBase,
|
||||||
mainline_map: Dict[str, int],
|
mainline_map: Dict[str, int],
|
||||||
event_map: Dict[str, EventBase],
|
event_map: Dict[str, EventBase],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: StateResolutionStore,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Get the mainline depths for the given event based on the mainline map
|
"""Get the mainline depths for the given event based on the mainline map
|
||||||
|
|
||||||
|
@ -683,7 +710,7 @@ async def _get_event(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
event_map: Dict[str, EventBase],
|
event_map: Dict[str, EventBase],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: StateResolutionStore,
|
||||||
allow_none: Literal[False] = False,
|
allow_none: Literal[False] = False,
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
...
|
...
|
||||||
|
@ -694,7 +721,7 @@ async def _get_event(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
event_map: Dict[str, EventBase],
|
event_map: Dict[str, EventBase],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: StateResolutionStore,
|
||||||
allow_none: Literal[True],
|
allow_none: Literal[True],
|
||||||
) -> Optional[EventBase]:
|
) -> Optional[EventBase]:
|
||||||
...
|
...
|
||||||
|
@ -704,7 +731,7 @@ async def _get_event(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
event_map: Dict[str, EventBase],
|
event_map: Dict[str, EventBase],
|
||||||
state_res_store: "synapse.state.StateResolutionStore",
|
state_res_store: StateResolutionStore,
|
||||||
allow_none: bool = False,
|
allow_none: bool = False,
|
||||||
) -> Optional[EventBase]:
|
) -> Optional[EventBase]:
|
||||||
"""Helper function to look up event in event_map, falling back to looking
|
"""Helper function to look up event in event_map, falling back to looking
|
||||||
|
|
|
@ -13,7 +13,17 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from typing import List
|
from typing import (
|
||||||
|
Collection,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -22,13 +32,13 @@ from twisted.internet import defer
|
||||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.event_auth import auth_types_for_event
|
from synapse.event_auth import auth_types_for_event
|
||||||
from synapse.events import make_event_from_dict
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
from synapse.state.v2 import (
|
from synapse.state.v2 import (
|
||||||
_get_auth_chain_difference,
|
_get_auth_chain_difference,
|
||||||
lexicographical_topological_sort,
|
lexicographical_topological_sort,
|
||||||
resolve_events_with_store,
|
resolve_events_with_store,
|
||||||
)
|
)
|
||||||
from synapse.types import EventID
|
from synapse.types import EventID, StateMap
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
@ -48,7 +58,7 @@ ORIGIN_SERVER_TS = 0
|
||||||
|
|
||||||
|
|
||||||
class FakeClock:
|
class FakeClock:
|
||||||
def sleep(self, msec):
|
def sleep(self, msec: float) -> "defer.Deferred[None]":
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,7 +70,14 @@ class FakeEvent:
|
||||||
as domain.
|
as domain.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, id, sender, type, state_key, content):
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
sender: str,
|
||||||
|
type: str,
|
||||||
|
state_key: Optional[str],
|
||||||
|
content: Mapping[str, object],
|
||||||
|
):
|
||||||
self.node_id = id
|
self.node_id = id
|
||||||
self.event_id = EventID(id, "example.com").to_string()
|
self.event_id = EventID(id, "example.com").to_string()
|
||||||
self.sender = sender
|
self.sender = sender
|
||||||
|
@ -69,12 +86,12 @@ class FakeEvent:
|
||||||
self.content = content
|
self.content = content
|
||||||
self.room_id = ROOM_ID
|
self.room_id = ROOM_ID
|
||||||
|
|
||||||
def to_event(self, auth_events, prev_events):
|
def to_event(self, auth_events: List[str], prev_events: List[str]) -> EventBase:
|
||||||
"""Given the auth_events and prev_events, convert to a Frozen Event
|
"""Given the auth_events and prev_events, convert to a Frozen Event
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
auth_events (list[str]): list of event_ids
|
auth_events: list of event_ids
|
||||||
prev_events (list[str]): list of event_ids
|
prev_events: list of event_ids
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FrozenEvent
|
FrozenEvent
|
||||||
|
@ -164,7 +181,7 @@ INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"]
|
||||||
|
|
||||||
|
|
||||||
class StateTestCase(unittest.TestCase):
|
class StateTestCase(unittest.TestCase):
|
||||||
def test_ban_vs_pl(self):
|
def test_ban_vs_pl(self) -> None:
|
||||||
events = [
|
events = [
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="PA",
|
id="PA",
|
||||||
|
@ -202,7 +219,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.do_check(events, edges, expected_state_ids)
|
self.do_check(events, edges, expected_state_ids)
|
||||||
|
|
||||||
def test_join_rule_evasion(self):
|
def test_join_rule_evasion(self) -> None:
|
||||||
events = [
|
events = [
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="JR",
|
id="JR",
|
||||||
|
@ -226,7 +243,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.do_check(events, edges, expected_state_ids)
|
self.do_check(events, edges, expected_state_ids)
|
||||||
|
|
||||||
def test_offtopic_pl(self):
|
def test_offtopic_pl(self) -> None:
|
||||||
events = [
|
events = [
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="PA",
|
id="PA",
|
||||||
|
@ -257,7 +274,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.do_check(events, edges, expected_state_ids)
|
self.do_check(events, edges, expected_state_ids)
|
||||||
|
|
||||||
def test_topic_basic(self):
|
def test_topic_basic(self) -> None:
|
||||||
events = [
|
events = [
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
|
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
|
||||||
|
@ -297,7 +314,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.do_check(events, edges, expected_state_ids)
|
self.do_check(events, edges, expected_state_ids)
|
||||||
|
|
||||||
def test_topic_reset(self):
|
def test_topic_reset(self) -> None:
|
||||||
events = [
|
events = [
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
|
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
|
||||||
|
@ -327,7 +344,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.do_check(events, edges, expected_state_ids)
|
self.do_check(events, edges, expected_state_ids)
|
||||||
|
|
||||||
def test_topic(self):
|
def test_topic(self) -> None:
|
||||||
events = [
|
events = [
|
||||||
FakeEvent(
|
FakeEvent(
|
||||||
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
|
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
|
||||||
|
@ -380,7 +397,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.do_check(events, edges, expected_state_ids)
|
self.do_check(events, edges, expected_state_ids)
|
||||||
|
|
||||||
def test_mainline_sort(self):
|
def test_mainline_sort(self) -> None:
|
||||||
"""Tests that the mainline ordering works correctly."""
|
"""Tests that the mainline ordering works correctly."""
|
||||||
|
|
||||||
events = [
|
events = [
|
||||||
|
@ -434,22 +451,26 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.do_check(events, edges, expected_state_ids)
|
self.do_check(events, edges, expected_state_ids)
|
||||||
|
|
||||||
def do_check(self, events, edges, expected_state_ids):
|
def do_check(
|
||||||
|
self,
|
||||||
|
events: List[FakeEvent],
|
||||||
|
edges: List[List[str]],
|
||||||
|
expected_state_ids: List[str],
|
||||||
|
) -> None:
|
||||||
"""Take a list of events and edges and calculate the state of the
|
"""Take a list of events and edges and calculate the state of the
|
||||||
graph at END, and asserts it matches `expected_state_ids`
|
graph at END, and asserts it matches `expected_state_ids`
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
events (list[FakeEvent])
|
events
|
||||||
edges (list[list[str]]): A list of chains of event edges, e.g.
|
edges: A list of chains of event edges, e.g.
|
||||||
`[[A, B, C]]` are edges A->B and B->C.
|
`[[A, B, C]]` are edges A->B and B->C.
|
||||||
expected_state_ids (list[str]): The expected state at END, (excluding
|
expected_state_ids: The expected state at END, (excluding
|
||||||
the keys that haven't changed since START).
|
the keys that haven't changed since START).
|
||||||
"""
|
"""
|
||||||
# We want to sort the events into topological order for processing.
|
# We want to sort the events into topological order for processing.
|
||||||
graph = {}
|
graph: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
# node_id -> FakeEvent
|
fake_event_map: Dict[str, FakeEvent] = {}
|
||||||
fake_event_map = {}
|
|
||||||
|
|
||||||
for ev in itertools.chain(INITIAL_EVENTS, events):
|
for ev in itertools.chain(INITIAL_EVENTS, events):
|
||||||
graph[ev.node_id] = set()
|
graph[ev.node_id] = set()
|
||||||
|
@ -462,10 +483,8 @@ class StateTestCase(unittest.TestCase):
|
||||||
for a, b in pairwise(edge_list):
|
for a, b in pairwise(edge_list):
|
||||||
graph[a].add(b)
|
graph[a].add(b)
|
||||||
|
|
||||||
# event_id -> FrozenEvent
|
event_map: Dict[str, EventBase] = {}
|
||||||
event_map = {}
|
state_at_event: Dict[str, StateMap[str]] = {}
|
||||||
# node_id -> state
|
|
||||||
state_at_event = {}
|
|
||||||
|
|
||||||
# We copy the map as the sort consumes the graph
|
# We copy the map as the sort consumes the graph
|
||||||
graph_copy = {k: set(v) for k, v in graph.items()}
|
graph_copy = {k: set(v) for k, v in graph.items()}
|
||||||
|
@ -496,7 +515,16 @@ class StateTestCase(unittest.TestCase):
|
||||||
if fake_event.state_key is not None:
|
if fake_event.state_key is not None:
|
||||||
state_after[(fake_event.type, fake_event.state_key)] = event_id
|
state_after[(fake_event.type, fake_event.state_key)] = event_id
|
||||||
|
|
||||||
auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event))
|
# This type ignore is a bit sad. Things we have tried:
|
||||||
|
# 1. Define a `GenericEvent` Protocol satisfied by FakeEvent, EventBase and
|
||||||
|
# EventBuilder. But this is Hard because the relevant attributes are
|
||||||
|
# DictProperty[T] descriptors on EventBase but normal Ts on FakeEvent.
|
||||||
|
# 2. Define a `GenericEvent` Protocol describing `FakeEvent` only, and
|
||||||
|
# change this function to accept Union[Event, EventBase, EventBuilder].
|
||||||
|
# This seems reasonable to me, but mypy isn't happy. I think that's
|
||||||
|
# a mypy bug, see https://github.com/python/mypy/issues/5570
|
||||||
|
# Instead, resort to a type-ignore.
|
||||||
|
auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event)) # type: ignore[arg-type]
|
||||||
|
|
||||||
auth_events = []
|
auth_events = []
|
||||||
for key in auth_types:
|
for key in auth_types:
|
||||||
|
@ -530,8 +558,14 @@ class StateTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class LexicographicalTestCase(unittest.TestCase):
|
class LexicographicalTestCase(unittest.TestCase):
|
||||||
def test_simple(self):
|
def test_simple(self) -> None:
|
||||||
graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}}
|
graph: Dict[str, Set[str]] = {
|
||||||
|
"l": {"o"},
|
||||||
|
"m": {"n", "o"},
|
||||||
|
"n": {"o"},
|
||||||
|
"o": set(),
|
||||||
|
"p": {"o"},
|
||||||
|
}
|
||||||
|
|
||||||
res = list(lexicographical_topological_sort(graph, key=lambda x: x))
|
res = list(lexicographical_topological_sort(graph, key=lambda x: x))
|
||||||
|
|
||||||
|
@ -539,7 +573,7 @@ class LexicographicalTestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
|
||||||
class SimpleParamStateTestCase(unittest.TestCase):
|
class SimpleParamStateTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self) -> None:
|
||||||
# We build up a simple DAG.
|
# We build up a simple DAG.
|
||||||
|
|
||||||
event_map = {}
|
event_map = {}
|
||||||
|
@ -627,7 +661,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_event_map_none(self):
|
def test_event_map_none(self) -> None:
|
||||||
# Test that we correctly handle passing `None` as the event_map
|
# Test that we correctly handle passing `None` as the event_map
|
||||||
|
|
||||||
state_d = resolve_events_with_store(
|
state_d = resolve_events_with_store(
|
||||||
|
@ -649,7 +683,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
|
||||||
events.
|
events.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_simple(self):
|
def test_simple(self) -> None:
|
||||||
# Test getting the auth difference for a simple chain with a single
|
# Test getting the auth difference for a simple chain with a single
|
||||||
# unpersisted event:
|
# unpersisted event:
|
||||||
#
|
#
|
||||||
|
@ -695,7 +729,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(difference, {c.event_id})
|
self.assertEqual(difference, {c.event_id})
|
||||||
|
|
||||||
def test_multiple_unpersisted_chain(self):
|
def test_multiple_unpersisted_chain(self) -> None:
|
||||||
# Test getting the auth difference for a simple chain with multiple
|
# Test getting the auth difference for a simple chain with multiple
|
||||||
# unpersisted events:
|
# unpersisted events:
|
||||||
#
|
#
|
||||||
|
@ -752,7 +786,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(difference, {d.event_id, c.event_id})
|
self.assertEqual(difference, {d.event_id, c.event_id})
|
||||||
|
|
||||||
def test_unpersisted_events_different_sets(self):
|
def test_unpersisted_events_different_sets(self) -> None:
|
||||||
# Test getting the auth difference for with multiple unpersisted events
|
# Test getting the auth difference for with multiple unpersisted events
|
||||||
# in different branches:
|
# in different branches:
|
||||||
#
|
#
|
||||||
|
@ -820,7 +854,10 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
|
||||||
self.assertEqual(difference, {d.event_id, e.event_id})
|
self.assertEqual(difference, {d.event_id, e.event_id})
|
||||||
|
|
||||||
|
|
||||||
def pairwise(iterable):
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def pairwise(iterable: Iterable[T]) -> Iterable[Tuple[T, T]]:
|
||||||
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
|
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
|
||||||
a, b = itertools.tee(iterable)
|
a, b = itertools.tee(iterable)
|
||||||
next(b, None)
|
next(b, None)
|
||||||
|
@ -829,24 +866,26 @@ def pairwise(iterable):
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
class TestStateResolutionStore:
|
class TestStateResolutionStore:
|
||||||
event_map = attr.ib()
|
event_map: Dict[str, EventBase] = attr.ib()
|
||||||
|
|
||||||
def get_events(self, event_ids, allow_rejected=False):
|
def get_events(
|
||||||
|
self, event_ids: Collection[str], allow_rejected: bool = False
|
||||||
|
) -> "defer.Deferred[Dict[str, EventBase]]":
|
||||||
"""Get events from the database
|
"""Get events from the database
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_ids (list): The event_ids of the events to fetch
|
event_ids: The event_ids of the events to fetch
|
||||||
allow_rejected (bool): If True return rejected events.
|
allow_rejected: If True return rejected events.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
|
Dict from event_id to event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return defer.succeed(
|
return defer.succeed(
|
||||||
{eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
|
{eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
|
def _get_auth_chain(self, event_ids: Iterable[str]) -> List[str]:
|
||||||
"""Gets the full auth chain for a set of events (including rejected
|
"""Gets the full auth chain for a set of events (including rejected
|
||||||
events).
|
events).
|
||||||
|
|
||||||
|
@ -880,7 +919,9 @@ class TestStateResolutionStore:
|
||||||
|
|
||||||
return list(result)
|
return list(result)
|
||||||
|
|
||||||
def get_auth_chain_difference(self, room_id, auth_sets):
|
def get_auth_chain_difference(
|
||||||
|
self, room_id: str, auth_sets: List[Set[str]]
|
||||||
|
) -> "defer.Deferred[Set[str]]":
|
||||||
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
|
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
|
||||||
|
|
||||||
common = set(chains[0]).intersection(*chains[1:])
|
common = set(chains[0]).intersection(*chains[1:])
|
||||||
|
|
Loading…
Reference in New Issue