Add type hints to misc. files. (#9676)

pull/9678/head
Patrick Cloke 2021-03-24 06:49:01 -04:00 committed by GitHub
parent 7e8dc9934e
commit af387cf52a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 57 additions and 54 deletions

1
changelog.d/9676.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to third party event rules and visibility modules.

View File

@ -20,8 +20,9 @@ files =
synapse/crypto, synapse/crypto,
synapse/event_auth.py, synapse/event_auth.py,
synapse/events/builder.py, synapse/events/builder.py,
synapse/events/validator.py,
synapse/events/spamcheck.py, synapse/events/spamcheck.py,
synapse/events/third_party_rules.py,
synapse/events/validator.py,
synapse/federation, synapse/federation,
synapse/groups, synapse/groups,
synapse/handlers, synapse/handlers,
@ -38,6 +39,7 @@ files =
synapse/push, synapse/push,
synapse/replication, synapse/replication,
synapse/rest, synapse/rest,
synapse/secrets.py,
synapse/server.py, synapse/server.py,
synapse/server_notices, synapse/server_notices,
synapse/spam_checker_api, synapse/spam_checker_api,
@ -71,6 +73,7 @@ files =
synapse/util/metrics.py, synapse/util/metrics.py,
synapse/util/macaroons.py, synapse/util/macaroons.py,
synapse/util/stringutils.py, synapse/util/stringutils.py,
synapse/visibility.py,
tests/replication, tests/replication,
tests/test_utils, tests/test_utils,
tests/handlers/test_password_providers.py, tests/handlers/test_password_providers.py,

View File

@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable, Union from typing import TYPE_CHECKING, Union
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.types import Requester, StateMap from synapse.types import Requester, StateMap
if TYPE_CHECKING:
from synapse.server import HomeServer
class ThirdPartyEventRules: class ThirdPartyEventRules:
"""Allows server admins to provide a Python module implementing an extra """Allows server admins to provide a Python module implementing an extra
@ -28,7 +31,7 @@ class ThirdPartyEventRules:
behaviours. behaviours.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.third_party_rules = None self.third_party_rules = None
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -95,10 +98,9 @@ class ThirdPartyEventRules:
if self.third_party_rules is None: if self.third_party_rules is None:
return True return True
ret = await self.third_party_rules.on_create_room( return await self.third_party_rules.on_create_room(
requester, config, is_requester_admin requester, config, is_requester_admin
) )
return ret
async def check_threepid_can_be_invited( async def check_threepid_can_be_invited(
self, medium: str, address: str, room_id: str self, medium: str, address: str, room_id: str
@ -119,10 +121,9 @@ class ThirdPartyEventRules:
state_events = await self._get_state_map_for_room(room_id) state_events = await self._get_state_map_for_room(room_id)
ret = await self.third_party_rules.check_threepid_can_be_invited( return await self.third_party_rules.check_threepid_can_be_invited(
medium, address, state_events medium, address, state_events
) )
return ret
async def check_visibility_can_be_modified( async def check_visibility_can_be_modified(
self, room_id: str, new_visibility: str self, room_id: str, new_visibility: str
@ -143,7 +144,7 @@ class ThirdPartyEventRules:
check_func = getattr( check_func = getattr(
self.third_party_rules, "check_visibility_can_be_modified", None self.third_party_rules, "check_visibility_can_be_modified", None
) )
if not check_func or not isinstance(check_func, Callable): if not check_func or not callable(check_func):
return True return True
state_events = await self._get_state_map_for_room(room_id) state_events = await self._get_state_map_for_room(room_id)

View File

@ -26,10 +26,10 @@ if sys.version_info[0:2] >= (3, 6):
import secrets import secrets
class Secrets: class Secrets:
def token_bytes(self, nbytes=32): def token_bytes(self, nbytes: int = 32) -> bytes:
return secrets.token_bytes(nbytes) return secrets.token_bytes(nbytes)
def token_hex(self, nbytes=32): def token_hex(self, nbytes: int = 32) -> str:
return secrets.token_hex(nbytes) return secrets.token_hex(nbytes)
@ -38,8 +38,8 @@ else:
import os import os
class Secrets: class Secrets:
def token_bytes(self, nbytes=32): def token_bytes(self, nbytes: int = 32) -> bytes:
return os.urandom(nbytes) return os.urandom(nbytes)
def token_hex(self, nbytes=32): def token_hex(self, nbytes: int = 32) -> str:
return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii") return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii")

View File

@ -449,7 +449,7 @@ class StateGroupStorage:
return self.stores.state._get_state_groups_from_groups(groups, state_filter) return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events( async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
) -> Dict[str, StateMap[EventBase]]: ) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
dicts for each event. dicts for each event.
@ -485,7 +485,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events( async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
) -> Dict[str, StateMap[str]]: ) -> Dict[str, StateMap[str]]:
""" """
Get the state dicts corresponding to a list of events, containing the event_ids Get the state dicts corresponding to a list of events, containing the event_ids

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import operator from typing import Dict, FrozenSet, List, Optional
from synapse.api.constants import ( from synapse.api.constants import (
AccountDataTypes, AccountDataTypes,
@ -21,10 +21,11 @@ from synapse.api.constants import (
HistoryVisibility, HistoryVisibility,
Membership, Membership,
) )
from synapse.events import EventBase
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.storage import Storage from synapse.storage import Storage
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id from synapse.types import StateMap, get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,32 +49,32 @@ MEMBERSHIP_PRIORITY = (
async def filter_events_for_client( async def filter_events_for_client(
storage: Storage, storage: Storage,
user_id, user_id: str,
events, events: List[EventBase],
is_peeking=False, is_peeking: bool = False,
always_include_ids=frozenset(), always_include_ids: FrozenSet[str] = frozenset(),
filter_send_to_client=True, filter_send_to_client: bool = True,
): ) -> List[EventBase]:
""" """
Check which events a user is allowed to see. If the user can see the event but its Check which events a user is allowed to see. If the user can see the event but its
sender asked for their data to be erased, prune the content of the event. sender asked for their data to be erased, prune the content of the event.
Args: Args:
storage storage
user_id(str): user id to be checked user_id: user id to be checked
events(list[synapse.events.EventBase]): sequence of events to be checked events: sequence of events to be checked
is_peeking(bool): should be True if: is_peeking: should be True if:
* the user is not currently a member of the room, and: * the user is not currently a member of the room, and:
* the user has not been a member of the room since the given * the user has not been a member of the room since the given
events events
always_include_ids (set(event_id)): set of event ids to specifically always_include_ids: set of event ids to specifically
include (unless sender is ignored) include (unless sender is ignored)
filter_send_to_client (bool): Whether we're checking an event that's going to be filter_send_to_client: Whether we're checking an event that's going to be
sent to a client. This might not always be the case since this function can sent to a client. This might not always be the case since this function can
also be called to check whether a user can see the state at a given point. also be called to check whether a user can see the state at a given point.
Returns: Returns:
list[synapse.events.EventBase] The filtered events.
""" """
# Filter out events that have been soft failed so that we don't relay them # Filter out events that have been soft failed so that we don't relay them
# to clients. # to clients.
@ -90,7 +91,7 @@ async def filter_events_for_client(
AccountDataTypes.IGNORED_USER_LIST, user_id AccountDataTypes.IGNORED_USER_LIST, user_id
) )
ignore_list = frozenset() ignore_list = frozenset() # type: FrozenSet[str]
if ignore_dict_content: if ignore_dict_content:
ignored_users_dict = ignore_dict_content.get("ignored_users", {}) ignored_users_dict = ignore_dict_content.get("ignored_users", {})
if isinstance(ignored_users_dict, dict): if isinstance(ignored_users_dict, dict):
@ -107,19 +108,18 @@ async def filter_events_for_client(
room_id room_id
] = await storage.main.get_retention_policy_for_room(room_id) ] = await storage.main.get_retention_policy_for_room(room_id)
def allowed(event): def allowed(event: EventBase) -> Optional[EventBase]:
""" """
Args: Args:
event (synapse.events.EventBase): event to check event: event to check
Returns: Returns:
None|EventBase: None if the user cannot see this event at all
None if the user cannot see this event at all
a redacted copy of the event if they can only see a redacted a redacted copy of the event if they can only see a redacted
version version
the original event if they can see it as normal. the original event if they can see it as normal.
""" """
# Only run some checks if these events aren't about to be sent to clients. This is # Only run some checks if these events aren't about to be sent to clients. This is
# because, if this is not the case, we're probably only checking if the users can # because, if this is not the case, we're probably only checking if the users can
@ -252,48 +252,46 @@ async def filter_events_for_client(
return event return event
# check each event: gives an iterable[None|EventBase] # Check each event: gives an iterable of None or (a potentially modified)
# EventBase.
filtered_events = map(allowed, events) filtered_events = map(allowed, events)
# remove the None entries # Turn it into a list and remove None entries before returning.
filtered_events = filter(operator.truth, filtered_events) return [ev for ev in filtered_events if ev]
# we turn it into a list before returning it.
return list(filtered_events)
async def filter_events_for_server( async def filter_events_for_server(
storage: Storage, storage: Storage,
server_name, server_name: str,
events, events: List[EventBase],
redact=True, redact: bool = True,
check_history_visibility_only=False, check_history_visibility_only: bool = False,
): ) -> List[EventBase]:
"""Filter a list of events based on whether given server is allowed to """Filter a list of events based on whether given server is allowed to
see them. see them.
Args: Args:
storage storage
server_name (str) server_name
events (iterable[FrozenEvent]) events
redact (bool): Whether to return a redacted version of the event, or redact: Whether to return a redacted version of the event, or
to filter them out entirely. to filter them out entirely.
check_history_visibility_only (bool): Whether to only check the check_history_visibility_only: Whether to only check the
history visibility, rather than things like if the sender has been history visibility, rather than things like if the sender has been
erased. This is used e.g. during pagination to decide whether to erased. This is used e.g. during pagination to decide whether to
backfill or not. backfill or not.
Returns Returns
list[FrozenEvent] The filtered events.
""" """
def is_sender_erased(event, erased_senders): def is_sender_erased(event: EventBase, erased_senders: Dict[str, bool]) -> bool:
if erased_senders and erased_senders[event.sender]: if erased_senders and erased_senders[event.sender]:
logger.info("Sender of %s has been erased, redacting", event.event_id) logger.info("Sender of %s has been erased, redacting", event.event_id)
return True return True
return False return False
def check_event_is_visible(event, state): def check_event_is_visible(event: EventBase, state: StateMap[EventBase]) -> bool:
history = state.get((EventTypes.RoomHistoryVisibility, ""), None) history = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if history: if history:
visibility = history.content.get( visibility = history.content.get(