242 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			242 lines
		
	
	
		
			8.8 KiB
		
	
	
	
		
			Python
		
	
	
# Copyright 2021 The Matrix.org Foundation C.I.C.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
import logging
 | 
						|
from typing import (
 | 
						|
    TYPE_CHECKING,
 | 
						|
    Any,
 | 
						|
    Awaitable,
 | 
						|
    Callable,
 | 
						|
    Dict,
 | 
						|
    Iterable,
 | 
						|
    List,
 | 
						|
    Optional,
 | 
						|
    Set,
 | 
						|
    TypeVar,
 | 
						|
    Union,
 | 
						|
)
 | 
						|
 | 
						|
from typing_extensions import ParamSpec
 | 
						|
 | 
						|
from twisted.internet.defer import CancelledError
 | 
						|
 | 
						|
from synapse.api.presence import UserPresenceState
 | 
						|
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from synapse.server import HomeServer
 | 
						|
 | 
						|
GET_USERS_FOR_STATES_CALLBACK = Callable[
 | 
						|
    [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
 | 
						|
]
 | 
						|
# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
 | 
						|
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
 | 
						|
P = ParamSpec("P")
 | 
						|
R = TypeVar("R")
 | 
						|
 | 
						|
 | 
						|
def load_legacy_presence_router(hs: "HomeServer") -> None:
 | 
						|
    """Wrapper that loads a presence router module configured using the old
 | 
						|
    configuration, and registers the hooks they implement.
 | 
						|
    """
 | 
						|
 | 
						|
    if hs.config.server.presence_router_module_class is None:
 | 
						|
        return
 | 
						|
 | 
						|
    module = hs.config.server.presence_router_module_class
 | 
						|
    config = hs.config.server.presence_router_config
 | 
						|
    api = hs.get_module_api()
 | 
						|
 | 
						|
    presence_router = module(config=config, module_api=api)
 | 
						|
 | 
						|
    # The known hooks. If a module implements a method which name appears in this set,
 | 
						|
    # we'll want to register it.
 | 
						|
    presence_router_methods = {
 | 
						|
        "get_users_for_states",
 | 
						|
        "get_interested_users",
 | 
						|
    }
 | 
						|
 | 
						|
    # All methods that the module provides should be async, but this wasn't enforced
 | 
						|
    # in the old module system, so we wrap them if needed
 | 
						|
    def async_wrapper(
 | 
						|
        f: Optional[Callable[P, R]]
 | 
						|
    ) -> Optional[Callable[P, Awaitable[R]]]:
 | 
						|
        # f might be None if the callback isn't implemented by the module. In this
 | 
						|
        # case we don't want to register a callback at all so we return None.
 | 
						|
        if f is None:
 | 
						|
            return None
 | 
						|
 | 
						|
        def run(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
 | 
						|
            # Assertion required because mypy can't prove we won't change `f`
 | 
						|
            # back to `None`. See
 | 
						|
            # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
 | 
						|
            assert f is not None
 | 
						|
 | 
						|
            return maybe_awaitable(f(*args, **kwargs))
 | 
						|
 | 
						|
        return run
 | 
						|
 | 
						|
    # Register the hooks through the module API.
 | 
						|
    hooks: Dict[str, Optional[Callable[..., Any]]] = {
 | 
						|
        hook: async_wrapper(getattr(presence_router, hook, None))
 | 
						|
        for hook in presence_router_methods
 | 
						|
    }
 | 
						|
 | 
						|
    api.register_presence_router_callbacks(**hooks)
 | 
						|
 | 
						|
 | 
						|
class PresenceRouter:
 | 
						|
    """
 | 
						|
    A module that the homeserver will call upon to help route user presence updates to
 | 
						|
    additional destinations.
 | 
						|
    """
 | 
						|
 | 
						|
    ALL_USERS = "ALL"
 | 
						|
 | 
						|
    def __init__(self, hs: "HomeServer"):
 | 
						|
        # Initially there are no callbacks
 | 
						|
        self._get_users_for_states_callbacks: List[GET_USERS_FOR_STATES_CALLBACK] = []
 | 
						|
        self._get_interested_users_callbacks: List[GET_INTERESTED_USERS_CALLBACK] = []
 | 
						|
 | 
						|
    def register_presence_router_callbacks(
 | 
						|
        self,
 | 
						|
        get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
 | 
						|
        get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
 | 
						|
    ) -> None:
 | 
						|
        # PresenceRouter modules are required to implement both of these methods
 | 
						|
        # or neither of them as they are assumed to act in a complementary manner
 | 
						|
        paired_methods = [get_users_for_states, get_interested_users]
 | 
						|
        if paired_methods.count(None) == 1:
 | 
						|
            raise RuntimeError(
 | 
						|
                "PresenceRouter modules must register neither or both of the paired callbacks: "
 | 
						|
                "[get_users_for_states, get_interested_users]"
 | 
						|
            )
 | 
						|
 | 
						|
        # Append the methods provided to the lists of callbacks
 | 
						|
        if get_users_for_states is not None:
 | 
						|
            self._get_users_for_states_callbacks.append(get_users_for_states)
 | 
						|
 | 
						|
        if get_interested_users is not None:
 | 
						|
            self._get_interested_users_callbacks.append(get_interested_users)
 | 
						|
 | 
						|
    async def get_users_for_states(
 | 
						|
        self,
 | 
						|
        state_updates: Iterable[UserPresenceState],
 | 
						|
    ) -> Dict[str, Set[UserPresenceState]]:
 | 
						|
        """
 | 
						|
        Given an iterable of user presence updates, determine where each one
 | 
						|
        needs to go.
 | 
						|
 | 
						|
        Args:
 | 
						|
            state_updates: An iterable of user presence state updates.
 | 
						|
 | 
						|
        Returns:
 | 
						|
          A dictionary of user_id -> set of UserPresenceState, indicating which
 | 
						|
          presence updates each user should receive.
 | 
						|
        """
 | 
						|
 | 
						|
        # Bail out early if we don't have any callbacks to run.
 | 
						|
        if len(self._get_users_for_states_callbacks) == 0:
 | 
						|
            # Don't include any extra destinations for presence updates
 | 
						|
            return {}
 | 
						|
 | 
						|
        users_for_states: Dict[str, Set[UserPresenceState]] = {}
 | 
						|
        # run all the callbacks for get_users_for_states and combine the results
 | 
						|
        for callback in self._get_users_for_states_callbacks:
 | 
						|
            try:
 | 
						|
                # Note: result is an object here, because we don't trust modules to
 | 
						|
                # return the types they're supposed to.
 | 
						|
                result: object = await delay_cancellation(callback(state_updates))
 | 
						|
            except CancelledError:
 | 
						|
                raise
 | 
						|
            except Exception as e:
 | 
						|
                logger.warning("Failed to run module API callback %s: %s", callback, e)
 | 
						|
                continue
 | 
						|
 | 
						|
            if not isinstance(result, Dict):
 | 
						|
                logger.warning(
 | 
						|
                    "Wrong type returned by module API callback %s: %s, expected Dict",
 | 
						|
                    callback,
 | 
						|
                    result,
 | 
						|
                )
 | 
						|
                continue
 | 
						|
 | 
						|
            for key, new_entries in result.items():
 | 
						|
                if not isinstance(new_entries, Set):
 | 
						|
                    logger.warning(
 | 
						|
                        "Wrong type returned by module API callback %s: %s, expected Set",
 | 
						|
                        callback,
 | 
						|
                        new_entries,
 | 
						|
                    )
 | 
						|
                    break
 | 
						|
                users_for_states.setdefault(key, set()).update(new_entries)
 | 
						|
 | 
						|
        return users_for_states
 | 
						|
 | 
						|
    async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
 | 
						|
        """
 | 
						|
        Retrieve a list of users that `user_id` is interested in receiving the
 | 
						|
        presence of. This will be in addition to those they share a room with.
 | 
						|
        Optionally, the object PresenceRouter.ALL_USERS can be returned to indicate
 | 
						|
        that this user should receive all incoming local and remote presence updates.
 | 
						|
 | 
						|
        Note that this method will only be called for local users, but can return users
 | 
						|
        that are local or remote.
 | 
						|
 | 
						|
        Args:
 | 
						|
            user_id: A user requesting presence updates.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            A set of user IDs to return presence updates for, or ALL_USERS to return all
 | 
						|
            known updates.
 | 
						|
        """
 | 
						|
 | 
						|
        # Bail out early if we don't have any callbacks to run.
 | 
						|
        if len(self._get_interested_users_callbacks) == 0:
 | 
						|
            # Don't report any additional interested users
 | 
						|
            return set()
 | 
						|
 | 
						|
        interested_users = set()
 | 
						|
        # run all the callbacks for get_interested_users and combine the results
 | 
						|
        for callback in self._get_interested_users_callbacks:
 | 
						|
            try:
 | 
						|
                result = await delay_cancellation(callback(user_id))
 | 
						|
            except CancelledError:
 | 
						|
                raise
 | 
						|
            except Exception as e:
 | 
						|
                logger.warning("Failed to run module API callback %s: %s", callback, e)
 | 
						|
                continue
 | 
						|
 | 
						|
            # If one of the callbacks returns ALL_USERS then we can stop calling all
 | 
						|
            # of the other callbacks, since the set of interested_users is already as
 | 
						|
            # large as it can possibly be
 | 
						|
            if result == PresenceRouter.ALL_USERS:
 | 
						|
                return PresenceRouter.ALL_USERS
 | 
						|
 | 
						|
            if not isinstance(result, Set):
 | 
						|
                logger.warning(
 | 
						|
                    "Wrong type returned by module API callback %s: %s, expected set",
 | 
						|
                    callback,
 | 
						|
                    result,
 | 
						|
                )
 | 
						|
                continue
 | 
						|
 | 
						|
            # Add the new interested users to the set
 | 
						|
            interested_users.update(result)
 | 
						|
 | 
						|
        return interested_users
 |