Include hashes of previous pdus when referencing them
							parent
							
								
									66104da10c
								
							
						
					
					
						commit
						bb04447c44
					
				|  | @ -65,13 +65,13 @@ class SynapseEvent(JsonEncodedObject): | |||
| 
 | ||||
|     internal_keys = [ | ||||
|         "is_state", | ||||
|         "prev_events", | ||||
|         "depth", | ||||
|         "destinations", | ||||
|         "origin", | ||||
|         "outlier", | ||||
|         "power_level", | ||||
|         "redacted", | ||||
|         "prev_pdus", | ||||
|     ] | ||||
| 
 | ||||
|     required_keys = [ | ||||
|  |  | |||
|  | @ -45,9 +45,7 @@ class PduCodec(object): | |||
|         kwargs["event_id"] = encode_event_id(pdu.pdu_id, pdu.origin) | ||||
|         kwargs["room_id"] = pdu.context | ||||
|         kwargs["etype"] = pdu.pdu_type | ||||
|         kwargs["prev_events"] = [ | ||||
|             encode_event_id(p[0], p[1]) for p in pdu.prev_pdus | ||||
|         ] | ||||
|         kwargs["prev_pdus"] = pdu.prev_pdus | ||||
| 
 | ||||
|         if hasattr(pdu, "prev_state_id") and hasattr(pdu, "prev_state_origin"): | ||||
|             kwargs["prev_state"] = encode_event_id( | ||||
|  | @ -78,11 +76,8 @@ class PduCodec(object): | |||
|         d["context"] = event.room_id | ||||
|         d["pdu_type"] = event.type | ||||
| 
 | ||||
|         if hasattr(event, "prev_events"): | ||||
|             d["prev_pdus"] = [ | ||||
|                 decode_event_id(e, self.server_name) | ||||
|                 for e in event.prev_events | ||||
|             ] | ||||
|         if hasattr(event, "prev_pdus"): | ||||
|             d["prev_pdus"] = event.prev_pdus | ||||
| 
 | ||||
|         if hasattr(event, "prev_state"): | ||||
|             d["prev_state_id"], d["prev_state_origin"] = ( | ||||
|  | @ -95,7 +90,7 @@ class PduCodec(object): | |||
|         kwargs = copy.deepcopy(event.unrecognized_keys) | ||||
|         kwargs.update({ | ||||
|             k: v for k, v in d.items() | ||||
|             if k not in ["event_id", "room_id", "type", "prev_events"] | ||||
|             if k not in ["event_id", "room_id", "type"] | ||||
|         }) | ||||
| 
 | ||||
|         if "ts" not in kwargs: | ||||
|  |  | |||
|  | @ -443,7 +443,7 @@ class ReplicationLayer(object): | |||
|             min_depth = yield self.store.get_min_depth_for_context(pdu.context) | ||||
| 
 | ||||
|             if min_depth and pdu.depth > min_depth: | ||||
|                 for pdu_id, origin in pdu.prev_pdus: | ||||
|                 for pdu_id, origin, hashes in pdu.prev_pdus: | ||||
|                     exists = yield self._get_persisted_pdu(pdu_id, origin) | ||||
| 
 | ||||
|                     if not exists: | ||||
|  |  | |||
|  | @ -141,8 +141,16 @@ class Pdu(JsonEncodedObject): | |||
|                 for kid, sig in pdu_tuple.signatures.items() | ||||
|             } | ||||
| 
 | ||||
|             prev_pdus = [] | ||||
|             for prev_pdu in pdu_tuple.prev_pdu_list: | ||||
|                 prev_hashes = pdu_tuple.edge_hashes.get(prev_pdu, {}) | ||||
|                 prev_hashes = { | ||||
|                     alg: encode_base64(hsh) for alg, hsh in prev_hashes.items() | ||||
|                 } | ||||
|                 prev_pdus.append((prev_pdu[0], prev_pdu[1], prev_hashes)) | ||||
| 
 | ||||
|             return Pdu( | ||||
|                 prev_pdus=pdu_tuple.prev_pdu_list, | ||||
|                 prev_pdus=prev_pdus, | ||||
|                 **args | ||||
|             ) | ||||
|         else: | ||||
|  |  | |||
|  | @ -72,10 +72,6 @@ class StateHandler(object): | |||
| 
 | ||||
|         snapshot.fill_out_prev_events(event) | ||||
| 
 | ||||
|         event.prev_events = [ | ||||
|             e for e in event.prev_events if e != event.event_id | ||||
|         ] | ||||
| 
 | ||||
|         current_state = snapshot.prev_state_pdu | ||||
| 
 | ||||
|         if current_state: | ||||
|  |  | |||
|  | @ -177,6 +177,14 @@ class DataStore(RoomMemberStore, RoomStore, | |||
|                 txn, pdu.pdu_id, pdu.origin, key_id, signature_bytes, | ||||
|             ) | ||||
| 
 | ||||
|         for prev_pdu_id, prev_origin, prev_hashes in pdu.prev_pdus: | ||||
|             for alg, hash_base64 in prev_hashes.items(): | ||||
|                 hash_bytes = decode_base64(hash_base64) | ||||
|                 self._store_prev_pdu_hash_txn( | ||||
|                     txn, pdu.pdu_id, pdu.origin, prev_pdu_id, prev_origin, alg, | ||||
|                     hash_bytes | ||||
|                 ) | ||||
| 
 | ||||
|         if pdu.is_state: | ||||
|             self._persist_state_txn(txn, pdu.prev_pdus, cols) | ||||
|         else: | ||||
|  | @ -352,6 +360,7 @@ class DataStore(RoomMemberStore, RoomStore, | |||
|             prev_pdus = self._get_latest_pdus_in_context( | ||||
|                 txn, room_id | ||||
|             ) | ||||
| 
 | ||||
|             if state_type is not None and state_key is not None: | ||||
|                 prev_state_pdu = self._get_current_state_pdu( | ||||
|                     txn, room_id, state_type, state_key | ||||
|  | @ -401,17 +410,16 @@ class Snapshot(object): | |||
|         self.prev_state_pdu = prev_state_pdu | ||||
| 
 | ||||
|     def fill_out_prev_events(self, event): | ||||
|         if hasattr(event, "prev_events"): | ||||
|         if hasattr(event, "prev_pdus"): | ||||
|             return | ||||
| 
 | ||||
|         es = [ | ||||
|             "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus | ||||
|         event.prev_pdus = [ | ||||
|             (p_id, origin, hashes) | ||||
|             for p_id, origin, hashes, _ in self.prev_pdus | ||||
|         ] | ||||
| 
 | ||||
|         event.prev_events = [e for e in es if e != event.event_id] | ||||
| 
 | ||||
|         if self.prev_pdus: | ||||
|             event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1 | ||||
|             event.depth = max([int(v) for _, _, _, v in self.prev_pdus]) + 1 | ||||
|         else: | ||||
|             event.depth = 0 | ||||
| 
 | ||||
|  |  | |||
|  | @ -20,10 +20,13 @@ from ._base import SQLBaseStore, Table, JoinHelper | |||
| from synapse.federation.units import Pdu | ||||
| from synapse.util.logutils import log_function | ||||
| 
 | ||||
| from syutil.base64util import encode_base64 | ||||
| 
 | ||||
| from collections import namedtuple | ||||
| 
 | ||||
| import logging | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -64,6 +67,8 @@ class PduStore(SQLBaseStore): | |||
|                 for r in PduEdgesTable.decode_results(txn.fetchall()) | ||||
|             ] | ||||
| 
 | ||||
|             edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin) | ||||
| 
 | ||||
|             hashes = self._get_pdu_hashes_txn(txn, pdu_id, origin) | ||||
|             signatures = self._get_pdu_origin_signatures_txn( | ||||
|                 txn, pdu_id, origin | ||||
|  | @ -86,7 +91,7 @@ class PduStore(SQLBaseStore): | |||
|             row = txn.fetchone() | ||||
|             if row: | ||||
|                 results.append(PduTuple( | ||||
|                     PduEntry(*row), edges, hashes, signatures | ||||
|                     PduEntry(*row), edges, hashes, signatures, edge_hashes | ||||
|                 )) | ||||
| 
 | ||||
|         return results | ||||
|  | @ -310,9 +315,14 @@ class PduStore(SQLBaseStore): | |||
|             (context, ) | ||||
|         ) | ||||
| 
 | ||||
|         results = txn.fetchall() | ||||
|         results = [] | ||||
|         for pdu_id, origin, depth in txn.fetchall(): | ||||
|             hashes = self._get_pdu_hashes_txn(txn, pdu_id, origin) | ||||
|             sha256_bytes = hashes["sha256"] | ||||
|             prev_hashes = {"sha256": encode_base64(sha256_bytes)} | ||||
|             results.append((pdu_id, origin, prev_hashes, depth)) | ||||
| 
 | ||||
|         return [(row[0], row[1], row[2]) for row in results] | ||||
|         return results | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_oldest_pdus_in_context(self, context): | ||||
|  | @ -431,7 +441,7 @@ class PduStore(SQLBaseStore): | |||
|                 "DELETE FROM %s WHERE pdu_id = ? AND origin = ?" | ||||
|                 % PduForwardExtremitiesTable.table_name | ||||
|             ) | ||||
|             txn.executemany(query, prev_pdus) | ||||
|             txn.executemany(query, list(p[:2] for p in prev_pdus)) | ||||
| 
 | ||||
|             # We only insert as a forward extremety the new pdu if there are no | ||||
|             # other pdus that reference it as a prev pdu | ||||
|  | @ -454,7 +464,7 @@ class PduStore(SQLBaseStore): | |||
|             # deleted in a second if they're incorrect anyway. | ||||
|             txn.executemany( | ||||
|                 PduBackwardExtremitiesTable.insert_statement(), | ||||
|                 [(i, o, context) for i, o in prev_pdus] | ||||
|                 [(i, o, context) for i, o, _ in prev_pdus] | ||||
|             ) | ||||
| 
 | ||||
|             # Also delete from the backwards extremities table all ones that | ||||
|  | @ -915,7 +925,7 @@ This does not include a prev_pdus key. | |||
| 
 | ||||
| PduTuple = namedtuple( | ||||
|     "PduTuple", | ||||
|     ("pdu_entry", "prev_pdu_list", "hashes", "signatures") | ||||
|     ("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes") | ||||
| ) | ||||
| """ This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent | ||||
| the `prev_pdus` key of a PDU. | ||||
|  |  | |||
|  | @ -34,3 +34,19 @@ CREATE TABLE IF NOT EXISTS pdu_origin_signatures ( | |||
| CREATE INDEX IF NOT EXISTS pdu_origin_signatures_id ON pdu_origin_signatures ( | ||||
|     pdu_id, origin | ||||
| ); | ||||
| 
 | ||||
| CREATE TABLE IF NOT EXISTS pdu_edge_hashes( | ||||
|     pdu_id TEXT, | ||||
|     origin TEXT, | ||||
|     prev_pdu_id TEXT, | ||||
|     prev_origin TEXT, | ||||
|     algorithm TEXT, | ||||
|     hash BLOB, | ||||
|     CONSTRAINT uniqueness UNIQUE ( | ||||
|         pdu_id, origin, prev_pdu_id, prev_origin, algorithm | ||||
|     ) | ||||
| ); | ||||
| 
 | ||||
| CREATE INDEX IF NOT EXISTS pdu_edge_hashes_id ON pdu_edge_hashes( | ||||
|     pdu_id, origin | ||||
| ); | ||||
|  |  | |||
|  | @ -88,3 +88,34 @@ class SignatureStore(SQLBaseStore): | |||
|             "signature": buffer(signature_bytes), | ||||
|         }) | ||||
| 
 | ||||
|     def _get_prev_pdu_hashes_txn(self, txn, pdu_id, origin): | ||||
|         """Get all the hashes for previous PDUs of a PDU | ||||
|         Args: | ||||
|             txn (cursor): | ||||
|             pdu_id (str): Id of the PDU. | ||||
|             origin (str): Origin of the PDU. | ||||
|         Returns: | ||||
|             dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes. | ||||
|         """ | ||||
|         query = ( | ||||
|             "SELECT prev_pdu_id, prev_origin, algorithm, hash" | ||||
|             " FROM pdu_edge_hashes" | ||||
|             " WHERE pdu_id = ? and origin = ?" | ||||
|         ) | ||||
|         txn.execute(query, (pdu_id, origin)) | ||||
|         results = {} | ||||
|         for prev_pdu_id, prev_origin, algorithm, hash_bytes in txn.fetchall(): | ||||
|             hashes = results.setdefault((prev_pdu_id, prev_origin), {}) | ||||
|             hashes[algorithm] = hash_bytes | ||||
|         return results | ||||
| 
 | ||||
|     def _store_prev_pdu_hash_txn(self, txn, pdu_id, origin, prev_pdu_id, | ||||
|                              prev_origin, algorithm, hash_bytes): | ||||
|         self._simple_insert_txn(txn, "pdu_edge_hashes", { | ||||
|             "pdu_id": pdu_id, | ||||
|             "origin": origin, | ||||
|             "prev_pdu_id": prev_pdu_id, | ||||
|             "prev_origin": prev_origin, | ||||
|             "algorithm": algorithm, | ||||
|             "hash": buffer(hash_bytes), | ||||
|         }) | ||||
|  |  | |||
|  | @ -41,7 +41,7 @@ def make_pdu(prev_pdus=[], **kwargs): | |||
|     } | ||||
|     pdu_fields.update(kwargs) | ||||
| 
 | ||||
|     return PduTuple(PduEntry(**pdu_fields), prev_pdus, {}, {}) | ||||
|     return PduTuple(PduEntry(**pdu_fields), prev_pdus, {}, {}, {}) | ||||
| 
 | ||||
| 
 | ||||
| class FederationTestCase(unittest.TestCase): | ||||
|  |  | |||
|  | @ -88,7 +88,7 @@ class PduCodecTestCase(unittest.TestCase): | |||
|         self.assertEquals(pdu.context, event.room_id) | ||||
|         self.assertEquals(pdu.is_state, event.is_state) | ||||
|         self.assertEquals(pdu.depth, event.depth) | ||||
|         self.assertEquals(["alice@bob.com"], event.prev_events) | ||||
|         self.assertEquals(pdu.prev_pdus, event.prev_pdus) | ||||
|         self.assertEquals(pdu.content, event.content) | ||||
| 
 | ||||
|     def test_pdu_from_event(self): | ||||
|  | @ -144,7 +144,7 @@ class PduCodecTestCase(unittest.TestCase): | |||
|         self.assertEquals(pdu.context, event.room_id) | ||||
|         self.assertEquals(pdu.is_state, event.is_state) | ||||
|         self.assertEquals(pdu.depth, event.depth) | ||||
|         self.assertEquals(["alice@bob.com"], event.prev_events) | ||||
|         self.assertEquals(pdu.prev_pdus, event.prev_pdus) | ||||
|         self.assertEquals(pdu.content, event.content) | ||||
|         self.assertEquals(pdu.state_key, event.state_key) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Mark Haines
						Mark Haines