Construct the EventContext in the state handler rather than constructing one and then immediately calling state_handler.annotate_context_with_state
parent
3c7857e49b
commit
c3eae8a88c
|
@ -20,8 +20,6 @@ from synapse.util.async import run_on_reactor
|
||||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||||
from synapse.api.constants import Membership, EventTypes
|
from synapse.api.constants import Membership, EventTypes
|
||||||
|
|
||||||
from synapse.events.snapshot import EventContext
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,15 +75,10 @@ class BaseHandler(object):
|
||||||
|
|
||||||
state_handler = self.state_handler
|
state_handler = self.state_handler
|
||||||
|
|
||||||
context = EventContext()
|
context = yield state_handler.compute_event_context(builder)
|
||||||
ret = yield state_handler.annotate_context_with_state(
|
|
||||||
builder,
|
|
||||||
context,
|
|
||||||
)
|
|
||||||
prev_state = ret
|
|
||||||
|
|
||||||
if builder.is_state():
|
if builder.is_state():
|
||||||
builder.prev_state = prev_state
|
builder.prev_state = context.prev_state_events
|
||||||
|
|
||||||
yield self.auth.add_auth_events(builder, context)
|
yield self.auth.add_auth_events(builder, context)
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,6 @@
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
from synapse.events.snapshot import EventContext
|
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError, FederationError, SynapseError, StoreError,
|
AuthError, FederationError, SynapseError, StoreError,
|
||||||
|
@ -260,8 +259,7 @@ class FederationHandler(BaseHandler):
|
||||||
event = pdu
|
event = pdu
|
||||||
|
|
||||||
# FIXME (erikj): Not sure this actually works :/
|
# FIXME (erikj): Not sure this actually works :/
|
||||||
context = EventContext()
|
context = yield self.state_handler.compute_event_context(event)
|
||||||
yield self.state_handler.annotate_context_with_state(event, context)
|
|
||||||
|
|
||||||
events.append((event, context))
|
events.append((event, context))
|
||||||
|
|
||||||
|
@ -555,8 +553,7 @@ class FederationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
context = EventContext()
|
context = yield self.state_handler.compute_event_context(event)
|
||||||
yield self.state_handler.annotate_context_with_state(event, context)
|
|
||||||
|
|
||||||
yield self.store.persist_event(
|
yield self.store.persist_event(
|
||||||
event,
|
event,
|
||||||
|
@ -688,11 +685,8 @@ class FederationHandler(BaseHandler):
|
||||||
event.event_id, event.signatures,
|
event.event_id, event.signatures,
|
||||||
)
|
)
|
||||||
|
|
||||||
context = EventContext()
|
context = yield self.state_handler.compute_event_context(
|
||||||
yield self.state_handler.annotate_context_with_state(
|
event, old_state=state
|
||||||
event,
|
|
||||||
context,
|
|
||||||
old_state=state
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
|
@ -19,6 +19,7 @@ from twisted.internet import defer
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.events.snapshot import EventContext
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
@ -70,7 +71,7 @@ class StateHandler(object):
|
||||||
defer.returnValue(res[1].values())
|
defer.returnValue(res[1].values())
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def annotate_context_with_state(self, event, context, old_state=None):
|
def compute_event_context(self, event, old_state=None):
|
||||||
""" Fills out the context with the `current state` of the graph. The
|
""" Fills out the context with the `current state` of the graph. The
|
||||||
`current state` here is defined to be the state of the event graph
|
`current state` here is defined to be the state of the event graph
|
||||||
just before the event - i.e. it never includes `event`
|
just before the event - i.e. it never includes `event`
|
||||||
|
@ -80,8 +81,11 @@ class StateHandler(object):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (EventBase)
|
event (EventBase)
|
||||||
context (EventContext)
|
Returns:
|
||||||
|
an EventContext
|
||||||
"""
|
"""
|
||||||
|
context = EventContext()
|
||||||
|
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
if old_state:
|
if old_state:
|
||||||
|
@ -107,7 +111,8 @@ class StateHandler(object):
|
||||||
if replaces.event_id != event.event_id: # Paranoia check
|
if replaces.event_id != event.event_id: # Paranoia check
|
||||||
event.unsigned["replaces_state"] = replaces.event_id
|
event.unsigned["replaces_state"] = replaces.event_id
|
||||||
|
|
||||||
defer.returnValue([])
|
context.prev_state_events = []
|
||||||
|
defer.returnValue(context)
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
ret = yield self.resolve_state_groups(
|
ret = yield self.resolve_state_groups(
|
||||||
|
@ -145,7 +150,8 @@ class StateHandler(object):
|
||||||
else:
|
else:
|
||||||
context.auth_events = {}
|
context.auth_events = {}
|
||||||
|
|
||||||
defer.returnValue(prev_state)
|
context.prev_state_events = prev_state
|
||||||
|
defer.returnValue(context)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
|
|
|
@ -34,7 +34,7 @@ class FederationTestCase(unittest.TestCase):
|
||||||
self.mock_config.signing_key = [MockKey()]
|
self.mock_config.signing_key = [MockKey()]
|
||||||
|
|
||||||
self.state_handler = NonCallableMock(spec_set=[
|
self.state_handler = NonCallableMock(spec_set=[
|
||||||
"annotate_context_with_state",
|
"compute_event_context",
|
||||||
])
|
])
|
||||||
|
|
||||||
self.auth = NonCallableMock(spec_set=[
|
self.auth = NonCallableMock(spec_set=[
|
||||||
|
@ -91,11 +91,12 @@ class FederationTestCase(unittest.TestCase):
|
||||||
self.datastore.get_room.return_value = defer.succeed(True)
|
self.datastore.get_room.return_value = defer.succeed(True)
|
||||||
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
self.auth.check_host_in_room.return_value = defer.succeed(True)
|
||||||
|
|
||||||
def annotate(ev, context, old_state=None):
|
def annotate(ev, old_state=None):
|
||||||
|
context = Mock()
|
||||||
context.current_state = {}
|
context.current_state = {}
|
||||||
context.auth_events = {}
|
context.auth_events = {}
|
||||||
return defer.succeed(False)
|
return defer.succeed(context)
|
||||||
self.state_handler.annotate_context_with_state.side_effect = annotate
|
self.state_handler.compute_event_context.side_effect = annotate
|
||||||
|
|
||||||
yield self.handlers.federation_handler.on_receive_pdu(
|
yield self.handlers.federation_handler.on_receive_pdu(
|
||||||
"fo", pdu, False
|
"fo", pdu, False
|
||||||
|
@ -109,15 +110,12 @@ class FederationTestCase(unittest.TestCase):
|
||||||
context=ANY,
|
context=ANY,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.state_handler.annotate_context_with_state.assert_called_once_with(
|
self.state_handler.compute_event_context.assert_called_once_with(
|
||||||
ANY,
|
ANY, old_state=None,
|
||||||
ANY,
|
|
||||||
old_state=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
self.auth.check.assert_called_once_with(ANY, auth_events={})
|
||||||
|
|
||||||
self.notifier.on_new_room_event.assert_called_once_with(
|
self.notifier.on_new_room_event.assert_called_once_with(
|
||||||
ANY,
|
ANY, extra_users=[]
|
||||||
extra_users=[]
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -60,7 +60,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
"check_host_in_room",
|
"check_host_in_room",
|
||||||
]),
|
]),
|
||||||
state_handler=NonCallableMock(spec_set=[
|
state_handler=NonCallableMock(spec_set=[
|
||||||
"annotate_context_with_state",
|
"compute_event_context",
|
||||||
"get_current_state",
|
"get_current_state",
|
||||||
]),
|
]),
|
||||||
config=self.mock_config,
|
config=self.mock_config,
|
||||||
|
@ -110,7 +110,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
defer.succeed([])
|
defer.succeed([])
|
||||||
)
|
)
|
||||||
|
|
||||||
def annotate(_, ctx):
|
def annotate(_):
|
||||||
|
ctx = Mock()
|
||||||
ctx.current_state = {
|
ctx.current_state = {
|
||||||
(EventTypes.Member, "@alice:green"): self._create_member(
|
(EventTypes.Member, "@alice:green"): self._create_member(
|
||||||
user_id="@alice:green",
|
user_id="@alice:green",
|
||||||
|
@ -121,10 +122,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
ctx.prev_state_events = []
|
||||||
|
|
||||||
return defer.succeed(True)
|
return defer.succeed(ctx)
|
||||||
|
|
||||||
self.state_handler.annotate_context_with_state.side_effect = annotate
|
self.state_handler.compute_event_context.side_effect = annotate
|
||||||
|
|
||||||
def add_auth(_, ctx):
|
def add_auth(_, ctx):
|
||||||
ctx.auth_events = ctx.current_state[
|
ctx.auth_events = ctx.current_state[
|
||||||
|
@ -146,8 +148,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
|
|
||||||
yield room_handler.change_membership(event, context)
|
yield room_handler.change_membership(event, context)
|
||||||
|
|
||||||
self.state_handler.annotate_context_with_state.assert_called_once_with(
|
self.state_handler.compute_event_context.assert_called_once_with(
|
||||||
builder, context
|
builder
|
||||||
)
|
)
|
||||||
|
|
||||||
self.auth.add_auth_events.assert_called_once_with(
|
self.auth.add_auth_events.assert_called_once_with(
|
||||||
|
@ -189,7 +191,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
defer.succeed([])
|
defer.succeed([])
|
||||||
)
|
)
|
||||||
|
|
||||||
def annotate(_, ctx):
|
def annotate(_):
|
||||||
|
ctx = Mock()
|
||||||
ctx.current_state = {
|
ctx.current_state = {
|
||||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
(EventTypes.Member, "@bob:red"): self._create_member(
|
||||||
user_id="@bob:red",
|
user_id="@bob:red",
|
||||||
|
@ -197,10 +200,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
membership=Membership.INVITE
|
membership=Membership.INVITE
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
ctx.prev_state_events = []
|
||||||
|
|
||||||
return defer.succeed(True)
|
return defer.succeed(ctx)
|
||||||
|
|
||||||
self.state_handler.annotate_context_with_state.side_effect = annotate
|
self.state_handler.compute_event_context.side_effect = annotate
|
||||||
|
|
||||||
def add_auth(_, ctx):
|
def add_auth(_, ctx):
|
||||||
ctx.auth_events = ctx.current_state[
|
ctx.auth_events = ctx.current_state[
|
||||||
|
@ -262,7 +266,8 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
defer.succeed([])
|
defer.succeed([])
|
||||||
)
|
)
|
||||||
|
|
||||||
def annotate(_, ctx):
|
def annotate(_):
|
||||||
|
ctx = Mock()
|
||||||
ctx.current_state = {
|
ctx.current_state = {
|
||||||
(EventTypes.Member, "@bob:red"): self._create_member(
|
(EventTypes.Member, "@bob:red"): self._create_member(
|
||||||
user_id="@bob:red",
|
user_id="@bob:red",
|
||||||
|
@ -270,10 +275,11 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
|
||||||
membership=Membership.JOIN
|
membership=Membership.JOIN
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
ctx.prev_state_events = []
|
||||||
|
|
||||||
return defer.succeed(True)
|
return defer.succeed(ctx)
|
||||||
|
|
||||||
self.state_handler.annotate_context_with_state.side_effect = annotate
|
self.state_handler.compute_event_context.side_effect = annotate
|
||||||
|
|
||||||
def add_auth(_, ctx):
|
def add_auth(_, ctx):
|
||||||
ctx.auth_events = ctx.current_state[
|
ctx.auth_events = ctx.current_state[
|
||||||
|
|
|
@ -38,7 +38,6 @@ class StateTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_annotate_with_old_message(self):
|
def test_annotate_with_old_message(self):
|
||||||
event = self.create_event(type="test_message", name="event")
|
event = self.create_event(type="test_message", name="event")
|
||||||
context = Mock()
|
|
||||||
|
|
||||||
old_state = [
|
old_state = [
|
||||||
self.create_event(type="test1", state_key="1"),
|
self.create_event(type="test1", state_key="1"),
|
||||||
|
@ -46,8 +45,8 @@ class StateTestCase(unittest.TestCase):
|
||||||
self.create_event(type="test2", state_key=""),
|
self.create_event(type="test2", state_key=""),
|
||||||
]
|
]
|
||||||
|
|
||||||
yield self.state.annotate_context_with_state(
|
context = yield self.state.compute_event_context(
|
||||||
event, context, old_state=old_state
|
event, old_state=old_state
|
||||||
)
|
)
|
||||||
|
|
||||||
for k, v in context.current_state.items():
|
for k, v in context.current_state.items():
|
||||||
|
@ -64,7 +63,6 @@ class StateTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_annotate_with_old_state(self):
|
def test_annotate_with_old_state(self):
|
||||||
event = self.create_event(type="state", state_key="", name="event")
|
event = self.create_event(type="state", state_key="", name="event")
|
||||||
context = Mock()
|
|
||||||
|
|
||||||
old_state = [
|
old_state = [
|
||||||
self.create_event(type="test1", state_key="1"),
|
self.create_event(type="test1", state_key="1"),
|
||||||
|
@ -72,8 +70,8 @@ class StateTestCase(unittest.TestCase):
|
||||||
self.create_event(type="test2", state_key=""),
|
self.create_event(type="test2", state_key=""),
|
||||||
]
|
]
|
||||||
|
|
||||||
yield self.state.annotate_context_with_state(
|
context = yield self.state.compute_event_context(
|
||||||
event, context, old_state=old_state
|
event, old_state=old_state
|
||||||
)
|
)
|
||||||
|
|
||||||
for k, v in context.current_state.items():
|
for k, v in context.current_state.items():
|
||||||
|
@ -92,7 +90,6 @@ class StateTestCase(unittest.TestCase):
|
||||||
def test_trivial_annotate_message(self):
|
def test_trivial_annotate_message(self):
|
||||||
event = self.create_event(type="test_message", name="event")
|
event = self.create_event(type="test_message", name="event")
|
||||||
event.prev_events = []
|
event.prev_events = []
|
||||||
context = Mock()
|
|
||||||
|
|
||||||
old_state = [
|
old_state = [
|
||||||
self.create_event(type="test1", state_key="1"),
|
self.create_event(type="test1", state_key="1"),
|
||||||
|
@ -106,7 +103,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
group_name: old_state,
|
group_name: old_state,
|
||||||
}
|
}
|
||||||
|
|
||||||
yield self.state.annotate_context_with_state(event, context)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
for k, v in context.current_state.items():
|
for k, v in context.current_state.items():
|
||||||
type, state_key = k
|
type, state_key = k
|
||||||
|
@ -124,7 +121,6 @@ class StateTestCase(unittest.TestCase):
|
||||||
def test_trivial_annotate_state(self):
|
def test_trivial_annotate_state(self):
|
||||||
event = self.create_event(type="state", state_key="", name="event")
|
event = self.create_event(type="state", state_key="", name="event")
|
||||||
event.prev_events = []
|
event.prev_events = []
|
||||||
context = Mock()
|
|
||||||
|
|
||||||
old_state = [
|
old_state = [
|
||||||
self.create_event(type="test1", state_key="1"),
|
self.create_event(type="test1", state_key="1"),
|
||||||
|
@ -138,7 +134,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
group_name: old_state,
|
group_name: old_state,
|
||||||
}
|
}
|
||||||
|
|
||||||
yield self.state.annotate_context_with_state(event, context)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
for k, v in context.current_state.items():
|
for k, v in context.current_state.items():
|
||||||
type, state_key = k
|
type, state_key = k
|
||||||
|
@ -156,7 +152,6 @@ class StateTestCase(unittest.TestCase):
|
||||||
def test_resolve_message_conflict(self):
|
def test_resolve_message_conflict(self):
|
||||||
event = self.create_event(type="test_message", name="event")
|
event = self.create_event(type="test_message", name="event")
|
||||||
event.prev_events = []
|
event.prev_events = []
|
||||||
context = Mock()
|
|
||||||
|
|
||||||
old_state_1 = [
|
old_state_1 = [
|
||||||
self.create_event(type="test1", state_key="1"),
|
self.create_event(type="test1", state_key="1"),
|
||||||
|
@ -178,7 +173,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
group_name_2: old_state_2,
|
group_name_2: old_state_2,
|
||||||
}
|
}
|
||||||
|
|
||||||
yield self.state.annotate_context_with_state(event, context)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state), 5)
|
self.assertEqual(len(context.current_state), 5)
|
||||||
|
|
||||||
|
@ -188,7 +183,6 @@ class StateTestCase(unittest.TestCase):
|
||||||
def test_resolve_state_conflict(self):
|
def test_resolve_state_conflict(self):
|
||||||
event = self.create_event(type="test4", state_key="", name="event")
|
event = self.create_event(type="test4", state_key="", name="event")
|
||||||
event.prev_events = []
|
event.prev_events = []
|
||||||
context = Mock()
|
|
||||||
|
|
||||||
old_state_1 = [
|
old_state_1 = [
|
||||||
self.create_event(type="test1", state_key="1"),
|
self.create_event(type="test1", state_key="1"),
|
||||||
|
@ -210,7 +204,7 @@ class StateTestCase(unittest.TestCase):
|
||||||
group_name_2: old_state_2,
|
group_name_2: old_state_2,
|
||||||
}
|
}
|
||||||
|
|
||||||
yield self.state.annotate_context_with_state(event, context)
|
context = yield self.state.compute_event_context(event)
|
||||||
|
|
||||||
self.assertEqual(len(context.current_state), 5)
|
self.assertEqual(len(context.current_state), 5)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue