Sanitize TransactionStore

pull/123/head
Erik Johnston 2015-03-23 13:43:21 +00:00
parent f6583796fe
commit 278149f533
2 changed files with 104 additions and 87 deletions

View File

@ -179,7 +179,7 @@ class FederationHandler(BaseHandler):
# it's probably a good idea to mark it as not in retry-state # it's probably a good idea to mark it as not in retry-state
# for sending (although this is a bit of a leap) # for sending (although this is a bit of a leap)
retry_timings = yield self.store.get_destination_retry_timings(origin) retry_timings = yield self.store.get_destination_retry_timings(origin)
if (retry_timings and retry_timings.retry_last_ts): if retry_timings and retry_timings["retry_last_ts"]:
self.store.set_destination_retry_timings(origin, 0, 0) self.store.set_destination_retry_timings(origin, 0, 0)
room = yield self.store.get_room(event.room_id) room = yield self.store.get_room(event.room_id)

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, Table, cached from ._base import SQLBaseStore, cached
from collections import namedtuple from collections import namedtuple
@ -84,13 +84,18 @@ class TransactionStore(SQLBaseStore):
def _set_received_txn_response(self, txn, transaction_id, origin, code, def _set_received_txn_response(self, txn, transaction_id, origin, code,
response_json): response_json):
query = ( self._simple_update_one_txn(
"UPDATE %s " txn,
"SET response_code = ?, response_json = ? " table=ReceivedTransactionsTable.table_name,
"WHERE transaction_id = ? AND origin = ?" keyvalues={
) % ReceivedTransactionsTable.table_name "transaction_id": transaction_id,
"origin": origin,
txn.execute(query, (code, response_json, transaction_id, origin)) },
updatevalues={
"response_code": code,
"response_json": response_json,
}
)
def prep_send_transaction(self, transaction_id, destination, def prep_send_transaction(self, transaction_id, destination,
origin_server_ts): origin_server_ts):
@ -121,38 +126,32 @@ class TransactionStore(SQLBaseStore):
# First we find out what the prev_txns should be. # First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time, # Since we know that we are only sending one transaction at a time,
# we can simply take the last one. # we can simply take the last one.
query = "%s ORDER BY id DESC LIMIT 1" % ( query = (
SentTransactions.select_statement("destination = ?"), "SELECT * FROM sent_transactions"
" WHERE destination = ?"
" ORDER BY id DESC LIMIT 1"
) )
txn.execute(query, (destination,)) txn.execute(query, (destination,))
results = SentTransactions.decode_results(txn.fetchall()) results = self.cursor_to_dict(txn)
prev_txns = [r.transaction_id for r in results] prev_txns = [r["transaction_id"] for r in results]
# Actually add the new transaction to the sent_transactions table. # Actually add the new transaction to the sent_transactions table.
query = SentTransactions.insert_statement() self._simple_insert_txn(
txn.execute(query, SentTransactions.EntryType( txn,
self.get_next_stream_id(), table=SentTransactions.table_name,
transaction_id=transaction_id, values={
destination=destination, "transaction_id": self.get_next_stream_id(),
ts=origin_server_ts, "destination": destination,
response_code=0, "ts": origin_server_ts,
response_json=None "response_code": 0,
)) "response_json": None,
}
)
# Update the tx id -> pdu id mapping # TODO Update the tx id -> pdu id mapping
# values = [
# (transaction_id, destination, pdu[0], pdu[1])
# for pdu in pdu_list
# ]
#
# logger.debug("Inserting: %s", repr(values))
#
# query = TransactionsToPduTable.insert_statement()
# txn.executemany(query, values)
return prev_txns return prev_txns
@ -171,15 +170,20 @@ class TransactionStore(SQLBaseStore):
transaction_id, destination, code, response_dict transaction_id, destination, code, response_dict
) )
def _delivered_txn(cls, txn, transaction_id, destination, def _delivered_txn(self, txn, transaction_id, destination,
code, response_json): code, response_json):
query = ( self._simple_update_one_txn(
"UPDATE %s " txn,
"SET response_code = ?, response_json = ? " table=SentTransactions.table_name,
"WHERE transaction_id = ? AND destination = ?" keyvalues={
) % SentTransactions.table_name "transaction_id": transaction_id,
"destination": destination,
txn.execute(query, (code, response_json, transaction_id, destination)) },
updatevalues={
"response_code": code,
"response_json": response_json,
}
)
def get_transactions_after(self, transaction_id, destination): def get_transactions_after(self, transaction_id, destination):
"""Get all transactions after a given local transaction_id. """Get all transactions after a given local transaction_id.
@ -189,25 +193,26 @@ class TransactionStore(SQLBaseStore):
destination (str) destination (str)
Returns: Returns:
list: A list of `ReceivedTransactionsTable.EntryType` list: A list of dicts
""" """
return self.runInteraction( return self.runInteraction(
"get_transactions_after", "get_transactions_after",
self._get_transactions_after, transaction_id, destination self._get_transactions_after, transaction_id, destination
) )
def _get_transactions_after(cls, txn, transaction_id, destination): def _get_transactions_after(self, txn, transaction_id, destination):
where = ( query = (
"destination = ? AND id > (select id FROM %s WHERE " "SELECT * FROM sent_transactions"
"transaction_id = ? AND destination = ?)" " WHERE destination = ? AND id >"
) % ( " ("
SentTransactions.table_name " SELECT id FROM sent_transactions"
" WHERE transaction_id = ? AND destination = ?"
" )"
) )
query = SentTransactions.select_statement(where)
txn.execute(query, (destination, transaction_id, destination)) txn.execute(query, (destination, transaction_id, destination))
return ReceivedTransactionsTable.decode_results(txn.fetchall()) return self.cursor_to_dict(txn)
@cached() @cached()
def get_destination_retry_timings(self, destination): def get_destination_retry_timings(self, destination):
@ -218,19 +223,24 @@ class TransactionStore(SQLBaseStore):
Returns: Returns:
None if not retrying None if not retrying
Otherwise a DestinationsTable.EntryType for the retry scheme Otherwise a dict for the retry scheme
""" """
return self.runInteraction( return self.runInteraction(
"get_destination_retry_timings", "get_destination_retry_timings",
self._get_destination_retry_timings, destination) self._get_destination_retry_timings, destination)
def _get_destination_retry_timings(cls, txn, destination): def _get_destination_retry_timings(self, txn, destination):
query = DestinationsTable.select_statement("destination = ?") result = self._simple_select_one_txn(
txn.execute(query, (destination,)) txn,
result = txn.fetchall() table=DestinationsTable.table_name,
if result: keyvalues={
result = DestinationsTable.decode_single_result(result) "destination": destination,
if result.retry_last_ts > 0: },
retcols=DestinationsTable.fields,
allow_none=True,
)
if result["retry_last_ts"] > 0:
return result return result
else: else:
return None return None
@ -249,11 +259,11 @@ class TransactionStore(SQLBaseStore):
# As this is the new value, we might as well prefill the cache # As this is the new value, we might as well prefill the cache
self.get_destination_retry_timings.prefill( self.get_destination_retry_timings.prefill(
destination, destination,
DestinationsTable.EntryType( {
destination, "destination": destination,
retry_last_ts, "retry_last_ts": retry_last_ts,
retry_interval "retry_interval": retry_interval
) },
) )
# XXX: we could chose to not bother persisting this if our cache thinks # XXX: we could chose to not bother persisting this if our cache thinks
@ -270,18 +280,27 @@ class TransactionStore(SQLBaseStore):
retry_last_ts, retry_interval): retry_last_ts, retry_interval):
query = ( query = (
"REPLACE INTO %s " "INSERT INTO destinations"
" (destination, retry_last_ts, retry_interval)" " (destination, retry_last_ts, retry_interval)"
" VALUES (?, ?, ?)" " VALUES (?, ?, ?)"
) % DestinationsTable.table_name " ON DUPLICATE KEY UPDATE"
" retry_last_ts=?, retry_interval=?"
)
txn.execute(query, (destination, retry_last_ts, retry_interval)) txn.execute(
query,
(
destination,
retry_last_ts, retry_interval,
retry_last_ts, retry_interval,
)
)
def get_destinations_needing_retry(self): def get_destinations_needing_retry(self):
"""Get all destinations which are due a retry for sending a transaction. """Get all destinations which are due a retry for sending a transaction.
Returns: Returns:
list: A list of `DestinationsTable.EntryType` list: A list of dicts
""" """
return self.runInteraction( return self.runInteraction(
@ -289,14 +308,17 @@ class TransactionStore(SQLBaseStore):
self._get_destinations_needing_retry self._get_destinations_needing_retry
) )
def _get_destinations_needing_retry(cls, txn): def _get_destinations_needing_retry(self, txn):
where = "retry_last_ts > 0 and retry_next_ts < now()" query = (
query = DestinationsTable.select_statement(where) "SELECT * FROM destinations"
txn.execute(query) " WHERE retry_last_ts > 0 and retry_next_ts < ?"
return DestinationsTable.decode_results(txn.fetchall()) )
txn.execute(query, (self._clock.time_msec(),))
return self.cursor_to_dict(txn)
class ReceivedTransactionsTable(Table): class ReceivedTransactionsTable(object):
table_name = "received_transactions" table_name = "received_transactions"
fields = [ fields = [
@ -308,10 +330,8 @@ class ReceivedTransactionsTable(Table):
"has_been_referenced", "has_been_referenced",
] ]
EntryType = namedtuple("ReceivedTransactionsEntry", fields)
class SentTransactions(object):
class SentTransactions(Table):
table_name = "sent_transactions" table_name = "sent_transactions"
fields = [ fields = [
@ -326,7 +346,7 @@ class SentTransactions(Table):
EntryType = namedtuple("SentTransactionsEntry", fields) EntryType = namedtuple("SentTransactionsEntry", fields)
class TransactionsToPduTable(Table): class TransactionsToPduTable(object):
table_name = "transaction_id_to_pdu" table_name = "transaction_id_to_pdu"
fields = [ fields = [
@ -336,10 +356,8 @@ class TransactionsToPduTable(Table):
"pdu_origin", "pdu_origin",
] ]
EntryType = namedtuple("TransactionsToPduEntry", fields)
class DestinationsTable(object):
class DestinationsTable(Table):
table_name = "destinations" table_name = "destinations"
fields = [ fields = [
@ -348,4 +366,3 @@ class DestinationsTable(Table):
"retry_interval", "retry_interval",
] ]
EntryType = namedtuple("DestinationsEntry", fields)