From 83ce57302dab6a825f3afde11926b5404ce1c9ff Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 8 Sep 2014 19:50:46 +0100 Subject: [PATCH] Fix bug in state handling where we incorrectly identified a missing pdu. Update tests to catch this case. --- synapse/state.py | 100 +++++++++--------- synapse/storage/pdu.py | 9 +- tests/test_state.py | 233 ++++++++++++++++++++++++++++++++++++++--- 3 files changed, 271 insertions(+), 71 deletions(-) diff --git a/synapse/state.py b/synapse/state.py index 5dcff27367..e69282860a 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -134,7 +134,9 @@ class StateHandler(object): @defer.inlineCallbacks @log_function def _handle_new_state(self, new_pdu): - tree = yield self.store.get_unresolved_state_tree(new_pdu) + tree, missing_branch = yield self.store.get_unresolved_state_tree( + new_pdu + ) new_branch, current_branch = tree logger.debug( @@ -142,64 +144,17 @@ class StateHandler(object): new_branch, current_branch ) - if not current_branch: - # There is no current state - defer.returnValue(True) - return - - n = new_branch[-1] - c = current_branch[-1] - - if n.pdu_id == c.pdu_id and n.origin == c.origin: - # We have all the PDUs we need, so we can just do the conflict - # resolution. - - if len(current_branch) == 1: - # This is a direct clobber so we can just... - defer.returnValue(True) - - conflict_res = [ - self._do_power_level_conflict_res, - self._do_chain_length_conflict_res, - self._do_hash_conflict_res, - ] - - for algo in conflict_res: - new_res, curr_res = algo(new_branch, current_branch) - - if new_res < curr_res: - defer.returnValue(False) - elif new_res > curr_res: - defer.returnValue(True) - - raise Exception("Conflict resolution failed.") - - else: - # We need to ask for PDUs. - missing_prev = max( - new_branch[-1], current_branch[-1], - key=lambda x: x.depth - ) - - if not hasattr(missing_prev, "prev_state_id"): - # FIXME Hmm - # temporary fallback - for algo in conflict_res: - new_res, curr_res = algo(new_branch, current_branch) - - if new_res < curr_res: - defer.returnValue(False) - elif new_res > curr_res: - defer.returnValue(True) - return + if missing_branch is not None: + # We're missing some PDUs. Fetch them. + # TODO (erikj): Limit this. + missing_prev = tree[missing_branch][-1] pdu_id = missing_prev.prev_state_id origin = missing_prev.prev_state_origin is_missing = yield self.store.get_pdu(pdu_id, origin) is None - if not is_missing: - raise Exception("Conflict resolution failed.") + raise Exception("Conflict resolution failed") yield self._replication.get_pdu( destination=missing_prev.origin, @@ -211,6 +166,45 @@ class StateHandler(object): updated_current = yield self._handle_new_state(new_pdu) defer.returnValue(updated_current) + if not current_branch: + # There is no current state + defer.returnValue(True) + return + + n = new_branch[-1] + c = current_branch[-1] + + if n.pdu_id == c.pdu_id and n.origin == c.origin: + # We found a common ancestor! + + if len(current_branch) == 1: + # This is a direct clobber so we can just... + defer.returnValue(True) + + else: + # We didn't find a common ancestor. This is probably fine. + pass + + result = self._do_conflict_res(new_branch, current_branch) + defer.returnValue(result) + + def _do_conflict_res(self, new_branch, current_branch): + conflict_res = [ + self._do_power_level_conflict_res, + self._do_chain_length_conflict_res, + self._do_hash_conflict_res, + ] + + for algo in conflict_res: + new_res, curr_res = algo(new_branch, current_branch) + + if new_res < curr_res: + defer.returnValue(False) + elif new_res > curr_res: + defer.returnValue(True) + + raise Exception("Conflict resolution failed.") + def _do_power_level_conflict_res(self, new_branch, current_branch): max_power_new = max( new_branch[:-1], diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index 0bf97e37ee..3cbce2d0a1 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -308,8 +308,8 @@ class PduStore(SQLBaseStore): @defer.inlineCallbacks def get_oldest_pdus_in_context(self, context): - """Get a list of Pdus that we haven't backfilled beyond yet (and haven't - seen). This list is used when we want to backfill backwards and is the + """Get a list of Pdus that we haven't backfilled beyond yet (and havent + seen). This list is used when we want to backfill backwards and is the list we send to the remote server. Args: @@ -524,13 +524,16 @@ class StatePduStore(SQLBaseStore): txn, new_pdu, current ) + missing_branch = None for branch, prev_state, state in enum_branches: if state: return_value[branch].append(state) else: + # We don't have prev_state :( + missing_branch = branch break - return return_value + return (return_value, missing_branch) def update_current_state(self, pdu_id, origin, context, pdu_type, state_key): diff --git a/tests/test_state.py b/tests/test_state.py index b01496c40f..4512475ebd 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -24,6 +24,8 @@ from collections import namedtuple from mock import Mock +import mock + ReturnType = namedtuple( "StateReturnType", ["new_branch", "current_branch"] @@ -54,7 +56,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu], []) + (ReturnType([new_pdu], []), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -78,7 +80,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", "A", 5) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu], [old_pdu]) + (ReturnType([new_pdu, old_pdu], [old_pdu]), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -103,7 +105,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 5) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) + (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -128,7 +130,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 15) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) + (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -153,7 +155,7 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 10) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]) + (ReturnType([new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1]), None) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -179,7 +181,13 @@ class StateTestCase(unittest.TestCase): new_pdu = new_fake_pdu_entry("D", "test", "mem", "x", "C", 10) self.persistence.get_unresolved_state_tree.return_value = ( - ReturnType([new_pdu, old_pdu_3, old_pdu_1], [old_pdu_2, old_pdu_1]) + ( + ReturnType( + [new_pdu, old_pdu_3, old_pdu_1], + [old_pdu_2, old_pdu_1] + ), + None + ) ) is_new = yield self.state.handle_new_state(new_pdu) @@ -200,22 +208,32 @@ class StateTestCase(unittest.TestCase): # triggering a get_pdu request # The pdu we haven't seen - old_pdu_1 = new_fake_pdu_entry("A", "test", "mem", "x", None, 10) + old_pdu_1 = new_fake_pdu_entry( + "A", "test", "mem", "x", None, 10, depth=0 + ) - old_pdu_2 = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) - new_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20) + old_pdu_2 = new_fake_pdu_entry( + "B", "test", "mem", "x", "A", 10, depth=1 + ) + new_pdu = new_fake_pdu_entry( + "C", "test", "mem", "x", "A", 20, depth=2 + ) # The return_value of `get_unresolved_state_tree`, which changes after # the call to get_pdu - tree_to_return = [ReturnType([new_pdu], [old_pdu_2])] + tree_to_return = [(ReturnType([new_pdu], [old_pdu_2]), 0)] def return_tree(p): return tree_to_return[0] - def set_return_tree(*args, **kwargs): - tree_to_return[0] = ReturnType( - [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1] + def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): + tree_to_return[0] = ( + ReturnType( + [new_pdu, old_pdu_1], [old_pdu_2, old_pdu_1] + ), + None ) + return defer.succeed(None) self.persistence.get_unresolved_state_tree.side_effect = return_tree @@ -227,6 +245,13 @@ class StateTestCase(unittest.TestCase): self.assertTrue(is_new) + self.replication.get_pdu.assert_called_with( + destination=new_pdu.origin, + pdu_origin=old_pdu_1.origin, + pdu_id=old_pdu_1.pdu_id, + outlier=True + ) + self.persistence.get_unresolved_state_tree.assert_called_with( new_pdu ) @@ -237,6 +262,184 @@ class StateTestCase(unittest.TestCase): self.assertEqual(1, self.persistence.update_current_state.call_count) + @defer.inlineCallbacks + def test_missing_pdu_depth_1(self): + # We try to update state against a PDU we haven't yet seen, + # triggering a get_pdu request + + # The pdu we haven't seen + old_pdu_1 = new_fake_pdu_entry( + "A", "test", "mem", "x", None, 10, depth=0 + ) + + old_pdu_2 = new_fake_pdu_entry( + "B", "test", "mem", "x", "A", 10, depth=2 + ) + old_pdu_3 = new_fake_pdu_entry( + "C", "test", "mem", "x", "B", 10, depth=3 + ) + new_pdu = new_fake_pdu_entry( + "D", "test", "mem", "x", "A", 20, depth=4 + ) + + # The return_value of `get_unresolved_state_tree`, which changes after + # the call to get_pdu + tree_to_return = [ + ( + ReturnType([new_pdu], [old_pdu_3]), + 0 + ), + ( + ReturnType( + [new_pdu, old_pdu_1], [old_pdu_3] + ), + 1 + ), + ( + ReturnType( + [new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1] + ), + None + ), + ] + + to_return = [0] + + def return_tree(p): + return tree_to_return[to_return[0]] + + def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): + to_return[0] += 1 + return defer.succeed(None) + + self.persistence.get_unresolved_state_tree.side_effect = return_tree + + self.replication.get_pdu.side_effect = set_return_tree + + self.persistence.get_pdu.return_value = None + + is_new = yield self.state.handle_new_state(new_pdu) + + self.assertTrue(is_new) + + self.assertEqual(2, self.replication.get_pdu.call_count) + + self.replication.get_pdu.assert_has_calls( + [ + mock.call( + destination=new_pdu.origin, + pdu_origin=old_pdu_1.origin, + pdu_id=old_pdu_1.pdu_id, + outlier=True + ), + mock.call( + destination=old_pdu_3.origin, + pdu_origin=old_pdu_2.origin, + pdu_id=old_pdu_2.pdu_id, + outlier=True + ), + ] + ) + + self.persistence.get_unresolved_state_tree.assert_called_with( + new_pdu + ) + + self.assertEquals( + 3, self.persistence.get_unresolved_state_tree.call_count + ) + + self.assertEqual(1, self.persistence.update_current_state.call_count) + + @defer.inlineCallbacks + def test_missing_pdu_depth_2(self): + # We try to update state against a PDU we haven't yet seen, + # triggering a get_pdu request + + # The pdu we haven't seen + old_pdu_1 = new_fake_pdu_entry( + "A", "test", "mem", "x", None, 10, depth=0 + ) + + old_pdu_2 = new_fake_pdu_entry( + "B", "test", "mem", "x", "A", 10, depth=2 + ) + old_pdu_3 = new_fake_pdu_entry( + "C", "test", "mem", "x", "B", 10, depth=3 + ) + new_pdu = new_fake_pdu_entry( + "D", "test", "mem", "x", "A", 20, depth=1 + ) + + # The return_value of `get_unresolved_state_tree`, which changes after + # the call to get_pdu + tree_to_return = [ + ( + ReturnType([new_pdu], [old_pdu_3]), + 1, + ), + ( + ReturnType( + [new_pdu], [old_pdu_3, old_pdu_2] + ), + 0, + ), + ( + ReturnType( + [new_pdu, old_pdu_1], [old_pdu_3, old_pdu_2, old_pdu_1] + ), + None + ), + ] + + to_return = [0] + + def return_tree(p): + return tree_to_return[to_return[0]] + + def set_return_tree(destination, pdu_origin, pdu_id, outlier=False): + to_return[0] += 1 + return defer.succeed(None) + + self.persistence.get_unresolved_state_tree.side_effect = return_tree + + self.replication.get_pdu.side_effect = set_return_tree + + self.persistence.get_pdu.return_value = None + + is_new = yield self.state.handle_new_state(new_pdu) + + self.assertTrue(is_new) + + self.assertEqual(2, self.replication.get_pdu.call_count) + + self.replication.get_pdu.assert_has_calls( + [ + mock.call( + destination=old_pdu_3.origin, + pdu_origin=old_pdu_2.origin, + pdu_id=old_pdu_2.pdu_id, + outlier=True + ), + mock.call( + destination=new_pdu.origin, + pdu_origin=old_pdu_1.origin, + pdu_id=old_pdu_1.pdu_id, + outlier=True + ), + ] + ) + + self.persistence.get_unresolved_state_tree.assert_called_with( + new_pdu + ) + + self.assertEquals( + 3, self.persistence.get_unresolved_state_tree.call_count + ) + + self.assertEqual(1, self.persistence.update_current_state.call_count) + @defer.inlineCallbacks def test_new_event(self): event = Mock() @@ -270,7 +473,7 @@ class StateTestCase(unittest.TestCase): def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id, - power_level): + power_level, depth=0): new_pdu = PduEntry( pdu_id=pdu_id, pdu_type=pdu_type, @@ -280,7 +483,7 @@ def new_fake_pdu_entry(pdu_id, context, pdu_type, state_key, prev_state_id, origin="example.com", context="context", ts=1405353060021, - depth=0, + depth=depth, content_json="{}", unrecognized_keys="{}", outlier=True,