Add StateMap type alias (#6715)

pull/6723/head
Erik Johnston 2020-01-16 13:31:22 +00:00 committed by GitHub
parent 7b14c4a018
commit d386f2f339
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 115 additions and 93 deletions

1
changelog.d/6715.misc Normal file
View File

@ -0,0 +1 @@
Add StateMap type alias to simplify types.

View File

@ -14,7 +14,6 @@
# limitations under the License.
import logging
from typing import Dict, Tuple
from six import itervalues
@ -35,7 +34,7 @@ from synapse.api.errors import (
ResourceLimitError,
)
from synapse.config.server import is_threepid_reserved
from synapse.types import UserID
from synapse.types import StateMap, UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@ -509,10 +508,7 @@ class Auth(object):
return self.store.is_server_admin(user)
def compute_auth_events(
self,
event,
current_state_ids: Dict[Tuple[str, str], str],
for_verification: bool = False,
self, event, current_state_ids: StateMap[str], for_verification: bool = False,
):
"""Given an event and current state return the list of event IDs used
to auth an event.

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Union
from six import iteritems
@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.appservice import ApplicationService
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap
@attr.s(slots=True)
@ -106,13 +107,11 @@ class EventContext:
_state_group = attr.ib(default=None, type=Optional[int])
state_group_before_event = attr.ib(default=None, type=Optional[int])
prev_group = attr.ib(default=None, type=Optional[int])
delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
delta_ids = attr.ib(default=None, type=Optional[StateMap[str]])
app_service = attr.ib(default=None, type=Optional[ApplicationService])
_current_state_ids = attr.ib(
default=None, type=Optional[Dict[Tuple[str, str], str]]
)
_prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
_current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
_prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
@staticmethod
def with_state(

View File

@ -31,6 +31,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import StateMap
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
# This is defined in the Matrix spec and enforced by the receiver.
@ -77,7 +78,7 @@ class PerDestinationQueue(object):
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu
self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu]
self._pending_edus_keyed = {} # type: StateMap[Edu]
# Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination

View File

@ -14,9 +14,11 @@
# limitations under the License.
import logging
from typing import List
from synapse.api.constants import Membership
from synapse.types import RoomStreamToken
from synapse.events import FrozenEvent
from synapse.types import RoomStreamToken, StateMap
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@ -259,35 +261,26 @@ class ExfiltrationWriter(object):
"""Interface used to specify how to write exported data.
"""
def write_events(self, room_id, events):
def write_events(self, room_id: str, events: List[FrozenEvent]):
"""Write a batch of events for a room.
Args:
room_id (str)
events (list[FrozenEvent])
"""
pass
def write_state(self, room_id, event_id, state):
def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
"""Write the state at the given event in the room.
This only gets called for backward extremities rather than for each
event.
Args:
room_id (str)
event_id (str)
state (dict[tuple[str, str], FrozenEvent])
"""
pass
def write_invite(self, room_id, event, state):
def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
"""Write an invite for the room, with associated invite state.
Args:
room_id (str)
event (FrozenEvent)
state (dict[tuple[str, str], dict]): A subset of the state at the
room_id
event
state: A subset of the state at the
invite, with a subset of the event keys (type, state_key
content and sender)
"""

View File

@ -64,7 +64,7 @@ from synapse.replication.http.federation import (
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id
from synapse.types import StateMap, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
@ -89,7 +89,7 @@ class _NewEventInfo:
event = attr.ib(type=EventBase)
state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None)
auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
def shortstr(iterable, maxitems=5):
@ -352,9 +352,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(
ours.values()
) # type: list[dict[tuple[str, str], str]]
state_maps = list(ours.values()) # type: list[StateMap[str]]
# we don't need this any more, let's delete it.
del ours
@ -1912,7 +1910,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
auth_events: Optional[Dict[Tuple[str, str], EventBase]],
auth_events: Optional[StateMap[EventBase]],
backfilled: bool,
):
"""

View File

@ -32,7 +32,15 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, Syna
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.types import (
Requester,
RoomAlias,
RoomID,
RoomStreamToken,
StateMap,
StreamToken,
UserID,
)
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache
@ -207,15 +215,19 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def _update_upgraded_room_pls(
self, requester, old_room_id, new_room_id, old_room_state,
self,
requester: Requester,
old_room_id: str,
new_room_id: str,
old_room_state: StateMap[str],
):
"""Send updated power levels in both rooms after an upgrade
Args:
requester (synapse.types.Requester): the user requesting the upgrade
old_room_id (str): the id of the room to be replaced
new_room_id (str): the id of the replacement room
old_room_state (dict[tuple[str, str], str]): the state map for the old room
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
new_room_id: the id of the replacement room
old_room_state: the state map for the old room
Returns:
Deferred

View File

@ -16,7 +16,7 @@
import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional
from six import iteritems, itervalues
@ -33,6 +33,7 @@ from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
@ -594,7 +595,7 @@ def _make_state_cache_entry(new_state, state_groups_ids):
def resolve_events_with_store(
room_id: str,
room_version: str,
state_sets: List[Dict[Tuple[str, str], str]],
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):

View File

@ -15,7 +15,7 @@
import hashlib
import logging
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional
from six import iteritems, iterkeys, itervalues
@ -26,6 +26,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.types import StateMap
logger = logging.getLogger(__name__)
@ -36,7 +37,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
@defer.inlineCallbacks
def resolve_events_with_store(
room_id: str,
state_sets: List[Dict[Tuple[str, str], str]],
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable,
):

View File

@ -16,7 +16,7 @@
import heapq
import itertools
import logging
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
from six import iteritems, itervalues
@ -27,6 +27,7 @@ from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.types import StateMap
logger = logging.getLogger(__name__)
@ -35,7 +36,7 @@ logger = logging.getLogger(__name__)
def resolve_events_with_store(
room_id: str,
room_version: str,
state_sets: List[Dict[Tuple[str, str], str]],
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
):
@ -393,12 +394,12 @@ def _iterative_auth_checks(
room_id (str)
room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to
base_state (dict[tuple[str, str], str]): The set of state to start with
base_state (StateMap[str]): The set of state to start with
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
Returns:
Deferred[dict[tuple[str, str], str]]: Returns the final updated state
Deferred[StateMap[str]]: Returns the final updated state
"""
resolved_state = base_state.copy()

View File

@ -165,19 +165,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
# FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all()
):
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
Args:
room_id (str)
state_filter (StateFilter): The state filter used to fetch state
room_id
state_filter: The state filter used to fetch state
from the database.
Returns:
Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
event ID.
defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
"""
where_clause, where_args = state_filter.make_sql_filter_clause()

View File

@ -15,6 +15,7 @@
import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple
from six import iteritems
from six.moves import range
@ -26,6 +27,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.database import Database
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@ -133,17 +135,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
@defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, state_filter):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
):
"""Returns the state groups for a given set of groups from the
database, filtering on types of state events.
Args:
groups(list[int]): list of state group IDs to query
state_filter (StateFilter): The state filter used to fetch state
groups: list of state group IDs to query
state_filter: The state filter used to fetch state
from the database.
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
results = {}
@ -199,18 +202,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
groups (iterable[int]): list of state groups for which we want
groups: list of state groups for which we want
to get the state.
state_filter (StateFilter): The state filter used to fetch state
state_filter: The state filter used to fetch state
from the database.
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
member_filter, non_member_filter = state_filter.get_member_split()
@ -268,24 +272,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state
def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
def _get_state_for_groups_using_cache(
self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
cache (DictionaryCache): the cache of group ids to state dicts which
we will pass through - either the normal state cache or the specific
members state cache.
state_filter (StateFilter): The state filter used to fetch state
from the database.
groups: list of state groups for which we want to get the state.
cache: the cache of group ids to state dicts which
we will pass through - either the normal state cache or the
specific members state cache.
state_filter: The state filter used to fetch state from the
database.
Returns:
tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
dict of state_group_id -> (dict of (type, state_key) -> event id)
of entries in the cache, and the state group ids either missing
from the cache or incomplete.
Tuple of dict of state_group_id to state map of entries in the
cache, and the state group ids either missing from the cache or
incomplete.
"""
results = {}
incomplete_groups = set()

View File

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import Iterable, List, TypeVar
from six import iteritems, itervalues
@ -22,9 +23,13 @@ import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.types import StateMap
logger = logging.getLogger(__name__)
# Used for generic functions below
T = TypeVar("T")
@attr.s(slots=True)
class StateFilter(object):
@ -233,14 +238,14 @@ class StateFilter(object):
return len(self.concrete_types())
def filter_state(self, state_dict):
def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
"""Returns the state filtered with by this StateFilter
Args:
state (dict[tuple[str, str], Any]): The state map to filter
state: The state map to filter
Returns:
dict[tuple[str, str], Any]: The filtered state map
The filtered state map
"""
if self.is_full():
return dict(state_dict)
@ -333,12 +338,12 @@ class StateGroupStorage(object):
def __init__(self, hs, stores):
self.stores = stores
def get_state_group_delta(self, state_group):
def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]):
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
(prev_group, delta_ids)
"""
@ -353,7 +358,7 @@ class StateGroupStorage(object):
event_ids (iterable[str]): ids of the events
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
@ -410,17 +415,18 @@ class StateGroupStorage(object):
for group, event_id_map in iteritems(group_to_ids)
}
def _get_state_groups_from_groups(self, groups, state_filter):
def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
groups(list[int]): list of state group IDs to query
state_filter (StateFilter): The state filter used to fetch state
groups: list of state group IDs to query
state_filter: The state filter used to fetch state
from the database.
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@ -519,7 +525,9 @@ class StateGroupStorage(object):
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@ -529,8 +537,7 @@ class StateGroupStorage(object):
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
"""
return self.stores.state._get_state_for_groups(groups, state_filter)

View File

@ -17,6 +17,7 @@ import re
import string
import sys
from collections import namedtuple
from typing import Dict, Tuple, TypeVar
import attr
from signedjson.key import decode_verify_key_bytes
@ -28,7 +29,7 @@ from synapse.api.errors import SynapseError
if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection
else:
from typing import Sized, Iterable, Container, TypeVar
from typing import Sized, Iterable, Container
T_co = TypeVar("T_co", covariant=True)
@ -36,6 +37,12 @@ else:
__slots__ = ()
# Define a state map type from type/state_key to T (usually an event ID or
# event)
T = TypeVar("T")
StateMap = Dict[Tuple[str, str], T]
class Requester(
namedtuple(
"Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]