385 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			385 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
| # -*- coding: utf-8 -*-
 | |
| # Copyright 2014-2016 OpenMarket Ltd
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| 
 | |
| from twisted.internet import defer
 | |
| 
 | |
| from .persistence import TransactionActions
 | |
| from .units import Transaction
 | |
| 
 | |
| from synapse.api.errors import HttpResponseException
 | |
| from synapse.util.logutils import log_function
 | |
| from synapse.util.logcontext import PreserveLoggingContext
 | |
| from synapse.util.retryutils import (
 | |
|     get_retry_limiter, NotRetryingDestination,
 | |
| )
 | |
| import synapse.metrics
 | |
| 
 | |
| import logging
 | |
| 
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| metrics = synapse.metrics.get_metrics_for(__name__)
 | |
| 
 | |
| 
 | |
| class TransactionQueue(object):
 | |
|     """This class makes sure we only have one transaction in flight at
 | |
|     a time for a given destination.
 | |
| 
 | |
|     It batches pending PDUs into single transactions.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, hs, transport_layer):
 | |
|         self.server_name = hs.hostname
 | |
| 
 | |
|         self.store = hs.get_datastore()
 | |
|         self.transaction_actions = TransactionActions(self.store)
 | |
| 
 | |
|         self.transport_layer = transport_layer
 | |
| 
 | |
|         self._clock = hs.get_clock()
 | |
| 
 | |
|         # Is a mapping from destinations -> deferreds. Used to keep track
 | |
|         # of which destinations have transactions in flight and when they are
 | |
|         # done
 | |
|         self.pending_transactions = {}
 | |
| 
 | |
|         metrics.register_callback(
 | |
|             "pending_destinations",
 | |
|             lambda: len(self.pending_transactions),
 | |
|         )
 | |
| 
 | |
|         # Is a mapping from destination -> list of
 | |
|         # tuple(pending pdus, deferred, order)
 | |
|         self.pending_pdus_by_dest = pdus = {}
 | |
|         # destination -> list of tuple(edu, deferred)
 | |
|         self.pending_edus_by_dest = edus = {}
 | |
| 
 | |
|         metrics.register_callback(
 | |
|             "pending_pdus",
 | |
|             lambda: sum(map(len, pdus.values())),
 | |
|         )
 | |
|         metrics.register_callback(
 | |
|             "pending_edus",
 | |
|             lambda: sum(map(len, edus.values())),
 | |
|         )
 | |
| 
 | |
|         # destination -> list of tuple(failure, deferred)
 | |
|         self.pending_failures_by_dest = {}
 | |
| 
 | |
|         # HACK to get unique tx id
 | |
|         self._next_txn_id = int(self._clock.time_msec())
 | |
| 
 | |
|     def can_send_to(self, destination):
 | |
|         """Can we send messages to the given server?
 | |
| 
 | |
|         We can't send messages to ourselves. If we are running on localhost
 | |
|         then we can only federation with other servers running on localhost.
 | |
|         Otherwise we only federate with servers on a public domain.
 | |
| 
 | |
|         Args:
 | |
|             destination(str): The server we are possibly trying to send to.
 | |
|         Returns:
 | |
|             bool: True if we can send to the server.
 | |
|         """
 | |
| 
 | |
|         if destination == self.server_name:
 | |
|             return False
 | |
|         if self.server_name.startswith("localhost"):
 | |
|             return destination.startswith("localhost")
 | |
|         else:
 | |
|             return not destination.startswith("localhost")
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def enqueue_pdu(self, pdu, destinations, order):
 | |
|         # We loop through all destinations to see whether we already have
 | |
|         # a transaction in progress. If we do, stick it in the pending_pdus
 | |
|         # table and we'll get back to it later.
 | |
| 
 | |
|         destinations = set(destinations)
 | |
|         destinations = set(
 | |
|             dest for dest in destinations if self.can_send_to(dest)
 | |
|         )
 | |
| 
 | |
|         logger.debug("Sending to: %s", str(destinations))
 | |
| 
 | |
|         if not destinations:
 | |
|             return
 | |
| 
 | |
|         deferreds = []
 | |
| 
 | |
|         for destination in destinations:
 | |
|             deferred = defer.Deferred()
 | |
|             self.pending_pdus_by_dest.setdefault(destination, []).append(
 | |
|                 (pdu, deferred, order)
 | |
|             )
 | |
| 
 | |
|             def chain(failure):
 | |
|                 if not deferred.called:
 | |
|                     deferred.errback(failure)
 | |
| 
 | |
|             def log_failure(f):
 | |
|                 logger.warn("Failed to send pdu to %s: %s", destination, f.value)
 | |
| 
 | |
|             deferred.addErrback(log_failure)
 | |
| 
 | |
|             with PreserveLoggingContext():
 | |
|                 self._attempt_new_transaction(destination).addErrback(chain)
 | |
| 
 | |
|             deferreds.append(deferred)
 | |
| 
 | |
|         yield defer.DeferredList(deferreds, consumeErrors=True)
 | |
| 
 | |
|     # NO inlineCallbacks
 | |
|     def enqueue_edu(self, edu):
 | |
|         destination = edu.destination
 | |
| 
 | |
|         if not self.can_send_to(destination):
 | |
|             return
 | |
| 
 | |
|         deferred = defer.Deferred()
 | |
|         self.pending_edus_by_dest.setdefault(destination, []).append(
 | |
|             (edu, deferred)
 | |
|         )
 | |
| 
 | |
|         def chain(failure):
 | |
|             if not deferred.called:
 | |
|                 deferred.errback(failure)
 | |
| 
 | |
|         def log_failure(f):
 | |
|             logger.warn("Failed to send edu to %s: %s", destination, f.value)
 | |
| 
 | |
|         deferred.addErrback(log_failure)
 | |
| 
 | |
|         with PreserveLoggingContext():
 | |
|             self._attempt_new_transaction(destination).addErrback(chain)
 | |
| 
 | |
|         return deferred
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     def enqueue_failure(self, failure, destination):
 | |
|         if destination == self.server_name or destination == "localhost":
 | |
|             return
 | |
| 
 | |
|         deferred = defer.Deferred()
 | |
| 
 | |
|         if not self.can_send_to(destination):
 | |
|             return
 | |
| 
 | |
|         self.pending_failures_by_dest.setdefault(
 | |
|             destination, []
 | |
|         ).append(
 | |
|             (failure, deferred)
 | |
|         )
 | |
| 
 | |
|         def chain(f):
 | |
|             if not deferred.called:
 | |
|                 deferred.errback(f)
 | |
| 
 | |
|         def log_failure(f):
 | |
|             logger.warn("Failed to send failure to %s: %s", destination, f.value)
 | |
| 
 | |
|         deferred.addErrback(log_failure)
 | |
| 
 | |
|         with PreserveLoggingContext():
 | |
|             self._attempt_new_transaction(destination).addErrback(chain)
 | |
| 
 | |
|         yield deferred
 | |
| 
 | |
|     @defer.inlineCallbacks
 | |
|     @log_function
 | |
|     def _attempt_new_transaction(self, destination):
 | |
|         # list of (pending_pdu, deferred, order)
 | |
|         if destination in self.pending_transactions:
 | |
|             # XXX: pending_transactions can get stuck on by a never-ending
 | |
|             # request at which point pending_pdus_by_dest just keeps growing.
 | |
|             # we need application-layer timeouts of some flavour of these
 | |
|             # requests
 | |
|             logger.debug(
 | |
|                 "TX [%s] Transaction already in progress",
 | |
|                 destination
 | |
|             )
 | |
|             return
 | |
| 
 | |
|         pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
 | |
|         pending_edus = self.pending_edus_by_dest.pop(destination, [])
 | |
|         pending_failures = self.pending_failures_by_dest.pop(destination, [])
 | |
| 
 | |
|         if pending_pdus:
 | |
|             logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
 | |
|                          destination, len(pending_pdus))
 | |
| 
 | |
|         if not pending_pdus and not pending_edus and not pending_failures:
 | |
|             logger.debug("TX [%s] Nothing to send", destination)
 | |
|             return
 | |
| 
 | |
|         try:
 | |
|             self.pending_transactions[destination] = 1
 | |
| 
 | |
|             logger.debug("TX [%s] _attempt_new_transaction", destination)
 | |
| 
 | |
|             # Sort based on the order field
 | |
|             pending_pdus.sort(key=lambda t: t[2])
 | |
| 
 | |
|             pdus = [x[0] for x in pending_pdus]
 | |
|             edus = [x[0] for x in pending_edus]
 | |
|             failures = [x[0].get_dict() for x in pending_failures]
 | |
|             deferreds = [
 | |
|                 x[1]
 | |
|                 for x in pending_pdus + pending_edus + pending_failures
 | |
|             ]
 | |
| 
 | |
|             txn_id = str(self._next_txn_id)
 | |
| 
 | |
|             limiter = yield get_retry_limiter(
 | |
|                 destination,
 | |
|                 self._clock,
 | |
|                 self.store,
 | |
|             )
 | |
| 
 | |
|             logger.debug(
 | |
|                 "TX [%s] {%s} Attempting new transaction"
 | |
|                 " (pdus: %d, edus: %d, failures: %d)",
 | |
|                 destination, txn_id,
 | |
|                 len(pending_pdus),
 | |
|                 len(pending_edus),
 | |
|                 len(pending_failures)
 | |
|             )
 | |
| 
 | |
|             logger.debug("TX [%s] Persisting transaction...", destination)
 | |
| 
 | |
|             transaction = Transaction.create_new(
 | |
|                 origin_server_ts=int(self._clock.time_msec()),
 | |
|                 transaction_id=txn_id,
 | |
|                 origin=self.server_name,
 | |
|                 destination=destination,
 | |
|                 pdus=pdus,
 | |
|                 edus=edus,
 | |
|                 pdu_failures=failures,
 | |
|             )
 | |
| 
 | |
|             self._next_txn_id += 1
 | |
| 
 | |
|             yield self.transaction_actions.prepare_to_send(transaction)
 | |
| 
 | |
|             logger.debug("TX [%s] Persisted transaction", destination)
 | |
|             logger.info(
 | |
|                 "TX [%s] {%s} Sending transaction [%s],"
 | |
|                 " (PDUs: %d, EDUs: %d, failures: %d)",
 | |
|                 destination, txn_id,
 | |
|                 transaction.transaction_id,
 | |
|                 len(pending_pdus),
 | |
|                 len(pending_edus),
 | |
|                 len(pending_failures),
 | |
|             )
 | |
| 
 | |
|             with limiter:
 | |
|                 # Actually send the transaction
 | |
| 
 | |
|                 # FIXME (erikj): This is a bit of a hack to make the Pdu age
 | |
|                 # keys work
 | |
|                 def json_data_cb():
 | |
|                     data = transaction.get_dict()
 | |
|                     now = int(self._clock.time_msec())
 | |
|                     if "pdus" in data:
 | |
|                         for p in data["pdus"]:
 | |
|                             if "age_ts" in p:
 | |
|                                 unsigned = p.setdefault("unsigned", {})
 | |
|                                 unsigned["age"] = now - int(p["age_ts"])
 | |
|                                 del p["age_ts"]
 | |
|                     return data
 | |
| 
 | |
|                 try:
 | |
|                     response = yield self.transport_layer.send_transaction(
 | |
|                         transaction, json_data_cb
 | |
|                     )
 | |
|                     code = 200
 | |
| 
 | |
|                     if response:
 | |
|                         for e_id, r in response.get("pdus", {}).items():
 | |
|                             if "error" in r:
 | |
|                                 logger.warn(
 | |
|                                     "Transaction returned error for %s: %s",
 | |
|                                     e_id, r,
 | |
|                                 )
 | |
|                 except HttpResponseException as e:
 | |
|                     code = e.code
 | |
|                     response = e.response
 | |
| 
 | |
|                 logger.info(
 | |
|                     "TX [%s] {%s} got %d response",
 | |
|                     destination, txn_id, code
 | |
|                 )
 | |
| 
 | |
|                 logger.debug("TX [%s] Sent transaction", destination)
 | |
|                 logger.debug("TX [%s] Marking as delivered...", destination)
 | |
| 
 | |
|             yield self.transaction_actions.delivered(
 | |
|                 transaction, code, response
 | |
|             )
 | |
| 
 | |
|             logger.debug("TX [%s] Marked as delivered", destination)
 | |
| 
 | |
|             logger.debug("TX [%s] Yielding to callbacks...", destination)
 | |
| 
 | |
|             for deferred in deferreds:
 | |
|                 if code == 200:
 | |
|                     deferred.callback(None)
 | |
|                 else:
 | |
|                     deferred.errback(RuntimeError("Got status %d" % code))
 | |
| 
 | |
|                 # Ensures we don't continue until all callbacks on that
 | |
|                 # deferred have fired
 | |
|                 try:
 | |
|                     yield deferred
 | |
|                 except:
 | |
|                     pass
 | |
| 
 | |
|             logger.debug("TX [%s] Yielded to callbacks", destination)
 | |
|         except NotRetryingDestination:
 | |
|             logger.info(
 | |
|                 "TX [%s] not ready for retry yet - "
 | |
|                 "dropping transaction for now",
 | |
|                 destination,
 | |
|             )
 | |
|         except RuntimeError as e:
 | |
|             # We capture this here as there as nothing actually listens
 | |
|             # for this finishing functions deferred.
 | |
|             logger.warn(
 | |
|                 "TX [%s] Problem in _attempt_transaction: %s",
 | |
|                 destination,
 | |
|                 e,
 | |
|             )
 | |
|         except Exception as e:
 | |
|             # We capture this here as there as nothing actually listens
 | |
|             # for this finishing functions deferred.
 | |
|             logger.warn(
 | |
|                 "TX [%s] Problem in _attempt_transaction: %s",
 | |
|                 destination,
 | |
|                 e,
 | |
|             )
 | |
| 
 | |
|             for deferred in deferreds:
 | |
|                 if not deferred.called:
 | |
|                     deferred.errback(e)
 | |
| 
 | |
|         finally:
 | |
|             # We want to be *very* sure we delete this after we stop processing
 | |
|             self.pending_transactions.pop(destination, None)
 | |
| 
 | |
|             # Check to see if there is anything else to send.
 | |
|             self._attempt_new_transaction(destination)
 |