Use an enum for direction. (#14927)

For better type safety we  use an enum instead of strings to
configure direction (backwards or forwards).
pull/14935/head
Patrick Cloke 2023-01-27 07:27:55 -05:00 committed by GitHub
parent fc35e0673f
commit 265735db9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 76 additions and 44 deletions

1
changelog.d/14927.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints.

View File

@ -17,6 +17,8 @@
"""Contains constants from the specification.""" """Contains constants from the specification."""
import enum
from typing_extensions import Final from typing_extensions import Final
# the max size of a (canonical-json-encoded) event # the max size of a (canonical-json-encoded) event
@ -290,3 +292,8 @@ class ApprovalNoticeMedium:
NONE = "org.matrix.msc3866.none" NONE = "org.matrix.msc3866.none"
EMAIL = "org.matrix.msc3866.email" EMAIL = "org.matrix.msc3866.email"
class Direction(enum.Enum):
BACKWARDS = "b"
FORWARDS = "f"

View File

@ -16,7 +16,7 @@ import abc
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from synapse.api.constants import Membership from synapse.api.constants import Direction, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -197,7 +197,7 @@ class AdminHandler:
# efficient method perhaps but it does guarantee we get everything. # efficient method perhaps but it does guarantee we get everything.
while True: while True:
events, _ = await self.store.paginate_room_events( events, _ = await self.store.paginate_room_events(
room_id, from_key, to_key, limit=100, direction="f" room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
) )
if not events: if not events:
break break

View File

@ -15,7 +15,13 @@
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership from synapse.api.constants import (
AccountDataTypes,
Direction,
EduTypes,
EventTypes,
Membership,
)
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig from synapse.events.utils import SerializeEventConfig
@ -57,7 +63,13 @@ class InitialSyncHandler:
self.validator = EventValidator() self.validator = EventValidator()
self.snapshot_cache: ResponseCache[ self.snapshot_cache: ResponseCache[
Tuple[ Tuple[
str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool str,
Optional[StreamToken],
Optional[StreamToken],
Direction,
int,
bool,
bool,
] ]
] = ResponseCache(hs.get_clock(), "initial_sync_cache") ] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()

View File

@ -19,7 +19,7 @@ import attr
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig from synapse.events.utils import SerializeEventConfig
@ -448,7 +448,7 @@ class PaginationHandler:
if pagin_config.from_token: if pagin_config.from_token:
from_token = pagin_config.from_token from_token = pagin_config.from_token
elif pagin_config.direction == "f": elif pagin_config.direction == Direction.FORWARDS:
from_token = ( from_token = (
await self.hs.get_event_sources().get_start_token_for_pagination( await self.hs.get_event_sources().get_start_token_for_pagination(
room_id room_id
@ -476,7 +476,7 @@ class PaginationHandler:
room_id, requester, allow_departed_users=True room_id, requester, allow_departed_users=True
) )
if pagin_config.direction == "b": if pagin_config.direction == Direction.BACKWARDS:
# if we're going backwards, we might need to backfill. This # if we're going backwards, we might need to backfill. This
# requires that we have a topo token. # requires that we have a topo token.
if room_token.topological: if room_token.topological:

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, O
import attr import attr
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event from synapse.events import EventBase, relation_from_event
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
@ -413,7 +413,11 @@ class RelationsHandler:
# Attempt to find another event to use as the latest event. # Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event( potential_events, _ = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.THREAD, direction="f" event_id,
event,
room_id,
RelationTypes.THREAD,
direction=Direction.FORWARDS,
) )
# Filter out ignored users. # Filter out ignored users.

View File

@ -30,7 +30,7 @@ from typing import (
import attr import attr
from synapse.api.constants import MAIN_TIMELINE, RelationTypes from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -168,7 +168,7 @@ class RelationsWorkerStore(SQLBaseStore):
relation_type: Optional[str] = None, relation_type: Optional[str] = None,
event_type: Optional[str] = None, event_type: Optional[str] = None,
limit: int = 5, limit: int = 5,
direction: str = "b", direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None, from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
@ -181,8 +181,8 @@ class RelationsWorkerStore(SQLBaseStore):
relation_type: Only fetch events with this relation type, if given. relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given. event_type: Only fetch events with this event type, if given.
limit: Only fetch the most recent `limit` events. limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the direction: Whether to fetch the most recent first (backwards) or the
oldest first (`"f"`). oldest first (forwards).
from_token: Fetch rows from the given token, or from the start if None. from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None. to_token: Fetch rows up to the given token, or up to the end if None.

View File

@ -55,6 +55,7 @@ from typing_extensions import Literal
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Direction
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
@ -86,7 +87,6 @@ MAX_STREAM_SIZE = 1000
_STREAM_TOKEN = "stream" _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological" _TOPOLOGICAL_TOKEN = "topological"
# Used as return values for pagination APIs # Used as return values for pagination APIs
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn: class _EventDictReturn:
@ -104,7 +104,7 @@ class _EventsAround:
def generate_pagination_where_clause( def generate_pagination_where_clause(
direction: str, direction: Direction,
column_names: Tuple[str, str], column_names: Tuple[str, str],
from_token: Optional[Tuple[Optional[int], int]], from_token: Optional[Tuple[Optional[int], int]],
to_token: Optional[Tuple[Optional[int], int]], to_token: Optional[Tuple[Optional[int], int]],
@ -130,27 +130,26 @@ def generate_pagination_where_clause(
token, but include those that match the to token. token, but include those that match the to token.
Args: Args:
direction: Whether we're paginating backwards("b") or forwards ("f"). direction: Whether we're paginating backwards or forwards.
column_names: The column names to bound. Must *not* be user defined as column_names: The column names to bound. Must *not* be user defined as
these get inserted directly into the SQL statement without escapes. these get inserted directly into the SQL statement without escapes.
from_token: The start point for the pagination. This is an exclusive from_token: The start point for the pagination. This is an exclusive
minimum bound if direction is "f", and an inclusive maximum bound if minimum bound if direction is forwards, and an inclusive maximum bound if
direction is "b". direction is backwards.
to_token: The endpoint point for the pagination. This is an inclusive to_token: The endpoint point for the pagination. This is an inclusive
maximum bound if direction is "f", and an exclusive minimum bound if maximum bound if direction is forwards, and an exclusive minimum bound if
direction is "b". direction is backwards.
engine: The database engine to generate the clauses for engine: The database engine to generate the clauses for
Returns: Returns:
The sql expression The sql expression
""" """
assert direction in ("b", "f")
where_clause = [] where_clause = []
if from_token: if from_token:
where_clause.append( where_clause.append(
_make_generic_sql_bound( _make_generic_sql_bound(
bound=">=" if direction == "b" else "<", bound=">=" if direction == Direction.BACKWARDS else "<",
column_names=column_names, column_names=column_names,
values=from_token, values=from_token,
engine=engine, engine=engine,
@ -160,7 +159,7 @@ def generate_pagination_where_clause(
if to_token: if to_token:
where_clause.append( where_clause.append(
_make_generic_sql_bound( _make_generic_sql_bound(
bound="<" if direction == "b" else ">=", bound="<" if direction == Direction.BACKWARDS else ">=",
column_names=column_names, column_names=column_names,
values=to_token, values=to_token,
engine=engine, engine=engine,
@ -171,7 +170,7 @@ def generate_pagination_where_clause(
def generate_pagination_bounds( def generate_pagination_bounds(
direction: str, direction: Direction,
from_token: Optional[RoomStreamToken], from_token: Optional[RoomStreamToken],
to_token: Optional[RoomStreamToken], to_token: Optional[RoomStreamToken],
) -> Tuple[ ) -> Tuple[
@ -181,7 +180,7 @@ def generate_pagination_bounds(
Generate a start and end point for this page of events. Generate a start and end point for this page of events.
Args: Args:
direction: Whether pagination is going forwards or backwards. One of "f" or "b". direction: Whether pagination is going forwards or backwards.
from_token: The token to start pagination at, or None to start at the first value. from_token: The token to start pagination at, or None to start at the first value.
to_token: The token to end pagination at, or None to not limit the end point. to_token: The token to end pagination at, or None to not limit the end point.
@ -201,7 +200,7 @@ def generate_pagination_bounds(
# Tokens really represent positions between elements, but we use # Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence # the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities. # we have a bit of asymmetry when it comes to equalities.
if direction == "b": if direction == Direction.BACKWARDS:
order = "DESC" order = "DESC"
else: else:
order = "ASC" order = "ASC"
@ -215,7 +214,7 @@ def generate_pagination_bounds(
if from_token: if from_token:
if from_token.topological is not None: if from_token.topological is not None:
from_bound = from_token.as_historical_tuple() from_bound = from_token.as_historical_tuple()
elif direction == "b": elif direction == Direction.BACKWARDS:
from_bound = ( from_bound = (
None, None,
from_token.get_max_stream_pos(), from_token.get_max_stream_pos(),
@ -230,7 +229,7 @@ def generate_pagination_bounds(
if to_token: if to_token:
if to_token.topological is not None: if to_token.topological is not None:
to_bound = to_token.as_historical_tuple() to_bound = to_token.as_historical_tuple()
elif direction == "b": elif direction == Direction.BACKWARDS:
to_bound = ( to_bound = (
None, None,
to_token.stream, to_token.stream,
@ -245,20 +244,20 @@ def generate_pagination_bounds(
def generate_next_token( def generate_next_token(
direction: str, last_topo_ordering: int, last_stream_ordering: int direction: Direction, last_topo_ordering: int, last_stream_ordering: int
) -> RoomStreamToken: ) -> RoomStreamToken:
""" """
Generate the next room stream token based on the currently returned data. Generate the next room stream token based on the currently returned data.
Args: Args:
direction: Whether pagination is going forwards or backwards. One of "f" or "b". direction: Whether pagination is going forwards or backwards.
last_topo_ordering: The last topological ordering being returned. last_topo_ordering: The last topological ordering being returned.
last_stream_ordering: The last stream ordering being returned. last_stream_ordering: The last stream ordering being returned.
Returns: Returns:
A new RoomStreamToken to return to the client. A new RoomStreamToken to return to the client.
""" """
if direction == "b": if direction == Direction.BACKWARDS:
# Tokens are positions between events. # Tokens are positions between events.
# This token points *after* the last event in the chunk. # This token points *after* the last event in the chunk.
# We need it to point to the event before it in the chunk # We need it to point to the event before it in the chunk
@ -1201,7 +1200,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn, txn,
room_id, room_id,
before_token, before_token,
direction="b", direction=Direction.BACKWARDS,
limit=before_limit, limit=before_limit,
event_filter=event_filter, event_filter=event_filter,
) )
@ -1211,7 +1210,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn, txn,
room_id, room_id,
after_token, after_token,
direction="f", direction=Direction.FORWARDS,
limit=after_limit, limit=after_limit,
event_filter=event_filter, event_filter=event_filter,
) )
@ -1374,7 +1373,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id: str, room_id: str,
from_token: RoomStreamToken, from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None, to_token: Optional[RoomStreamToken] = None,
direction: str = "b", direction: Direction = Direction.BACKWARDS,
limit: int = -1, limit: int = -1,
event_filter: Optional[Filter] = None, event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], RoomStreamToken]: ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
@ -1385,8 +1384,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id room_id
from_token: The token used to stream from from_token: The token used to stream from
to_token: A token which if given limits the results to only those before to_token: A token which if given limits the results to only those before
direction: Either 'b' or 'f' to indicate whether we are paginating direction: Indicates whether we are paginating forwards or backwards
forwards or backwards from `from_key`. from `from_key`.
limit: The maximum number of events to return. limit: The maximum number of events to return.
event_filter: If provided filters the events to event_filter: If provided filters the events to
those that match the filter. those that match the filter.
@ -1489,8 +1488,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
_EventDictReturn(event_id, topological_ordering, stream_ordering) _EventDictReturn(event_id, topological_ordering, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results( if _filter_results(
lower_token=to_token if direction == "b" else from_token, lower_token=to_token
upper_token=from_token if direction == "b" else to_token, if direction == Direction.BACKWARDS
else from_token,
upper_token=from_token
if direction == Direction.BACKWARDS
else to_token,
instance_name=instance_name, instance_name=instance_name,
topological_ordering=topological_ordering, topological_ordering=topological_ordering,
stream_ordering=stream_ordering, stream_ordering=stream_ordering,
@ -1514,7 +1517,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id: str, room_id: str,
from_key: RoomStreamToken, from_key: RoomStreamToken,
to_key: Optional[RoomStreamToken] = None, to_key: Optional[RoomStreamToken] = None,
direction: str = "b", direction: Direction = Direction.BACKWARDS,
limit: int = -1, limit: int = -1,
event_filter: Optional[Filter] = None, event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], RoomStreamToken]: ) -> Tuple[List[EventBase], RoomStreamToken]:
@ -1524,8 +1527,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id room_id
from_key: The token used to stream from from_key: The token used to stream from
to_key: A token which if given limits the results to only those before to_key: A token which if given limits the results to only those before
direction: Either 'b' or 'f' to indicate whether we are paginating direction: Indicates whether we are paginating forwards or backwards
forwards or backwards from `from_key`. from `from_key`.
limit: The maximum number of events to return. limit: The maximum number of events to return.
event_filter: If provided filters the events to those that match the filter. event_filter: If provided filters the events to those that match the filter.

View File

@ -16,6 +16,7 @@ from typing import Optional
import attr import attr
from synapse.api.constants import Direction
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -34,7 +35,7 @@ class PaginationConfig:
from_token: Optional[StreamToken] from_token: Optional[StreamToken]
to_token: Optional[StreamToken] to_token: Optional[StreamToken]
direction: str direction: Direction
limit: int limit: int
@classmethod @classmethod
@ -45,9 +46,13 @@ class PaginationConfig:
default_limit: int, default_limit: int,
default_dir: str = "f", default_dir: str = "f",
) -> "PaginationConfig": ) -> "PaginationConfig":
direction = parse_string( direction_str = parse_string(
request, "dir", default=default_dir, allowed_values=["f", "b"] request,
"dir",
default=default_dir,
allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
) )
direction = Direction(direction_str)
from_tok_str = parse_string(request, "from") from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to") to_tok_str = parse_string(request, "to")