diff --git a/synapse/state.py b/synapse/state.py index e69282860a..0cc1344d51 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -174,7 +174,9 @@ class StateHandler(object): n = new_branch[-1] c = current_branch[-1] - if n.pdu_id == c.pdu_id and n.origin == c.origin: + common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin + + if common_ancestor: # We found a common ancestor! if len(current_branch) == 1: @@ -185,10 +187,12 @@ class StateHandler(object): # We didn't find a common ancestor. This is probably fine. pass - result = self._do_conflict_res(new_branch, current_branch) + result = self._do_conflict_res( + new_branch, current_branch, common_ancestor + ) defer.returnValue(result) - def _do_conflict_res(self, new_branch, current_branch): + def _do_conflict_res(self, new_branch, current_branch, common_ancestor): conflict_res = [ self._do_power_level_conflict_res, self._do_chain_length_conflict_res, @@ -196,7 +200,9 @@ class StateHandler(object): ] for algo in conflict_res: - new_res, curr_res = algo(new_branch, current_branch) + new_res, curr_res = algo( + new_branch, current_branch, common_ancestor + ) if new_res < curr_res: defer.returnValue(False) @@ -205,23 +211,26 @@ class StateHandler(object): raise Exception("Conflict resolution failed.") - def _do_power_level_conflict_res(self, new_branch, current_branch): + def _do_power_level_conflict_res(self, new_branch, current_branch, + common_ancestor): max_power_new = max( - new_branch[:-1], + new_branch[:-1] if common_ancestor else new_branch, key=lambda t: t.power_level ).power_level max_power_current = max( - current_branch[:-1], + current_branch[:-1] if common_ancestor else current_branch, key=lambda t: t.power_level ).power_level return (max_power_new, max_power_current) - def _do_chain_length_conflict_res(self, new_branch, current_branch): + def _do_chain_length_conflict_res(self, new_branch, current_branch, + common_ancestor): return (len(new_branch), len(current_branch)) - def _do_hash_conflict_res(self, new_branch, current_branch): + def _do_hash_conflict_res(self, new_branch, current_branch, + common_ancestor): new_str = "".join([p.pdu_id + p.origin for p in new_branch]) c_str = "".join([p.pdu_id + p.origin for p in current_branch]) diff --git a/tests/test_state.py b/tests/test_state.py index 4512475ebd..a9fc3fb85c 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -440,6 +440,30 @@ class StateTestCase(unittest.TestCase): self.assertEqual(1, self.persistence.update_current_state.call_count) + @defer.inlineCallbacks + def test_no_common_ancestor(self): + # We do a direct overwriting of the old state, i.e., the new state + # points to the old state. + + old_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 5) + new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", None, 10) + + self.persistence.get_unresolved_state_tree.return_value = ( + (ReturnType([new_pdu], [old_pdu]), None) + ) + + is_new = yield self.state.handle_new_state(new_pdu) + + self.assertTrue(is_new) + + self.persistence.get_unresolved_state_tree.assert_called_once_with( + new_pdu + ) + + self.assertEqual(1, self.persistence.update_current_state.call_count) + + self.assertFalse(self.replication.get_pdu.called) + @defer.inlineCallbacks def test_new_event(self): event = Mock()