Merge pull request #3638 from matrix-org/rav/refactor_federation_client_exception_handling
Factor out exception handling in federation_clientpull/3421/merge
						commit
						bdae8f2e68
					
				| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Factor out exception handling in federation_client
 | 
			
		||||
| 
						 | 
				
			
			@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t
 | 
			
		|||
PDU_RETRY_TIME_MS = 1 * 60 * 1000
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InvalidResponseError(RuntimeError):
 | 
			
		||||
    """Helper for _try_destination_list: indicates that the server returned a response
 | 
			
		||||
    we couldn't parse
 | 
			
		||||
    """
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FederationClient(FederationBase):
 | 
			
		||||
    def __init__(self, hs):
 | 
			
		||||
        super(FederationClient, self).__init__(hs)
 | 
			
		||||
| 
						 | 
				
			
			@ -458,6 +465,61 @@ class FederationClient(FederationBase):
 | 
			
		|||
        defer.returnValue(signed_auth)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _try_destination_list(self, description, destinations, callback):
 | 
			
		||||
        """Try an operation on a series of servers, until it succeeds
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            description (unicode): description of the operation we're doing, for logging
 | 
			
		||||
 | 
			
		||||
            destinations (Iterable[unicode]): list of server_names to try
 | 
			
		||||
 | 
			
		||||
            callback (callable):  Function to run for each server. Passed a single
 | 
			
		||||
                argument: the server_name to try. May return a deferred.
 | 
			
		||||
 | 
			
		||||
                If the callback raises a CodeMessageException with a 300/400 code,
 | 
			
		||||
                attempts to perform the operation stop immediately and the exception is
 | 
			
		||||
                reraised.
 | 
			
		||||
 | 
			
		||||
                Otherwise, if the callback raises an Exception the error is logged and the
 | 
			
		||||
                next server tried. Normally the stacktrace is logged but this is
 | 
			
		||||
                suppressed if the exception is an InvalidResponseError.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            The [Deferred] result of callback, if it succeeds
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            CodeMessageException if the chosen remote server returns a 300/400 code.
 | 
			
		||||
 | 
			
		||||
            RuntimeError if no servers were reachable.
 | 
			
		||||
        """
 | 
			
		||||
        for destination in destinations:
 | 
			
		||||
            if destination == self.server_name:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
                res = yield callback(destination)
 | 
			
		||||
                defer.returnValue(res)
 | 
			
		||||
            except InvalidResponseError as e:
 | 
			
		||||
                logger.warn(
 | 
			
		||||
                    "Failed to %s via %s: %s",
 | 
			
		||||
                    description, destination, e,
 | 
			
		||||
                )
 | 
			
		||||
            except CodeMessageException as e:
 | 
			
		||||
                if not 500 <= e.code < 600:
 | 
			
		||||
                    raise
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.warn(
 | 
			
		||||
                        "Failed to %s via %s: %i %s",
 | 
			
		||||
                        description, destination, e.code, e.message,
 | 
			
		||||
                    )
 | 
			
		||||
            except Exception:
 | 
			
		||||
                logger.warn(
 | 
			
		||||
                    "Failed to %s via %s",
 | 
			
		||||
                    description, destination, exc_info=1,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        raise RuntimeError("Failed to %s via any server", description)
 | 
			
		||||
 | 
			
		||||
    def make_membership_event(self, destinations, room_id, user_id, membership,
 | 
			
		||||
                              content={},):
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -492,11 +554,9 @@ class FederationClient(FederationBase):
 | 
			
		|||
                "make_membership_event called with membership='%s', must be one of %s" %
 | 
			
		||||
                (membership, ",".join(valid_memberships))
 | 
			
		||||
            )
 | 
			
		||||
        for destination in destinations:
 | 
			
		||||
            if destination == self.server_name:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
        @defer.inlineCallbacks
 | 
			
		||||
        def send_request(destination):
 | 
			
		||||
            ret = yield self.transport_layer.make_membership_event(
 | 
			
		||||
                destination, room_id, user_id, membership
 | 
			
		||||
            )
 | 
			
		||||
| 
						 | 
				
			
			@ -518,24 +578,11 @@ class FederationClient(FederationBase):
 | 
			
		|||
            defer.returnValue(
 | 
			
		||||
                (destination, ev)
 | 
			
		||||
            )
 | 
			
		||||
                break
 | 
			
		||||
            except CodeMessageException as e:
 | 
			
		||||
                if not 500 <= e.code < 600:
 | 
			
		||||
                    raise
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.warn(
 | 
			
		||||
                        "Failed to make_%s via %s: %s",
 | 
			
		||||
                        membership, destination, e.message
 | 
			
		||||
                    )
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.warn(
 | 
			
		||||
                    "Failed to make_%s via %s: %s",
 | 
			
		||||
                    membership, destination, e.message
 | 
			
		||||
 | 
			
		||||
        return self._try_destination_list(
 | 
			
		||||
            "make_" + membership, destinations, send_request,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        raise RuntimeError("Failed to send to any server.")
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def send_join(self, destinations, pdu):
 | 
			
		||||
        """Sends a join event to one of a list of homeservers.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -558,11 +605,8 @@ class FederationClient(FederationBase):
 | 
			
		|||
            Fails with a ``RuntimeError`` if no servers were reachable.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        for destination in destinations:
 | 
			
		||||
            if destination == self.server_name:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
        @defer.inlineCallbacks
 | 
			
		||||
        def send_request(destination):
 | 
			
		||||
            time_now = self._clock.time_msec()
 | 
			
		||||
            _, content = yield self.transport_layer.send_join(
 | 
			
		||||
                destination=destination,
 | 
			
		||||
| 
						 | 
				
			
			@ -624,21 +668,7 @@ class FederationClient(FederationBase):
 | 
			
		|||
                "auth_chain": signed_auth,
 | 
			
		||||
                "origin": destination,
 | 
			
		||||
            })
 | 
			
		||||
            except CodeMessageException as e:
 | 
			
		||||
                if not 500 <= e.code < 600:
 | 
			
		||||
                    raise
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.exception(
 | 
			
		||||
                        "Failed to send_join via %s: %s",
 | 
			
		||||
                        destination, e.message
 | 
			
		||||
                    )
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.exception(
 | 
			
		||||
                    "Failed to send_join via %s: %s",
 | 
			
		||||
                    destination, e.message
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        raise RuntimeError("Failed to send to any server.")
 | 
			
		||||
        return self._try_destination_list("send_join", destinations, send_request)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def send_invite(self, destination, room_id, event_id, pdu):
 | 
			
		||||
| 
						 | 
				
			
			@ -663,7 +693,6 @@ class FederationClient(FederationBase):
 | 
			
		|||
 | 
			
		||||
        defer.returnValue(pdu)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def send_leave(self, destinations, pdu):
 | 
			
		||||
        """Sends a leave event to one of a list of homeservers.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -681,15 +710,12 @@ class FederationClient(FederationBase):
 | 
			
		|||
            Deferred: resolves to None.
 | 
			
		||||
 | 
			
		||||
            Fails with a ``CodeMessageException`` if the chosen remote server
 | 
			
		||||
            returns a non-200 code.
 | 
			
		||||
            returns a 300/400 code.
 | 
			
		||||
 | 
			
		||||
            Fails with a ``RuntimeError`` if no servers were reachable.
 | 
			
		||||
        """
 | 
			
		||||
        for destination in destinations:
 | 
			
		||||
            if destination == self.server_name:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            try:
 | 
			
		||||
        @defer.inlineCallbacks
 | 
			
		||||
        def send_request(destination):
 | 
			
		||||
            time_now = self._clock.time_msec()
 | 
			
		||||
            _, content = yield self.transport_layer.send_leave(
 | 
			
		||||
                destination=destination,
 | 
			
		||||
| 
						 | 
				
			
			@ -700,15 +726,8 @@ class FederationClient(FederationBase):
 | 
			
		|||
 | 
			
		||||
            logger.debug("Got content: %s", content)
 | 
			
		||||
            defer.returnValue(None)
 | 
			
		||||
            except CodeMessageException:
 | 
			
		||||
                raise
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.exception(
 | 
			
		||||
                    "Failed to send_leave via %s: %s",
 | 
			
		||||
                    destination, e.message
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        raise RuntimeError("Failed to send to any server.")
 | 
			
		||||
        return self._try_destination_list("send_leave", destinations, send_request)
 | 
			
		||||
 | 
			
		||||
    def get_public_rooms(self, destination, limit=None, since_token=None,
 | 
			
		||||
                         search_filter=None, include_all_networks=False,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue