MatrixSynapse/synapse/storage/pdu.py

933 lines
27 KiB
Python

# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore, Table, JoinHelper
from synapse.federation.units import Pdu
from synapse.util.logutils import log_function
from syutil.base64util import encode_base64
from collections import namedtuple
import logging
logger = logging.getLogger(__name__)
class PduStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
def get_pdu(self, pdu_id, origin):
"""Given a pdu_id and origin, get a PDU.
Args:
txn
pdu_id (str)
origin (str)
Returns:
PduTuple: If the pdu does not exist in the database, returns None
"""
return self.runInteraction(
self._get_pdu_tuple, pdu_id, origin
)
def _get_pdu_tuple(self, txn, pdu_id, origin):
res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
return res[0] if res else None
def _get_pdu_tuples(self, txn, pdu_id_tuples):
results = []
for pdu_id, origin in pdu_id_tuples:
txn.execute(
PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
(pdu_id, origin)
)
edges = [
(r.prev_pdu_id, r.prev_origin)
for r in PduEdgesTable.decode_results(txn.fetchall())
]
edge_hashes = self._get_prev_pdu_hashes_txn(txn, pdu_id, origin)
hashes = self._get_pdu_content_hashes_txn(txn, pdu_id, origin)
signatures = self._get_pdu_origin_signatures_txn(
txn, pdu_id, origin
)
query = (
"SELECT %(fields)s FROM %(pdus)s as p "
"LEFT JOIN %(state)s as s "
"ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
"WHERE p.pdu_id = ? AND p.origin = ? "
) % {
"fields": _pdu_state_joiner.get_fields(
PdusTable="p", StatePdusTable="s"),
"pdus": PdusTable.table_name,
"state": StatePdusTable.table_name,
}
txn.execute(query, (pdu_id, origin))
row = txn.fetchone()
if row:
results.append(PduTuple(
PduEntry(*row), edges, hashes, signatures, edge_hashes
))
return results
def get_current_state_for_context(self, context):
"""Get a list of PDUs that represent the current state for a given
context
Args:
context (str)
Returns:
list: A list of PduTuples
"""
return self.runInteraction(
self._get_current_state_for_context,
context
)
def _get_current_state_for_context(self, txn, context):
query = (
"SELECT pdu_id, origin FROM %s WHERE context = ?"
% CurrentStateTable.table_name
)
logger.debug("get_current_state %s, Args=%s", query, context)
txn.execute(query, (context,))
res = txn.fetchall()
logger.debug("get_current_state %d results", len(res))
return self._get_pdu_tuples(txn, res)
def _persist_pdu_txn(self, txn, prev_pdus, cols):
"""Inserts a (non-state) PDU into the database.
Args:
txn,
prev_pdus (list)
**cols: The columns to insert into the PdusTable.
"""
entry = PdusTable.EntryType(
**{k: cols.get(k, None) for k in PdusTable.fields}
)
txn.execute(PdusTable.insert_statement(), entry)
self._handle_prev_pdus(
txn, entry.outlier, entry.pdu_id, entry.origin,
prev_pdus, entry.context
)
def mark_pdu_as_processed(self, pdu_id, pdu_origin):
"""Mark a received PDU as processed.
Args:
txn
pdu_id (str)
pdu_origin (str)
"""
return self.runInteraction(
self._mark_as_processed, pdu_id, pdu_origin
)
def _mark_as_processed(self, txn, pdu_id, pdu_origin):
txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)
def get_all_pdus_from_context(self, context):
"""Get a list of all PDUs for a given context."""
return self.runInteraction(
self._get_all_pdus_from_context, context,
)
def _get_all_pdus_from_context(self, txn, context):
query = (
"SELECT pdu_id, origin FROM %s "
"WHERE context = ?"
) % PdusTable.table_name
txn.execute(query, (context,))
return self._get_pdu_tuples(txn, txn.fetchall())
def get_backfill(self, context, pdu_list, limit):
"""Get a list of Pdus for a given topic that occured before (and
including) the pdus in pdu_list. Return a list of max size `limit`.
Args:
txn
context (str)
pdu_list (list)
limit (int)
Return:
list: A list of PduTuples
"""
return self.runInteraction(
self._get_backfill, context, pdu_list, limit
)
def _get_backfill(self, txn, context, pdu_list, limit):
logger.debug(
"backfill: %s, %s, %s",
context, repr(pdu_list), limit
)
# We seed the pdu_results with the things from the pdu_list.
pdu_results = pdu_list
front = pdu_list
query = (
"SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
"WHERE context = ? AND pdu_id = ? AND origin = ? "
"LIMIT ?"
) % {
"edges_table": PduEdgesTable.table_name,
}
# We iterate through all pdu_ids in `front` to select their previous
# pdus. These are dumped in `new_front`. We continue until we reach the
# limit *or* new_front is empty (i.e., we've run out of things to
# select
while front and len(pdu_results) < limit:
new_front = []
for pdu_id, origin in front:
logger.debug(
"_backfill_interaction: i=%s, o=%s",
pdu_id, origin
)
txn.execute(
query,
(context, pdu_id, origin, limit - len(pdu_results))
)
for row in txn.fetchall():
logger.debug(
"_backfill_interaction: got i=%s, o=%s",
*row
)
new_front.append(row)
front = new_front
pdu_results += new_front
# We also want to update the `prev_pdus` attributes before returning.
return self._get_pdu_tuples(txn, pdu_results)
def get_min_depth_for_context(self, context):
"""Get the current minimum depth for a context
Args:
txn
context (str)
"""
return self.runInteraction(
self._get_min_depth_for_context, context
)
def _get_min_depth_for_context(self, txn, context):
return self._get_min_depth_interaction(txn, context)
def _get_min_depth_interaction(self, txn, context):
txn.execute(
"SELECT min_depth FROM %s WHERE context = ?"
% ContextDepthTable.table_name,
(context,)
)
row = txn.fetchone()
return row[0] if row else None
def _update_min_depth_for_context_txn(self, txn, context, depth):
"""Update the minimum `depth` of the given context, which is the line
on which we stop backfilling backwards.
Args:
context (str)
depth (int)
"""
min_depth = self._get_min_depth_interaction(txn, context)
do_insert = depth < min_depth if min_depth else True
if do_insert:
txn.execute(
"INSERT OR REPLACE INTO %s (context, min_depth) "
"VALUES (?,?)" % ContextDepthTable.table_name,
(context, depth)
)
def _get_latest_pdus_in_context(self, txn, context):
"""Get's a list of the most current pdus for a given context. This is
used when we are sending a Pdu and need to fill out the `prev_pdus`
key
Args:
txn
context
"""
query = (
"SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
"INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
"AND f.origin = p.origin "
"WHERE f.context = ?"
) % {
"pdus": PdusTable.table_name,
"forward": PduForwardExtremitiesTable.table_name,
}
logger.debug("get_prev query: %s", query)
txn.execute(
query,
(context, )
)
results = []
for pdu_id, origin, depth in txn.fetchall():
hashes = self._get_pdu_reference_hashes_txn(txn, pdu_id, origin)
sha256_bytes = hashes["sha256"]
prev_hashes = {"sha256": encode_base64(sha256_bytes)}
results.append((pdu_id, origin, prev_hashes, depth))
return results
@defer.inlineCallbacks
def get_oldest_pdus_in_context(self, context):
"""Get a list of Pdus that we haven't backfilled beyond yet (and havent
seen). This list is used when we want to backfill backwards and is the
list we send to the remote server.
Args:
txn
context (str)
Returns:
list: A list of PduIdTuple.
"""
results = yield self._execute(
None,
"SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
% {"back": PduBackwardExtremitiesTable.table_name, },
context
)
defer.returnValue([PduIdTuple(i, o) for i, o in results])
def is_pdu_new(self, pdu_id, origin, context, depth):
"""For a given Pdu, try and figure out if it's 'new', i.e., if it's
not something we got randomly from the past, for example when we
request the current state of the room that will probably return a bunch
of pdus from before we joined.
Args:
txn
pdu_id (str)
origin (str)
context (str)
depth (int)
Returns:
bool
"""
return self.runInteraction(
self._is_pdu_new,
pdu_id=pdu_id,
origin=origin,
context=context,
depth=depth
)
def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
# If depth > min depth in back table, then we classify it as new.
# OR if there is nothing in the back table, then it kinda needs to
# be a new thing.
query = (
"SELECT min(p.depth) FROM %(edges)s as e "
"INNER JOIN %(back)s as b "
"ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
"INNER JOIN %(pdus)s as p "
"ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
"WHERE p.context = ?"
) % {
"pdus": PdusTable.table_name,
"edges": PduEdgesTable.table_name,
"back": PduBackwardExtremitiesTable.table_name,
}
txn.execute(query, (context,))
min_depth, = txn.fetchone()
if not min_depth or depth > int(min_depth):
logger.debug(
"is_new true: id=%s, o=%s, d=%s min_depth=%s",
pdu_id, origin, depth, min_depth
)
return True
# If this pdu is in the forwards table, then it also is a new one
query = (
"SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
) % {
"forward": PduForwardExtremitiesTable.table_name,
}
txn.execute(query, (pdu_id, origin))
# Did we get anything?
if txn.fetchall():
logger.debug(
"is_new true: id=%s, o=%s, d=%s was forward",
pdu_id, origin, depth
)
return True
logger.debug(
"is_new false: id=%s, o=%s, d=%s",
pdu_id, origin, depth
)
# FINE THEN. It's probably old.
return False
@staticmethod
@log_function
def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
context):
txn.executemany(
PduEdgesTable.insert_statement(),
[(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
)
# Update the extremities table if this is not an outlier.
if not outlier:
# First, we delete the new one from the forwards extremities table.
query = (
"DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
% PduForwardExtremitiesTable.table_name
)
txn.executemany(query, list(p[:2] for p in prev_pdus))
# We only insert as a forward extremety the new pdu if there are no
# other pdus that reference it as a prev pdu
query = (
"INSERT INTO %(table)s (pdu_id, origin, context) "
"SELECT ?, ?, ? WHERE NOT EXISTS ("
"SELECT 1 FROM %(pdu_edges)s WHERE "
"prev_pdu_id = ? AND prev_origin = ?"
")"
) % {
"table": PduForwardExtremitiesTable.table_name,
"pdu_edges": PduEdgesTable.table_name
}
logger.debug("query: %s", query)
txn.execute(query, (pdu_id, origin, context, pdu_id, origin))
# Insert all the prev_pdus as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway.
txn.executemany(
PduBackwardExtremitiesTable.insert_statement(),
[(i, o, context) for i, o, _ in prev_pdus]
)
# Also delete from the backwards extremities table all ones that
# reference pdus that we have already seen
query = (
"DELETE FROM %(pdu_back)s WHERE EXISTS ("
"SELECT 1 FROM %(pdus)s AS pdus "
"WHERE "
"%(pdu_back)s.pdu_id = pdus.pdu_id "
"AND %(pdu_back)s.origin = pdus.origin "
"AND not pdus.outlier "
")"
) % {
"pdu_back": PduBackwardExtremitiesTable.table_name,
"pdus": PdusTable.table_name,
}
txn.execute(query)
class StatePduStore(SQLBaseStore):
"""A collection of queries for handling state PDUs.
"""
def _persist_state_txn(self, txn, prev_pdus, cols):
"""Inserts a state PDU into the database
Args:
txn,
prev_pdus (list)
**cols: The columns to insert into the PdusTable and StatePdusTable
"""
pdu_entry = PdusTable.EntryType(
**{k: cols.get(k, None) for k in PdusTable.fields}
)
state_entry = StatePdusTable.EntryType(
**{k: cols.get(k, None) for k in StatePdusTable.fields}
)
logger.debug("Inserting pdu: %s", repr(pdu_entry))
logger.debug("Inserting state: %s", repr(state_entry))
txn.execute(PdusTable.insert_statement(), pdu_entry)
txn.execute(StatePdusTable.insert_statement(), state_entry)
self._handle_prev_pdus(
txn,
pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
pdu_entry.context
)
def get_unresolved_state_tree(self, new_state_pdu):
return self.runInteraction(
self._get_unresolved_state_tree, new_state_pdu
)
@log_function
def _get_unresolved_state_tree(self, txn, new_pdu):
current = self._get_current_interaction(
txn,
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
)
ReturnType = namedtuple(
"StateReturnType", ["new_branch", "current_branch"]
)
return_value = ReturnType([new_pdu], [])
if not current:
logger.debug("get_unresolved_state_tree No current state.")
return (return_value, None)
return_value.current_branch.append(current)
enum_branches = self._enumerate_state_branches(
txn, new_pdu, current
)
missing_branch = None
for branch, prev_state, state in enum_branches:
if state:
return_value[branch].append(state)
else:
# We don't have prev_state :(
missing_branch = branch
break
return (return_value, missing_branch)
def update_current_state(self, pdu_id, origin, context, pdu_type,
state_key):
return self.runInteraction(
self._update_current_state,
pdu_id, origin, context, pdu_type, state_key
)
def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
state_key):
query = (
"INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
) % {
"curr": CurrentStateTable.table_name,
"fields": CurrentStateTable.get_fields_string(),
"qs": ", ".join(["?"] * len(CurrentStateTable.fields))
}
query_args = CurrentStateTable.EntryType(
pdu_id=pdu_id,
origin=origin,
context=context,
pdu_type=pdu_type,
state_key=state_key
)
txn.execute(query, query_args)
def get_current_state_pdu(self, context, pdu_type, state_key):
"""For a given context, pdu_type, state_key 3-tuple, return what is
currently considered the current state.
Args:
txn
context (str)
pdu_type (str)
state_key (str)
Returns:
PduEntry
"""
return self.runInteraction(
self._get_current_state_pdu, context, pdu_type, state_key
)
def _get_current_state_pdu(self, txn, context, pdu_type, state_key):
return self._get_current_interaction(txn, context, pdu_type, state_key)
def _get_current_interaction(self, txn, context, pdu_type, state_key):
logger.debug(
"_get_current_interaction %s %s %s",
context, pdu_type, state_key
)
fields = _pdu_state_joiner.get_fields(
PdusTable="p", StatePdusTable="s")
current_query = (
"SELECT %(fields)s FROM %(state)s as s "
"INNER JOIN %(pdus)s as p "
"ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
"INNER JOIN %(curr)s as c "
"ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
"WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
) % {
"fields": fields,
"curr": CurrentStateTable.table_name,
"state": StatePdusTable.table_name,
"pdus": PdusTable.table_name,
}
txn.execute(
current_query,
(context, pdu_type, state_key)
)
row = txn.fetchone()
result = PduEntry(*row) if row else None
if not result:
logger.debug("_get_current_interaction not found")
else:
logger.debug(
"_get_current_interaction found %s %s",
result.pdu_id, result.origin
)
return result
def handle_new_state(self, new_pdu):
"""Actually perform conflict resolution on the new_pdu on the
assumption we have all the pdus required to perform it.
Args:
new_pdu
Returns:
bool: True if the new_pdu clobbered the current state, False if not
"""
return self.runInteraction(
self._handle_new_state, new_pdu
)
def _handle_new_state(self, txn, new_pdu):
logger.debug(
"handle_new_state %s %s",
new_pdu.pdu_id, new_pdu.origin
)
current = self._get_current_interaction(
txn,
new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
)
is_current = False
if (not current or not current.prev_state_id
or not current.prev_state_origin):
# Oh, we don't have any state for this yet.
is_current = True
elif (current.pdu_id == new_pdu.prev_state_id
and current.origin == new_pdu.prev_state_origin):
# Oh! A direct clobber. Just do it.
is_current = True
else:
##
# Ok, now loop through until we get to a common ancestor.
max_new = int(new_pdu.power_level)
max_current = int(current.power_level)
enum_branches = self._enumerate_state_branches(
txn, new_pdu, current
)
for branch, prev_state, state in enum_branches:
if not state:
raise RuntimeError(
"Could not find state_pdu %s %s" %
(
prev_state.prev_state_id,
prev_state.prev_state_origin
)
)
if branch == 0:
max_new = max(int(state.depth), max_new)
else:
max_current = max(int(state.depth), max_current)
is_current = max_new > max_current
if is_current:
logger.debug("handle_new_state make current")
# Right, this is a new thing, so woo, just insert it.
txn.execute(
"INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
% {
"curr": CurrentStateTable.table_name,
"fields": CurrentStateTable.get_fields_string(),
"qs": ", ".join(["?"] * len(CurrentStateTable.fields))
},
CurrentStateTable.EntryType(
*(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
)
)
else:
logger.debug("handle_new_state not current")
logger.debug("handle_new_state done")
return is_current
@log_function
def _enumerate_state_branches(self, txn, pdu_a, pdu_b):
branch_a = pdu_a
branch_b = pdu_b
while True:
if (branch_a.pdu_id == branch_b.pdu_id
and branch_a.origin == branch_b.origin):
# Woo! We found a common ancestor
logger.debug("_enumerate_state_branches Found common ancestor")
break
do_branch_a = (
hasattr(branch_a, "prev_state_id") and
branch_a.prev_state_id
)
do_branch_b = (
hasattr(branch_b, "prev_state_id") and
branch_b.prev_state_id
)
logger.debug(
"do_branch_a=%s, do_branch_b=%s",
do_branch_a, do_branch_b
)
if do_branch_a and do_branch_b:
do_branch_a = int(branch_a.depth) > int(branch_b.depth)
if do_branch_a:
pdu_tuple = PduIdTuple(
branch_a.prev_state_id,
branch_a.prev_state_origin
)
prev_branch = branch_a
logger.debug("getting branch_a prev %s", pdu_tuple)
branch_a = self._get_pdu_tuple(txn, *pdu_tuple)
if branch_a:
branch_a = Pdu.from_pdu_tuple(branch_a)
logger.debug("branch_a=%s", branch_a)
yield (0, prev_branch, branch_a)
if not branch_a:
break
elif do_branch_b:
pdu_tuple = PduIdTuple(
branch_b.prev_state_id,
branch_b.prev_state_origin
)
prev_branch = branch_b
logger.debug("getting branch_b prev %s", pdu_tuple)
branch_b = self._get_pdu_tuple(txn, *pdu_tuple)
if branch_b:
branch_b = Pdu.from_pdu_tuple(branch_b)
logger.debug("branch_b=%s", branch_b)
yield (1, prev_branch, branch_b)
if not branch_b:
break
else:
break
class PdusTable(Table):
table_name = "pdus"
fields = [
"pdu_id",
"origin",
"context",
"pdu_type",
"ts",
"depth",
"is_state",
"content_json",
"unrecognized_keys",
"outlier",
"have_processed",
]
EntryType = namedtuple("PdusEntry", fields)
class PduDestinationsTable(Table):
table_name = "pdu_destinations"
fields = [
"pdu_id",
"origin",
"destination",
"delivered_ts",
]
EntryType = namedtuple("PduDestinationsEntry", fields)
class PduEdgesTable(Table):
table_name = "pdu_edges"
fields = [
"pdu_id",
"origin",
"prev_pdu_id",
"prev_origin",
"context"
]
EntryType = namedtuple("PduEdgesEntry", fields)
class PduForwardExtremitiesTable(Table):
table_name = "pdu_forward_extremities"
fields = [
"pdu_id",
"origin",
"context",
]
EntryType = namedtuple("PduForwardExtremitiesEntry", fields)
class PduBackwardExtremitiesTable(Table):
table_name = "pdu_backward_extremities"
fields = [
"pdu_id",
"origin",
"context",
]
EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)
class ContextDepthTable(Table):
table_name = "context_depth"
fields = [
"context",
"min_depth",
]
EntryType = namedtuple("ContextDepthEntry", fields)
class StatePdusTable(Table):
table_name = "state_pdus"
fields = [
"pdu_id",
"origin",
"context",
"pdu_type",
"state_key",
"power_level",
"prev_state_id",
"prev_state_origin",
]
EntryType = namedtuple("StatePdusEntry", fields)
class CurrentStateTable(Table):
table_name = "current_state"
fields = [
"pdu_id",
"origin",
"context",
"pdu_type",
"state_key",
]
EntryType = namedtuple("CurrentStateEntry", fields)
_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)
# TODO: These should probably be put somewhere more sensible
PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))
PduEntry = _pdu_state_joiner.EntryType
""" We are always interested in the join of the PdusTable and StatePdusTable,
rather than just the PdusTable.
This does not include a prev_pdus key.
"""
PduTuple = namedtuple(
"PduTuple",
("pdu_entry", "prev_pdu_list", "hashes", "signatures", "edge_hashes")
)
""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
the `prev_pdus` key of a PDU.
"""