# -*- coding: utf-8 -*- # Copyright 2014 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from tests import unittest from twisted.internet import defer from synapse.state import StateHandler from mock import Mock class StateTestCase(unittest.TestCase): def setUp(self): self.store = Mock( spec_set=[ "get_state_groups", "add_event_hashes", ] ) hs = Mock(spec=["get_datastore"]) hs.get_datastore.return_value = self.store self.state = StateHandler(hs) self.event_id = 0 @defer.inlineCallbacks def test_annotate_with_old_message(self): event = self.create_event(type="test_message", name="event") old_state = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context( event, old_state=old_state ) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set(old_state), set(context.current_state.values()) ) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_annotate_with_old_state(self): event = self.create_event(type="state", state_key="", name="event") old_state = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context( event, old_state=old_state ) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set(old_state), set(context.current_state.values()) ) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_trivial_annotate_message(self): event = self.create_event(type="test_message", name="event") event.prev_events = [] old_state = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] group_name = "group_name_1" self.store.get_state_groups.return_value = { group_name: old_state, } context = yield self.state.compute_event_context(event) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set([e.event_id for e in old_state]), set([e.event_id for e in context.current_state.values()]) ) self.assertEqual(group_name, context.state_group) @defer.inlineCallbacks def test_trivial_annotate_state(self): event = self.create_event(type="state", state_key="", name="event") event.prev_events = [] old_state = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] group_name = "group_name_1" self.store.get_state_groups.return_value = { group_name: old_state, } context = yield self.state.compute_event_context(event) for k, v in context.current_state.items(): type, state_key = k self.assertEqual(type, v.type) self.assertEqual(state_key, v.state_key) self.assertEqual( set([e.event_id for e in old_state]), set([e.event_id for e in context.current_state.values()]) ) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_resolve_message_conflict(self): event = self.create_event(type="test_message", name="event") event.prev_events = [] old_state_1 = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] old_state_2 = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test3", state_key="2"), self.create_event(type="test4", state_key=""), ] group_name_1 = "group_name_1" group_name_2 = "group_name_2" self.store.get_state_groups.return_value = { group_name_1: old_state_1, group_name_2: old_state_2, } context = yield self.state.compute_event_context(event) self.assertEqual(len(context.current_state), 5) self.assertIsNone(context.state_group) @defer.inlineCallbacks def test_resolve_state_conflict(self): event = self.create_event(type="test4", state_key="", name="event") event.prev_events = [] old_state_1 = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test1", state_key="2"), self.create_event(type="test2", state_key=""), ] old_state_2 = [ self.create_event(type="test1", state_key="1"), self.create_event(type="test3", state_key="2"), self.create_event(type="test4", state_key=""), ] group_name_1 = "group_name_1" group_name_2 = "group_name_2" self.store.get_state_groups.return_value = { group_name_1: old_state_1, group_name_2: old_state_2, } context = yield self.state.compute_event_context(event) self.assertEqual(len(context.current_state), 5) self.assertIsNone(context.state_group) def create_event(self, name=None, type=None, state_key=None): self.event_id += 1 event_id = str(self.event_id) if not name: if state_key is not None: name = "<%s-%s>" % (type, state_key) else: name = "<%s>" % (type, ) event = Mock(name=name, spec=[]) event.type = type if state_key is not None: event.state_key = state_key event.event_id = event_id event.is_state = lambda: (state_key is not None) event.unsigned = {} event.user_id = "@user_id:example.com" event.room_id = "!room_id:example.com" return event