237 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			237 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
| # -*- coding: utf-8 -*-
 | |
| # Copyright 2014 OpenMarket Ltd
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| from tests import unittest
 | |
| from twisted.internet import defer
 | |
| 
 | |
| from synapse.state import StateHandler
 | |
| 
 | |
| from mock import Mock
 | |
| 
 | |
| 
 | |
| class StateTestCase(unittest.TestCase):
 | |
|     def setUp(self):
 | |
|         self.store = Mock(
 | |
|             spec_set=[
 | |
|                 "get_state_groups",
 | |
|                 "add_event_hashes",
 | |
|             ]
 | |
|         )
 | |
|         hs = Mock(spec=["get_datastore"])
 | |
|         hs.get_datastore.return_value = self.store
 | |
| 
 | |
|         self.state = StateHandler(hs)
 | |
|         self.event_id = 0
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def test_annotate_with_old_message(self):
 | |
|         event = self.create_event(type="test_message", name="event")
 | |
| 
 | |
|         old_state = [
 | |
|             self.create_event(type="test1", state_key="1"),
 | |
|             self.create_event(type="test1", state_key="2"),
 | |
|             self.create_event(type="test2", state_key=""),
 | |
|         ]
 | |
| 
 | |
|         context = yield self.state.compute_event_context(
 | |
|             event, old_state=old_state
 | |
|         )
 | |
| 
 | |
|         for k, v in context.current_state.items():
 | |
|             type, state_key = k
 | |
|             self.assertEqual(type, v.type)
 | |
|             self.assertEqual(state_key, v.state_key)
 | |
| 
 | |
|         self.assertEqual(
 | |
|             set(old_state), set(context.current_state.values())
 | |
|         )
 | |
| 
 | |
|         self.assertIsNone(context.state_group)
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def test_annotate_with_old_state(self):
 | |
|         event = self.create_event(type="state", state_key="", name="event")
 | |
| 
 | |
|         old_state = [
 | |
|             self.create_event(type="test1", state_key="1"),
 | |
|             self.create_event(type="test1", state_key="2"),
 | |
|             self.create_event(type="test2", state_key=""),
 | |
|         ]
 | |
| 
 | |
|         context = yield self.state.compute_event_context(
 | |
|             event, old_state=old_state
 | |
|         )
 | |
| 
 | |
|         for k, v in context.current_state.items():
 | |
|             type, state_key = k
 | |
|             self.assertEqual(type, v.type)
 | |
|             self.assertEqual(state_key, v.state_key)
 | |
| 
 | |
|         self.assertEqual(
 | |
|             set(old_state),
 | |
|             set(context.current_state.values())
 | |
|         )
 | |
| 
 | |
|         self.assertIsNone(context.state_group)
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def test_trivial_annotate_message(self):
 | |
|         event = self.create_event(type="test_message", name="event")
 | |
|         event.prev_events = []
 | |
| 
 | |
|         old_state = [
 | |
|             self.create_event(type="test1", state_key="1"),
 | |
|             self.create_event(type="test1", state_key="2"),
 | |
|             self.create_event(type="test2", state_key=""),
 | |
|         ]
 | |
| 
 | |
|         group_name = "group_name_1"
 | |
| 
 | |
|         self.store.get_state_groups.return_value = {
 | |
|             group_name: old_state,
 | |
|         }
 | |
| 
 | |
|         context = yield self.state.compute_event_context(event)
 | |
| 
 | |
|         for k, v in context.current_state.items():
 | |
|             type, state_key = k
 | |
|             self.assertEqual(type, v.type)
 | |
|             self.assertEqual(state_key, v.state_key)
 | |
| 
 | |
|         self.assertEqual(
 | |
|             set([e.event_id for e in old_state]),
 | |
|             set([e.event_id for e in context.current_state.values()])
 | |
|         )
 | |
| 
 | |
|         self.assertEqual(group_name, context.state_group)
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def test_trivial_annotate_state(self):
 | |
|         event = self.create_event(type="state", state_key="", name="event")
 | |
|         event.prev_events = []
 | |
| 
 | |
|         old_state = [
 | |
|             self.create_event(type="test1", state_key="1"),
 | |
|             self.create_event(type="test1", state_key="2"),
 | |
|             self.create_event(type="test2", state_key=""),
 | |
|         ]
 | |
| 
 | |
|         group_name = "group_name_1"
 | |
| 
 | |
|         self.store.get_state_groups.return_value = {
 | |
|             group_name: old_state,
 | |
|         }
 | |
| 
 | |
|         context = yield self.state.compute_event_context(event)
 | |
| 
 | |
|         for k, v in context.current_state.items():
 | |
|             type, state_key = k
 | |
|             self.assertEqual(type, v.type)
 | |
|             self.assertEqual(state_key, v.state_key)
 | |
| 
 | |
|         self.assertEqual(
 | |
|             set([e.event_id for e in old_state]),
 | |
|             set([e.event_id for e in context.current_state.values()])
 | |
|         )
 | |
| 
 | |
|         self.assertIsNone(context.state_group)
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def test_resolve_message_conflict(self):
 | |
|         event = self.create_event(type="test_message", name="event")
 | |
|         event.prev_events = []
 | |
| 
 | |
|         old_state_1 = [
 | |
|             self.create_event(type="test1", state_key="1"),
 | |
|             self.create_event(type="test1", state_key="2"),
 | |
|             self.create_event(type="test2", state_key=""),
 | |
|         ]
 | |
| 
 | |
|         old_state_2 = [
 | |
|             self.create_event(type="test1", state_key="1"),
 | |
|             self.create_event(type="test3", state_key="2"),
 | |
|             self.create_event(type="test4", state_key=""),
 | |
|         ]
 | |
| 
 | |
|         group_name_1 = "group_name_1"
 | |
|         group_name_2 = "group_name_2"
 | |
| 
 | |
|         self.store.get_state_groups.return_value = {
 | |
|             group_name_1: old_state_1,
 | |
|             group_name_2: old_state_2,
 | |
|         }
 | |
| 
 | |
|         context = yield self.state.compute_event_context(event)
 | |
| 
 | |
|         self.assertEqual(len(context.current_state), 5)
 | |
| 
 | |
|         self.assertIsNone(context.state_group)
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def test_resolve_state_conflict(self):
 | |
|         event = self.create_event(type="test4", state_key="", name="event")
 | |
|         event.prev_events = []
 | |
| 
 | |
|         old_state_1 = [
 | |
|             self.create_event(type="test1", state_key="1"),
 | |
|             self.create_event(type="test1", state_key="2"),
 | |
|             self.create_event(type="test2", state_key=""),
 | |
|         ]
 | |
| 
 | |
|         old_state_2 = [
 | |
|             self.create_event(type="test1", state_key="1"),
 | |
|             self.create_event(type="test3", state_key="2"),
 | |
|             self.create_event(type="test4", state_key=""),
 | |
|         ]
 | |
| 
 | |
|         group_name_1 = "group_name_1"
 | |
|         group_name_2 = "group_name_2"
 | |
| 
 | |
|         self.store.get_state_groups.return_value = {
 | |
|             group_name_1: old_state_1,
 | |
|             group_name_2: old_state_2,
 | |
|         }
 | |
| 
 | |
|         context = yield self.state.compute_event_context(event)
 | |
| 
 | |
|         self.assertEqual(len(context.current_state), 5)
 | |
| 
 | |
|         self.assertIsNone(context.state_group)
 | |
| 
 | |
|     def create_event(self, name=None, type=None, state_key=None):
 | |
|         self.event_id += 1
 | |
|         event_id = str(self.event_id)
 | |
| 
 | |
|         if not name:
 | |
|             if state_key is not None:
 | |
|                 name = "<%s-%s>" % (type, state_key)
 | |
|             else:
 | |
|                 name = "<%s>" % (type, )
 | |
| 
 | |
|         event = Mock(name=name, spec=[])
 | |
|         event.type = type
 | |
| 
 | |
|         if state_key is not None:
 | |
|             event.state_key = state_key
 | |
|         event.event_id = event_id
 | |
| 
 | |
|         event.is_state = lambda: (state_key is not None)
 | |
|         event.unsigned = {}
 | |
| 
 | |
|         event.user_id = "@user_id:example.com"
 | |
|         event.room_id = "!room_id:example.com"
 | |
| 
 | |
|         return event
 |