Implement new replace_state and changed prev_state

`prev_state` is now a list of previous state ids, similiar to
prev_events. `replace_state` now points to what we think was replaced.
pull/12/head
Erik Johnston 2014-11-06 15:10:55 +00:00
parent 3791b75000
commit 4317c8e583
13 changed files with 219 additions and 127 deletions

View File

@ -60,6 +60,7 @@ class SynapseEvent(JsonEncodedObject):
"age_ts",
"prev_content",
"prev_state",
"replaces_state",
"redacted_because",
"origin_server_ts",
]

View File

@ -147,10 +147,7 @@ class DirectoryHandler(BaseHandler):
content={"aliases": aliases},
)
snapshot = yield self.store.snapshot_room(
room_id=room_id,
user_id=user_id,
)
snapshot = yield self.store.snapshot_room(event)
yield self._on_new_room_event(
event, snapshot, extra_users=[user_id], suppress_auth=True

View File

@ -313,9 +313,7 @@ class FederationHandler(BaseHandler):
state_key=user_id,
)
snapshot = yield self.store.snapshot_room(
event.room_id, event.user_id,
)
snapshot = yield self.store.snapshot_room(event)
snapshot.fill_out_prev_events(event)
yield self.state_handler.annotate_state_groups(event)

View File

@ -81,7 +81,7 @@ class MessageHandler(BaseHandler):
user = self.hs.parse_userid(event.user_id)
assert user.is_mine, "User must be our own: %s" % (user,)
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
snapshot = yield self.store.snapshot_room(event)
yield self._on_new_room_event(
event, snapshot, suppress_auth=suppress_auth
@ -141,12 +141,7 @@ class MessageHandler(BaseHandler):
SynapseError if something went wrong.
"""
snapshot = yield self.store.snapshot_room(
event.room_id,
event.user_id,
state_type=event.type,
state_key=event.state_key,
)
snapshot = yield self.store.snapshot_room(event)
yield self._on_new_room_event(event, snapshot)
@ -214,7 +209,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def send_feedback(self, event):
snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
snapshot = yield self.store.snapshot_room(event)
# store message in db
yield self._on_new_room_event(event, snapshot)

View File

@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import Membership
from synapse.api.events.room import RoomMemberEvent
from ._base import BaseHandler
@ -196,10 +195,7 @@ class ProfileHandler(BaseHandler):
)
for j in joins:
snapshot = yield self.store.snapshot_room(
j.room_id, j.state_key, RoomMemberEvent.TYPE,
j.state_key
)
snapshot = yield self.store.snapshot_room(j)
content = {
"membership": j.content["membership"],

View File

@ -122,10 +122,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def handle_event(event):
snapshot = yield self.store.snapshot_room(
room_id=room_id,
user_id=user_id,
)
snapshot = yield self.store.snapshot_room(event)
logger.debug("Event: %s", event)
@ -364,10 +361,8 @@ class RoomMemberHandler(BaseHandler):
"""
target_user_id = event.state_key
snapshot = yield self.store.snapshot_room(
event.room_id, event.user_id,
RoomMemberEvent.TYPE, target_user_id
)
snapshot = yield self.store.snapshot_room(event)
## TODO(markjh): get prev state from snapshot.
prev_state = yield self.store.get_room_member(
target_user_id, event.room_id
@ -442,10 +437,7 @@ class RoomMemberHandler(BaseHandler):
content=content,
)
snapshot = yield self.store.snapshot_room(
room_id, joinee.to_string(), RoomMemberEvent.TYPE,
joinee.to_string()
)
snapshot = yield self.store.snapshot_room(new_event)
yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)

View File

@ -138,7 +138,7 @@ class RoomStateEventRestServlet(RestServlet):
raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND
)
defer.returnValue((200, data[0].get_dict()["content"]))
defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key):

View File

@ -45,40 +45,6 @@ class StateHandler(object):
self.server_name = hs.hostname
self.hs = hs
@defer.inlineCallbacks
@log_function
def handle_new_event(self, event, snapshot):
""" Given an event this works out if a) we have sufficient power level
to update the state and b) works out what the prev_state should be.
Returns:
Deferred: Resolved with a boolean indicating if we successfully
updated the state.
Raised:
AuthError
"""
# This needs to be done in a transaction.
if not hasattr(event, "state_key"):
return
# Now I need to fill out the prev state and work out if it has auth
# (w.r.t. to power levels)
snapshot.fill_out_prev_events(event)
yield self.annotate_state_groups(event)
if event.old_state_events:
current_state = event.old_state_events.get(
(event.type, event.state_key)
)
if current_state:
event.prev_state = current_state.event_id
defer.returnValue(True)
@defer.inlineCallbacks
@log_function
def annotate_state_groups(self, event, old_state=None):
@ -111,7 +77,10 @@ class StateHandler(object):
event.old_state_events = copy.deepcopy(new_state)
if hasattr(event, "state_key"):
new_state[(event.type, event.state_key)] = event
key = (event.type, event.state_key)
if key in new_state:
event.replaces_state = new_state[key].event_id
new_state[key] = event
event.state_group = None
event.state_events = new_state

View File

@ -242,8 +242,8 @@ class DataStore(RoomMemberStore, RoomStore,
"state_key": event.state_key,
}
if hasattr(event, "prev_state"):
vals["prev_state"] = event.prev_state
if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state
self._simple_insert_txn(txn, "state_events", vals)
@ -258,6 +258,40 @@ class DataStore(RoomMemberStore, RoomStore,
}
)
for e_id, h in event.prev_state:
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event.event_id,
"prev_event_id": e_id,
"room_id": event.room_id,
"is_state": 1,
},
or_ignore=True,
)
if not backfilled:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
}
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn(
@ -357,7 +391,7 @@ class DataStore(RoomMemberStore, RoomStore,
],
)
def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
def snapshot_room(self, event):
"""Snapshot the room for an update by a user
Args:
room_id (synapse.types.RoomId): The room to snapshot.
@ -368,16 +402,29 @@ class DataStore(RoomMemberStore, RoomStore,
synapse.storage.Snapshot: A snapshot of the state of the room.
"""
def _snapshot(txn):
membership_state = self._get_room_member(txn, user_id, room_id)
prev_events = self._get_latest_events_in_room(txn, room_id)
prev_events = self._get_latest_events_in_room(
txn,
event.room_id
)
prev_state = None
state_key = None
if hasattr(event, "state_key"):
state_key = event.state_key
prev_state = self._get_latest_state_in_room(
txn,
event.room_id,
type=event.type,
state_key=state_key,
)
return Snapshot(
store=self,
room_id=room_id,
user_id=user_id,
room_id=event.room_id,
user_id=event.user_id,
prev_events=prev_events,
membership_state=membership_state,
state_type=state_type,
prev_state=prev_state,
state_type=event.type,
state_key=state_key,
)
@ -400,30 +447,29 @@ class Snapshot(object):
"""
def __init__(self, store, room_id, user_id, prev_events,
membership_state, state_type=None, state_key=None,
prev_state_pdu=None):
prev_state, state_type=None, state_key=None):
self.store = store
self.room_id = room_id
self.user_id = user_id
self.prev_events = prev_events
self.membership_state = membership_state
self.prev_state = prev_state
self.state_type = state_type
self.state_key = state_key
self.prev_state_pdu = prev_state_pdu
def fill_out_prev_events(self, event):
if hasattr(event, "prev_events"):
return
if not hasattr(event, "prev_events"):
event.prev_events = [
(event_id, hashes)
for event_id, hashes, _ in self.prev_events
]
event.prev_events = [
(event_id, hashes)
for event_id, hashes, _ in self.prev_events
]
if self.prev_events:
event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
else:
event.depth = 0
if self.prev_events:
event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
else:
event.depth = 0
if not hasattr(event, "prev_state") and self.prev_state is not None:
event.prev_state = self.prev_state
def schema_path(schema):

View File

@ -245,7 +245,6 @@ class SQLBaseStore(object):
return [r[0] for r in txn.fetchall()]
def _simple_select_onecol(self, table, keyvalues, retcol):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@ -273,17 +272,30 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
return self.runInteraction(
"_simple_select_list",
self._simple_select_list_txn,
table, keyvalues, retcols
)
def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
txn : Transaction object
table : string giving the table name
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
"""
sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k) for k in keyvalues)
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
def func(txn):
txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn)
return self.runInteraction("_simple_select_list", func)
txn.execute(sql, keyvalues.values())
return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None):
@ -417,6 +429,10 @@ class SQLBaseStore(object):
d.pop("topological_ordering", None)
d.pop("processed", None)
d["origin_server_ts"] = d.pop("ts", 0)
replaces_state = d.pop("prev_state", None)
if replaces_state:
d["replaces_state"] = replaces_state
d.update(json.loads(row_dict["unrecognized_keys"]))
d["content"] = json.loads(d["content"])
@ -450,16 +466,32 @@ class SQLBaseStore(object):
k: encode_base64(v) for k, v in signatures.items()
}
ev.prev_events = self._get_prev_events(txn, ev.event_id)
prevs = self._get_prev_events_and_state(txn, ev.event_id)
if hasattr(ev, "prev_state"):
# Load previous state_content.
# TODO: Should we be pulling this out above?
cursor = txn.execute(select_event_sql, (ev.prev_state,))
prevs = self.cursor_to_dict(cursor)
if prevs:
prev = self._parse_event_from_row(prevs[0])
ev.prev_content = prev.content
ev.prev_events = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 0
]
if hasattr(ev, "state_key"):
ev.prev_state = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 1
]
if hasattr(ev, "replaces_state"):
# Load previous state_content.
# FIXME (erikj): Handle multiple prev_states.
cursor = txn.execute(
select_event_sql,
(ev.replaces_state,)
)
prevs = self.cursor_to_dict(cursor)
if prevs:
prev = self._parse_event_from_row(prevs[0])
ev.prev_content = prev.content
if not hasattr(ev, "redacted"):
logger.debug("Doesn't have redacted key: %s", ev)

View File

@ -69,19 +69,21 @@ class EventFederationStore(SQLBaseStore):
return results
def _get_prev_events(self, txn, event_id):
prev_ids = self._simple_select_onecol_txn(
def _get_latest_state_in_room(self, txn, room_id, type, state_key):
event_ids = self._simple_select_onecol_txn(
txn,
table="event_edges",
table="state_forward_extremities",
keyvalues={
"event_id": event_id,
"room_id": room_id,
"type": type,
"state_key": state_key,
},
retcol="prev_event_id",
retcol="event_id",
)
results = []
for prev_event_id in prev_ids:
hashes = self._get_event_reference_hashes_txn(txn, prev_event_id)
for event_id in event_ids:
hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
@ -90,6 +92,53 @@ class EventFederationStore(SQLBaseStore):
return results
def _get_prev_events(self, txn, event_id):
results = self._get_prev_events_and_state(
txn,
event_id,
is_state=0,
)
return [(e_id, h, ) for e_id, h, _ in results]
def _get_prev_state(self, txn, event_id):
results = self._get_prev_events_and_state(
txn,
event_id,
is_state=1,
)
return [(e_id, h, ) for e_id, h, _ in results]
def _get_prev_events_and_state(self, txn, event_id, is_state=None):
keyvalues = {
"event_id": event_id,
}
if is_state is not None:
keyvalues["is_state"] = is_state
res = self._simple_select_list_txn(
txn,
table="event_edges",
keyvalues=keyvalues,
retcols=["prev_event_id", "is_state"],
)
results = []
for d in res:
hashes = self._get_event_reference_hashes_txn(
txn,
d["prev_event_id"]
)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
}
results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
return results
def get_min_depth(self, room_id):
return self.runInteraction(
"get_min_depth",
@ -135,6 +184,7 @@ class EventFederationStore(SQLBaseStore):
"event_id": event_id,
"prev_event_id": e_id,
"room_id": room_id,
"is_state": 0,
},
or_ignore=True,
)

View File

@ -1,7 +1,7 @@
CREATE TABLE IF NOT EXISTS event_forward_extremities(
event_id TEXT,
room_id TEXT,
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
);
@ -10,8 +10,8 @@ CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
CREATE TABLE IF NOT EXISTS event_backward_extremities(
event_id TEXT,
room_id TEXT,
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
);
@ -20,10 +20,11 @@ CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id
CREATE TABLE IF NOT EXISTS event_edges(
event_id TEXT,
prev_event_id TEXT,
room_id TEXT,
CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id)
event_id TEXT NOT NULL,
prev_event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
is_state INTEGER NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state)
);
CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
@ -31,8 +32,8 @@ CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
CREATE TABLE IF NOT EXISTS room_depth(
room_id TEXT,
min_depth INTEGER,
room_id TEXT NOT NULL,
min_depth INTEGER NOT NULL,
CONSTRAINT uniqueness UNIQUE (room_id)
);
@ -40,10 +41,25 @@ CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
create TABLE IF NOT EXISTS event_destinations(
event_id TEXT,
destination TEXT,
event_id TEXT NOT NULL,
destination TEXT NOT NULL,
delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
);
CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
CREATE TABLE IF NOT EXISTS state_forward_extremities(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
type TEXT NOT NULL,
state_key TEXT NOT NULL,
CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
);
CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities(
room_id, type, state_key
);
CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id);

View File

@ -80,7 +80,7 @@ class JsonEncodedObject(object):
def get_full_dict(self):
d = {
k: v for (k, v) in self.__dict__.items()
k: _encode(v) for (k, v) in self.__dict__.items()
if k in self.valid_keys or k in self.internal_keys
}
d.update(self.unrecognized_keys)