diff --git a/synapse/state.py b/synapse/state.py index 8144fa02b4..081bc31bb5 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext from collections import namedtuple @@ -36,12 +37,16 @@ def _get_state_key_from_event(event): KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) +AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,) + + class StateHandler(object): """ Responsible for doing state conflict resolution. """ def __init__(self, hs): self.store = hs.get_datastore() + self.hs = hs @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): @@ -210,64 +215,93 @@ class StateHandler(object): else: prev_states = [] + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } + try: - new_state = {} - new_state.update(unconflicted_state) - for key, events in conflicted_state.items(): - new_state[key] = self._resolve_state_events(events) + resolved_state = self._resolve_state_events( + conflicted_state, auth_events + ) except: logger.exception("Failed to resolve state") raise + new_state = unconflicted_state + new_state.update(resolved_state) + defer.returnValue((None, new_state, prev_states)) - def _get_power_level_from_event_state(self, event, user_id): - if hasattr(event, "old_state_events") and event.old_state_events: - key = (EventTypes.PowerLevels, "", ) - power_level_event = event.old_state_events.get(key) - level = None - if power_level_event: - level = power_level_event.content.get("users", {}).get( - user_id - ) - if not level: - level = power_level_event.content.get("users_default", 0) - - return level - else: - return 0 - @log_function - def _resolve_state_events(self, events): - curr_events = events + def _resolve_state_events(self, conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. - new_powers = [ - self._get_power_level_from_event_state(e, e.user_id) - for e in curr_events - ] + We resolve conflicts in the following order: + 1. power levels + 2. memberships + 3. other events. + """ + resolved_state = {} + power_key = (EventTypes.PowerLevels, "") + if power_key in conflicted_state.items(): + power_levels = conflicted_state[power_key] + resolved_state[power_key] = self._resolve_auth_events(power_levels) - new_powers = [ - int(p) if p else 0 for p in new_powers - ] + auth_events.update(resolved_state) - max_power = max(new_powers) + for key, events in conflicted_state.items(): + if key[0] == EventTypes.Member: + resolved_state[key] = self._resolve_auth_events( + events, + auth_events + ) - curr_events = [ - z[0] for z in zip(curr_events, new_powers) - if z[1] == max_power - ] + auth_events.update(resolved_state) - if not curr_events: - raise RuntimeError("Max didn't get a max?") - elif len(curr_events) == 1: - return curr_events[0] + for key, events in conflicted_state.items(): + if key not in resolved_state: + resolved_state[key] = self._resolve_normal_events( + events, auth_events + ) - # TODO: For now, just choose the one with the largest event_id. - return ( - sorted( - curr_events, - key=lambda e: hashlib.sha1( - e.event_id + e.user_id + e.room_id + e.type - ).hexdigest() - )[0] - ) + return resolved_state + + def _resolve_auth_events(self, events, auth_events): + reverse = [i for i in reversed(self._ordered_events(events))] + + auth_events = dict(auth_events) + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event + try: + # FIXME: hs.get_auth() is bad style, but we need to do it to + # get around circular deps. + self.hs.get_auth().check(event, auth_events) + prev_event = event + except AuthError: + return prev_event + + return event + + def _resolve_normal_events(self, events, auth_events): + for event in self._ordered_events(events): + try: + # FIXME: hs.get_auth() is bad style, but we need to do it to + # get around circular deps. + self.hs.get_auth().check(event, auth_events) + return event + except AuthError: + pass + + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event + + def _ordered_events(self, events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() + + return sorted(events, key=key_func) diff --git a/tests/test_state.py b/tests/test_state.py index 98ad9e54cd..019e794aa2 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -16,11 +16,120 @@ from tests import unittest from twisted.internet import defer +from synapse.events import FrozenEvent +from synapse.api.auth import Auth +from synapse.api.constants import EventTypes, Membership from synapse.state import StateHandler from mock import Mock +_next_event_id = 1000 + + +def create_event(name=None, type=None, state_key=None, depth=2, event_id=None, + prev_events=[], **kwargs): + global _next_event_id + + if not event_id: + _next_event_id += 1 + event_id = str(_next_event_id) + + if not name: + if state_key is not None: + name = "<%s-%s, %s>" % (type, state_key, event_id,) + else: + name = "<%s, %s>" % (type, event_id,) + + d = { + "event_id": event_id, + "type": type, + "sender": "@user_id:example.com", + "room_id": "!room_id:example.com", + "depth": depth, + "prev_events": prev_events, + } + + if state_key is not None: + d["state_key"] = state_key + + d.update(kwargs) + + event = FrozenEvent(d) + + return event + + +class StateGroupStore(object): + def __init__(self): + self._event_to_state_group = {} + self._group_to_state = {} + + self._next_group = 1 + + def get_state_groups(self, event_ids): + groups = {} + for event_id in event_ids: + group = self._event_to_state_group.get(event_id) + if group: + groups[group] = self._group_to_state[group] + + return defer.succeed(groups) + + def store_state_groups(self, event, context): + if context.current_state is None: + return + + state_events = context.current_state + + if event.is_state(): + state_events[(event.type, event.state_key)] = event + + state_group = context.state_group + if not state_group: + state_group = self._next_group + self._next_group += 1 + + self._group_to_state[state_group] = state_events.values() + + self._event_to_state_group[event.event_id] = state_group + + +class DictObj(dict): + def __init__(self, **kwargs): + super(DictObj, self).__init__(kwargs) + self.__dict__ = self + + +class Graph(object): + def __init__(self, nodes, edges): + events = {} + clobbered = set(events.keys()) + + for event_id, fields in nodes.items(): + refs = edges.get(event_id) + if refs: + clobbered.difference_update(refs) + prev_events = [(r, {}) for r in refs] + else: + prev_events = [] + + events[event_id] = create_event( + event_id=event_id, + prev_events=prev_events, + **fields + ) + + self._leaves = clobbered + self._events = sorted(events.values(), key=lambda e: e.depth) + + def walk(self): + return iter(self._events) + + def get_leaves(self): + return (self._events[i] for i in self._leaves) + + class StateTestCase(unittest.TestCase): def setUp(self): self.store = Mock( @@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase): "add_event_hashes", ] ) - hs = Mock(spec=["get_datastore"]) + hs = Mock(spec=["get_datastore", "get_auth", "get_state_handler"]) hs.get_datastore.return_value = self.store + hs.get_state_handler.return_value = None + hs.get_auth.return_value = Auth(hs) self.state = StateHandler(hs) self.event_id = 0 + @defer.inlineCallbacks + def test_branch_no_conflict(self): + graph = Graph( + nodes={ + "START": DictObj( + type=EventTypes.Create, + state_key="", + depth=1, + ), + "A": DictObj( + type=EventTypes.Message, + depth=2, + ), + "B": DictObj( + type=EventTypes.Message, + depth=3, + ), + "C": DictObj( + type=EventTypes.Name, + state_key="", + depth=3, + ), + "D": DictObj( + type=EventTypes.Message, + depth=4, + ), + }, + edges={ + "A": ["START"], + "B": ["A"], + "C": ["A"], + "D": ["B", "C"] + } + ) + + store = StateGroupStore() + self.store.get_state_groups.side_effect = store.get_state_groups + + context_store = {} + + for event in graph.walk(): + context = yield self.state.compute_event_context(event) + store.store_state_groups(event, context) + context_store[event.event_id] = context + + self.assertEqual(2, len(context_store["D"].current_state)) + + @defer.inlineCallbacks + def test_branch_basic_conflict(self): + graph = Graph( + nodes={ + "START": DictObj( + type=EventTypes.Create, + state_key="creator", + content={"membership": "@user_id:example.com"}, + depth=1, + ), + "A": DictObj( + type=EventTypes.Member, + state_key="@user_id:example.com", + content={"membership": Membership.JOIN}, + membership=Membership.JOIN, + depth=2, + ), + "B": DictObj( + type=EventTypes.Name, + state_key="", + depth=3, + ), + "C": DictObj( + type=EventTypes.Name, + state_key="", + depth=4, + ), + "D": DictObj( + type=EventTypes.Message, + depth=5, + ), + }, + edges={ + "A": ["START"], + "B": ["A"], + "C": ["A"], + "D": ["B", "C"] + } + ) + + store = StateGroupStore() + self.store.get_state_groups.side_effect = store.get_state_groups + + context_store = {} + + for event in graph.walk(): + context = yield self.state.compute_event_context(event) + store.store_state_groups(event, context) + context_store[event.event_id] = context + + self.assertSetEqual( + {"START", "A", "C"}, + {e.event_id for e in context_store["D"].current_state.values()} + ) + + @defer.inlineCallbacks + def test_branch_have_banned_conflict(self): + graph = Graph( + nodes={ + "START": DictObj( + type=EventTypes.Create, + state_key="creator", + content={"membership": "@user_id:example.com"}, + depth=1, + ), + "A": DictObj( + type=EventTypes.Member, + state_key="@user_id:example.com", + content={"membership": Membership.JOIN}, + membership=Membership.JOIN, + depth=2, + ), + "B": DictObj( + type=EventTypes.Name, + state_key="", + depth=3, + ), + "C": DictObj( + type=EventTypes.Member, + state_key="@user_id_2:example.com", + content={"membership": Membership.BAN}, + membership=Membership.BAN, + depth=4, + ), + "D": DictObj( + type=EventTypes.Name, + state_key="", + depth=4, + sender="@user_id_2:example.com", + ), + "E": DictObj( + type=EventTypes.Message, + depth=5, + ), + }, + edges={ + "A": ["START"], + "B": ["A"], + "C": ["B"], + "D": ["B"], + "E": ["C", "D"] + } + ) + + store = StateGroupStore() + self.store.get_state_groups.side_effect = store.get_state_groups + + context_store = {} + + for event in graph.walk(): + context = yield self.state.compute_event_context(event) + store.store_state_groups(event, context) + context_store[event.event_id] = context + + self.assertSetEqual( + {"START", "A", "B", "C"}, + {e.event_id for e in context_store["E"].current_state.values()} + ) + @defer.inlineCallbacks def test_annotate_with_old_message(self): - event = self.create_event(type="test_message", name="event") + event = create_event(type="test_message", name="event") old_state = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context( @@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_annotate_with_old_state(self): - event = self.create_event(type="state", state_key="", name="event") + event = create_event(type="state", state_key="", name="event") old_state = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context( @@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_trivial_annotate_message(self): - event = self.create_event(type="test_message", name="event") - event.prev_events = [] + event = create_event(type="test_message", name="event") old_state = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] group_name = "group_name_1" @@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_trivial_annotate_state(self): - event = self.create_event(type="state", state_key="", name="event") - event.prev_events = [] + event = create_event(type="state", state_key="", name="event") old_state = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] group_name = "group_name_1" @@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_resolve_message_conflict(self): - event = self.create_event(type="test_message", name="event") - event.prev_events = [] + event = create_event(type="test_message", name="event") old_state_1 = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] old_state_2 = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test3", state_key="2"), - self.create_event(type="test4", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test3", state_key="2"), + create_event(type="test4", state_key=""), ] - group_name_1 = "group_name_1" - group_name_2 = "group_name_2" - - self.store.get_state_groups.return_value = { - group_name_1: old_state_1, - group_name_2: old_state_2, - } - - context = yield self.state.compute_event_context(event) + context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(len(context.current_state), 5) @@ -181,21 +447,70 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_resolve_state_conflict(self): - event = self.create_event(type="test4", state_key="", name="event") - event.prev_events = [] + event = create_event(type="test4", state_key="", name="event") old_state_1 = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] old_state_2 = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test3", state_key="2"), - self.create_event(type="test4", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test3", state_key="2"), + create_event(type="test4", state_key=""), ] + context = yield self._get_context(event, old_state_1, old_state_2) + + self.assertEqual(len(context.current_state), 5) + + self.assertIsNone(context.state_group) + + @defer.inlineCallbacks + def test_standard_depth_conflict(self): + event = create_event(type="test4", name="event") + + member_event = create_event( + type=EventTypes.Member, + state_key="@user_id:example.com", + content={ + "membership": Membership.JOIN, + } + ) + + old_state_1 = [ + member_event, + create_event(type="test1", state_key="1", depth=1), + ] + + old_state_2 = [ + member_event, + create_event(type="test1", state_key="1", depth=2), + ] + + context = yield self._get_context(event, old_state_1, old_state_2) + + self.assertEqual(old_state_2[1], context.current_state[("test1", "1")]) + + # Reverse the depth to make sure we are actually using the depths + # during state resolution. + + old_state_1 = [ + member_event, + create_event(type="test1", state_key="1", depth=2), + ] + + old_state_2 = [ + member_event, + create_event(type="test1", state_key="1", depth=1), + ] + + context = yield self._get_context(event, old_state_1, old_state_2) + + self.assertEqual(old_state_1[1], context.current_state[("test1", "1")]) + + def _get_context(self, event, old_state_1, old_state_2): group_name_1 = "group_name_1" group_name_2 = "group_name_2" @@ -204,33 +519,4 @@ class StateTestCase(unittest.TestCase): group_name_2: old_state_2, } - context = yield self.state.compute_event_context(event) - - self.assertEqual(len(context.current_state), 5) - - self.assertIsNone(context.state_group) - - def create_event(self, name=None, type=None, state_key=None): - self.event_id += 1 - event_id = str(self.event_id) - - if not name: - if state_key is not None: - name = "<%s-%s>" % (type, state_key) - else: - name = "<%s>" % (type, ) - - event = Mock(name=name, spec=[]) - event.type = type - - if state_key is not None: - event.state_key = state_key - event.event_id = event_id - - event.is_state = lambda: (state_key is not None) - event.unsigned = {} - - event.user_id = "@user_id:example.com" - event.room_id = "!room_id:example.com" - - return event + return self.state.compute_event_context(event)