170 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			170 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
# -*- coding: utf-8 -*-
 | 
						|
# Copyright 2019 New Vector Ltd
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
import logging
 | 
						|
from typing import List
 | 
						|
 | 
						|
from canonicaljson import json
 | 
						|
 | 
						|
import synapse.server
 | 
						|
from synapse.api.errors import HttpResponseException
 | 
						|
from synapse.events import EventBase
 | 
						|
from synapse.federation.persistence import TransactionActions
 | 
						|
from synapse.federation.units import Edu, Transaction
 | 
						|
from synapse.logging.opentracing import (
 | 
						|
    extract_text_map,
 | 
						|
    set_tag,
 | 
						|
    start_active_span_follows_from,
 | 
						|
    tags,
 | 
						|
    whitelisted_homeserver,
 | 
						|
)
 | 
						|
from synapse.util.metrics import measure_func
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
 | 
						|
class TransactionManager(object):
 | 
						|
    """Helper class which handles building and sending transactions
 | 
						|
 | 
						|
    shared between PerDestinationQueue objects
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, hs: "synapse.server.HomeServer"):
 | 
						|
        self._server_name = hs.hostname
 | 
						|
        self.clock = hs.get_clock()  # nb must be called this for @measure_func
 | 
						|
        self._store = hs.get_datastore()
 | 
						|
        self._transaction_actions = TransactionActions(self._store)
 | 
						|
        self._transport_layer = hs.get_federation_transport_client()
 | 
						|
 | 
						|
        # HACK to get unique tx id
 | 
						|
        self._next_txn_id = int(self.clock.time_msec())
 | 
						|
 | 
						|
    @measure_func("_send_new_transaction")
 | 
						|
    async def send_new_transaction(
 | 
						|
        self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu]
 | 
						|
    ):
 | 
						|
 | 
						|
        # Make a transaction-sending opentracing span. This span follows on from
 | 
						|
        # all the edus in that transaction. This needs to be done since there is
 | 
						|
        # no active span here, so if the edus were not received by the remote the
 | 
						|
        # span would have no causality and it would be forgotten.
 | 
						|
        # The span_contexts is a generator so that it won't be evaluated if
 | 
						|
        # opentracing is disabled. (Yay speed!)
 | 
						|
 | 
						|
        span_contexts = []
 | 
						|
        keep_destination = whitelisted_homeserver(destination)
 | 
						|
 | 
						|
        for edu in pending_edus:
 | 
						|
            context = edu.get_context()
 | 
						|
            if context:
 | 
						|
                span_contexts.append(extract_text_map(json.loads(context)))
 | 
						|
            if keep_destination:
 | 
						|
                edu.strip_context()
 | 
						|
 | 
						|
        with start_active_span_follows_from("send_transaction", span_contexts):
 | 
						|
 | 
						|
            # Sort based on the order field
 | 
						|
            pending_pdus.sort(key=lambda t: t[1])
 | 
						|
            pdus = [x[0] for x in pending_pdus]
 | 
						|
            edus = pending_edus
 | 
						|
 | 
						|
            success = True
 | 
						|
 | 
						|
            logger.debug("TX [%s] _attempt_new_transaction", destination)
 | 
						|
 | 
						|
            txn_id = str(self._next_txn_id)
 | 
						|
 | 
						|
            logger.debug(
 | 
						|
                "TX [%s] {%s} Attempting new transaction (pdus: %d, edus: %d)",
 | 
						|
                destination,
 | 
						|
                txn_id,
 | 
						|
                len(pdus),
 | 
						|
                len(edus),
 | 
						|
            )
 | 
						|
 | 
						|
            transaction = Transaction.create_new(
 | 
						|
                origin_server_ts=int(self.clock.time_msec()),
 | 
						|
                transaction_id=txn_id,
 | 
						|
                origin=self._server_name,
 | 
						|
                destination=destination,
 | 
						|
                pdus=pdus,
 | 
						|
                edus=edus,
 | 
						|
            )
 | 
						|
 | 
						|
            self._next_txn_id += 1
 | 
						|
 | 
						|
            logger.info(
 | 
						|
                "TX [%s] {%s} Sending transaction [%s], (PDUs: %d, EDUs: %d)",
 | 
						|
                destination,
 | 
						|
                txn_id,
 | 
						|
                transaction.transaction_id,
 | 
						|
                len(pdus),
 | 
						|
                len(edus),
 | 
						|
            )
 | 
						|
 | 
						|
            # Actually send the transaction
 | 
						|
 | 
						|
            # FIXME (erikj): This is a bit of a hack to make the Pdu age
 | 
						|
            # keys work
 | 
						|
            def json_data_cb():
 | 
						|
                data = transaction.get_dict()
 | 
						|
                now = int(self.clock.time_msec())
 | 
						|
                if "pdus" in data:
 | 
						|
                    for p in data["pdus"]:
 | 
						|
                        if "age_ts" in p:
 | 
						|
                            unsigned = p.setdefault("unsigned", {})
 | 
						|
                            unsigned["age"] = now - int(p["age_ts"])
 | 
						|
                            del p["age_ts"]
 | 
						|
                return data
 | 
						|
 | 
						|
            try:
 | 
						|
                response = await self._transport_layer.send_transaction(
 | 
						|
                    transaction, json_data_cb
 | 
						|
                )
 | 
						|
                code = 200
 | 
						|
            except HttpResponseException as e:
 | 
						|
                code = e.code
 | 
						|
                response = e.response
 | 
						|
 | 
						|
                if e.code in (401, 404, 429) or 500 <= e.code:
 | 
						|
                    logger.info(
 | 
						|
                        "TX [%s] {%s} got %d response", destination, txn_id, code
 | 
						|
                    )
 | 
						|
                    raise e
 | 
						|
 | 
						|
            logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
 | 
						|
 | 
						|
            if code == 200:
 | 
						|
                for e_id, r in response.get("pdus", {}).items():
 | 
						|
                    if "error" in r:
 | 
						|
                        logger.warning(
 | 
						|
                            "TX [%s] {%s} Remote returned error for %s: %s",
 | 
						|
                            destination,
 | 
						|
                            txn_id,
 | 
						|
                            e_id,
 | 
						|
                            r,
 | 
						|
                        )
 | 
						|
            else:
 | 
						|
                for p in pdus:
 | 
						|
                    logger.warning(
 | 
						|
                        "TX [%s] {%s} Failed to send event %s",
 | 
						|
                        destination,
 | 
						|
                        txn_id,
 | 
						|
                        p.event_id,
 | 
						|
                    )
 | 
						|
                success = False
 | 
						|
 | 
						|
            set_tag(tags.ERROR, not success)
 | 
						|
            return success
 |