Add StateGroupStorage interface
parent
b7fe62b766
commit
5db03535d5
|
@ -30,6 +30,7 @@ stored in `synapse.storage.schema`.
|
||||||
from synapse.storage.data_stores import DataStores
|
from synapse.storage.data_stores import DataStores
|
||||||
from synapse.storage.data_stores.main import DataStore
|
from synapse.storage.data_stores.main import DataStore
|
||||||
from synapse.storage.persist_events import EventsPersistenceStorage
|
from synapse.storage.persist_events import EventsPersistenceStorage
|
||||||
|
from synapse.storage.state import StateGroupStorage
|
||||||
|
|
||||||
__all__ = ["DataStores", "DataStore"]
|
__all__ = ["DataStores", "DataStore"]
|
||||||
|
|
||||||
|
@ -45,6 +46,7 @@ class Storage(object):
|
||||||
self.main = stores.main
|
self.main = stores.main
|
||||||
|
|
||||||
self.persistence = EventsPersistenceStorage(hs, stores)
|
self.persistence = EventsPersistenceStorage(hs, stores)
|
||||||
|
self.state = StateGroupStorage(hs, stores)
|
||||||
|
|
||||||
|
|
||||||
def are_all_users_on_domain(txn, database_engine, domain):
|
def are_all_users_on_domain(txn, database_engine, domain):
|
||||||
|
|
|
@ -550,7 +550,7 @@ class EventsPersistenceStorage(object):
|
||||||
|
|
||||||
if missing_event_ids:
|
if missing_event_ids:
|
||||||
# Now pull out the state groups for any missing events from DB
|
# Now pull out the state groups for any missing events from DB
|
||||||
event_to_groups = yield self.state_store._get_state_group_for_events(
|
event_to_groups = yield self.main_store._get_state_group_for_events(
|
||||||
missing_event_ids
|
missing_event_ids
|
||||||
)
|
)
|
||||||
event_id_to_state_group.update(event_to_groups)
|
event_id_to_state_group.update(event_to_groups)
|
||||||
|
|
|
@ -19,6 +19,8 @@ from six import iteritems, itervalues
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -322,3 +324,233 @@ class StateFilter(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
return member_filter, non_member_filter
|
return member_filter, non_member_filter
|
||||||
|
|
||||||
|
|
||||||
|
class StateGroupStorage(object):
|
||||||
|
"""High level interface to fetching state for event.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs, stores):
|
||||||
|
self.stores = stores
|
||||||
|
|
||||||
|
def get_state_group_delta(self, state_group):
|
||||||
|
"""Given a state group try to return a previous group and a delta between
|
||||||
|
the old and the new.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(prev_group, delta_ids), where both may be None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.stores.main.get_state_group_delta(state_group)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_groups_ids(self, _room_id, event_ids):
|
||||||
|
"""Get the event IDs of all the state for the state groups for the given events
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_room_id (str): id of the room for these events
|
||||||
|
event_ids (iterable[str]): ids of the events
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||||
|
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||||
|
"""
|
||||||
|
if not event_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||||
|
|
||||||
|
groups = set(itervalues(event_to_groups))
|
||||||
|
group_to_state = yield self.stores.main._get_state_for_groups(groups)
|
||||||
|
|
||||||
|
return group_to_state
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_ids_for_group(self, state_group):
|
||||||
|
"""Get the event IDs of all the state in the given state group
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_group (int)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
|
||||||
|
"""
|
||||||
|
group_to_state = yield self._get_state_for_groups((state_group,))
|
||||||
|
|
||||||
|
return group_to_state[state_group]
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_groups(self, room_id, event_ids):
|
||||||
|
""" Get the state groups for the given list of event_ids
|
||||||
|
Returns:
|
||||||
|
Deferred[dict[int, list[EventBase]]]:
|
||||||
|
dict of state_group_id -> list of state events.
|
||||||
|
"""
|
||||||
|
if not event_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
|
||||||
|
|
||||||
|
state_event_map = yield self.stores.main.get_events(
|
||||||
|
[
|
||||||
|
ev_id
|
||||||
|
for group_ids in itervalues(group_to_ids)
|
||||||
|
for ev_id in itervalues(group_ids)
|
||||||
|
],
|
||||||
|
get_prev_content=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
group: [
|
||||||
|
state_event_map[v]
|
||||||
|
for v in itervalues(event_id_map)
|
||||||
|
if v in state_event_map
|
||||||
|
]
|
||||||
|
for group, event_id_map in iteritems(group_to_ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_state_groups_from_groups(self, groups, state_filter):
|
||||||
|
"""Returns the state groups for a given set of groups, filtering on
|
||||||
|
types of state events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
groups(list[int]): list of state group IDs to query
|
||||||
|
state_filter (StateFilter): The state filter used to fetch state
|
||||||
|
from the database.
|
||||||
|
Returns:
|
||||||
|
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||||
|
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self.stores.main._get_state_groups_from_groups(groups, state_filter)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
|
||||||
|
"""Given a list of event_ids and type tuples, return a list of state
|
||||||
|
dicts for each event.
|
||||||
|
Args:
|
||||||
|
event_ids (list[string])
|
||||||
|
state_filter (StateFilter): The state filter used to fetch state
|
||||||
|
from the database.
|
||||||
|
Returns:
|
||||||
|
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||||
|
"""
|
||||||
|
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||||
|
|
||||||
|
groups = set(itervalues(event_to_groups))
|
||||||
|
group_to_state = yield self.stores.main._get_state_for_groups(
|
||||||
|
groups, state_filter
|
||||||
|
)
|
||||||
|
|
||||||
|
state_event_map = yield self.stores.main.get_events(
|
||||||
|
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
|
||||||
|
get_prev_content=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
event_to_state = {
|
||||||
|
event_id: {
|
||||||
|
k: state_event_map[v]
|
||||||
|
for k, v in iteritems(group_to_state[group])
|
||||||
|
if v in state_event_map
|
||||||
|
}
|
||||||
|
for event_id, group in iteritems(event_to_groups)
|
||||||
|
}
|
||||||
|
|
||||||
|
return {event: event_to_state[event] for event in event_ids}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
|
||||||
|
"""
|
||||||
|
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||||
|
of the state events (as opposed to the events themselves)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_ids(list(str)): events whose state should be returned
|
||||||
|
state_filter (StateFilter): The state filter used to fetch state
|
||||||
|
from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A deferred dict from event_id -> (type, state_key) -> event_id
|
||||||
|
"""
|
||||||
|
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
||||||
|
|
||||||
|
groups = set(itervalues(event_to_groups))
|
||||||
|
group_to_state = yield self.stores.main._get_state_for_groups(
|
||||||
|
groups, state_filter
|
||||||
|
)
|
||||||
|
|
||||||
|
event_to_state = {
|
||||||
|
event_id: group_to_state[group]
|
||||||
|
for event_id, group in iteritems(event_to_groups)
|
||||||
|
}
|
||||||
|
|
||||||
|
return {event: event_to_state[event] for event in event_ids}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
|
||||||
|
"""
|
||||||
|
Get the state dict corresponding to a particular event
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id(str): event whose state should be returned
|
||||||
|
state_filter (StateFilter): The state filter used to fetch state
|
||||||
|
from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A deferred dict from (type, state_key) -> state_event
|
||||||
|
"""
|
||||||
|
state_map = yield self.get_state_for_events([event_id], state_filter)
|
||||||
|
return state_map[event_id]
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
|
||||||
|
"""
|
||||||
|
Get the state dict corresponding to a particular event
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id(str): event whose state should be returned
|
||||||
|
state_filter (StateFilter): The state filter used to fetch state
|
||||||
|
from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A deferred dict from (type, state_key) -> state_event
|
||||||
|
"""
|
||||||
|
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
|
||||||
|
return state_map[event_id]
|
||||||
|
|
||||||
|
def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
|
||||||
|
"""Gets the state at each of a list of state groups, optionally
|
||||||
|
filtering by type/state_key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
groups (iterable[int]): list of state groups for which we want
|
||||||
|
to get the state.
|
||||||
|
state_filter (StateFilter): The state filter used to fetch state
|
||||||
|
from the database.
|
||||||
|
Returns:
|
||||||
|
Deferred[dict[int, dict[tuple[str, str], str]]]:
|
||||||
|
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||||
|
"""
|
||||||
|
return self.stores.main._get_state_for_groups(groups, state_filter)
|
||||||
|
|
||||||
|
def store_state_group(
|
||||||
|
self, event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||||
|
):
|
||||||
|
"""Store a new set of state, returning a newly assigned state group.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event_id (str): The event ID for which the state was calculated
|
||||||
|
room_id (str)
|
||||||
|
prev_group (int|None): A previous state group for the room, optional.
|
||||||
|
delta_ids (dict|None): The delta between state at `prev_group` and
|
||||||
|
`current_state_ids`, if `prev_group` was given. Same format as
|
||||||
|
`current_state_ids`.
|
||||||
|
current_state_ids (dict): The state to store. Map of (type, state_key)
|
||||||
|
to event_id.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[int]: The state group ID
|
||||||
|
"""
|
||||||
|
return self.stores.main.store_state_group(
|
||||||
|
event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||||
|
)
|
||||||
|
|
Loading…
Reference in New Issue