Convert all namedtuples to attrs. (#11665)

To improve type hints throughout the code.
pull/11674/head
Patrick Cloke 2021-12-30 13:47:12 -05:00 committed by GitHub
parent 07a3b5daba
commit cbd82d0b2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 231 additions and 206 deletions

1
changelog.d/11665.misc Normal file
View File

@ -0,0 +1 @@
Convert `namedtuples` to `attrs`.

View File

@ -351,8 +351,7 @@ class Filter:
True if the event matches the filter.
"""
# We usually get the full "events" as dictionaries coming through,
# except for presence which actually gets passed around as its own
# namedtuple type.
# except for presence which actually gets passed around as its own type.
if isinstance(event, UserPresenceState):
user_id = event.user_id
field_matchers = {

View File

@ -14,10 +14,11 @@
import logging
import os
from collections import namedtuple
from typing import Dict, List, Tuple
from urllib.request import getproxies_environment # type: ignore
import attr
from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST, generate_ip_set
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import JsonDict
@ -44,18 +45,20 @@ THUMBNAIL_SIZE_YAML = """\
HTTP_PROXY_SET_WARNING = """\
The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured."""
ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
MediaStorageProviderConfig = namedtuple(
"MediaStorageProviderConfig",
(
"store_local", # Whether to store newly uploaded local files
"store_remote", # Whether to store newly downloaded remote files
"store_synchronous", # Whether to wait for successful storage for local uploads
),
)
@attr.s(frozen=True, slots=True, auto_attribs=True)
class ThumbnailRequirement:
width: int
height: int
method: str
media_type: str
@attr.s(frozen=True, slots=True, auto_attribs=True)
class MediaStorageProviderConfig:
store_local: bool # Whether to store newly uploaded local files
store_remote: bool # Whether to store newly downloaded remote files
store_synchronous: bool # Whether to wait for successful storage for local uploads
def parse_thumbnail_requirements(
@ -66,11 +69,10 @@ def parse_thumbnail_requirements(
method, and thumbnail media type to precalculate
Args:
thumbnail_sizes(list): List of dicts with "width", "height", and
"method" keys
thumbnail_sizes: List of dicts with "width", "height", and "method" keys
Returns:
Dictionary mapping from media type string to list of
ThumbnailRequirement tuples.
Dictionary mapping from media type string to list of ThumbnailRequirement.
"""
requirements: Dict[str, List[ThumbnailRequirement]] = {}
for size in thumbnail_sizes:

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import namedtuple
from typing import TYPE_CHECKING
from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
@ -104,10 +103,6 @@ class FederationBase:
return pdu
class PduToCheckSig(namedtuple("PduToCheckSig", ["pdu", "sender_domain", "deferreds"])):
pass
async def _check_sigs_on_pdu(
keyring: Keyring, room_version: RoomVersion, pdu: EventBase
) -> None:

View File

@ -30,7 +30,6 @@ Events are replicated via a separate events stream.
"""
import logging
from collections import namedtuple
from typing import (
TYPE_CHECKING,
Dict,
@ -43,6 +42,7 @@ from typing import (
Type,
)
import attr
from sortedcontainers import SortedDict
from synapse.api.presence import UserPresenceState
@ -382,13 +382,11 @@ class BaseFederationRow:
raise NotImplementedError()
class PresenceDestinationsRow(
BaseFederationRow,
namedtuple(
"PresenceDestinationsRow",
("state", "destinations"), # UserPresenceState # list[str]
),
):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceDestinationsRow(BaseFederationRow):
state: UserPresenceState
destinations: List[str]
TypeId = "pd"
@staticmethod
@ -404,17 +402,15 @@ class PresenceDestinationsRow(
buff.presence_destinations.append((self.state, self.destinations))
class KeyedEduRow(
BaseFederationRow,
namedtuple(
"KeyedEduRow",
("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu
),
):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class KeyedEduRow(BaseFederationRow):
"""Streams EDUs that have an associated key that is ued to clobber. For example,
typing EDUs clobber based on room_id.
"""
key: Tuple[str, ...] # the edu key passed to send_edu
edu: Edu
TypeId = "k"
@staticmethod
@ -428,9 +424,12 @@ class KeyedEduRow(
buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EduRow(BaseFederationRow):
"""Streams EDUs that don't have keys. See KeyedEduRow"""
edu: Edu
TypeId = "e"
@staticmethod
@ -453,14 +452,14 @@ _rowtypes: Tuple[Type[BaseFederationRow], ...] = (
TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
ParsedFederationStreamData = namedtuple(
"ParsedFederationStreamData",
(
"presence_destinations", # list of tuples of UserPresenceState and destinations
"keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu]
),
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ParsedFederationStreamData:
# list of tuples of UserPresenceState and destinations
presence_destinations: List[Tuple[UserPresenceState, List[str]]]
# dict of destination -> { key -> Edu }
keyed_edus: Dict[str, Dict[Tuple[str, ...], Edu]]
# dict of destination -> [Edu]
edus: Dict[str, List[Edu]]
def process_rows_for_federation(

View File

@ -462,9 +462,9 @@ class ApplicationServicesHandler:
Args:
room_alias: The room alias to query.
Returns:
namedtuple: with keys "room_id" and "servers" or None if no
association can be found.
RoomAliasMapping or None if no association can be found.
"""
room_alias_str = room_alias.to_string()
services = self.store.get_app_services()

View File

@ -278,13 +278,15 @@ class DirectoryHandler:
users = await self.store.get_users_in_room(room_id)
extra_servers = {get_domain_from_id(u) for u in users}
servers = set(extra_servers) | set(servers)
servers_set = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first.
if self.server_name in servers:
servers = [self.server_name] + [s for s in servers if s != self.server_name]
if self.server_name in servers_set:
servers = [self.server_name] + [
s for s in servers_set if s != self.server_name
]
else:
servers = list(servers)
servers = list(servers_set)
return {"room_id": room_id, "servers": servers}

View File

@ -13,9 +13,9 @@
# limitations under the License.
import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Any, Optional, Tuple
import attr
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
@ -474,16 +474,12 @@ class RoomListHandler:
)
class RoomListNextBatch(
namedtuple(
"RoomListNextBatch",
(
"last_joined_members", # The count to get rooms after/before
"last_room_id", # The room_id to get rooms after/before
"direction_is_forward", # Bool if this is a next_batch, false if prev_batch
),
)
):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RoomListNextBatch:
last_joined_members: int # The count to get rooms after/before
last_room_id: str # The room_id to get rooms after/before
direction_is_forward: bool # True if this is a next_batch, false if prev_batch
KEY_DICT = {
"last_joined_members": "m",
"last_room_id": "r",
@ -502,12 +498,12 @@ class RoomListNextBatch(
def to_token(self) -> str:
return encode_base64(
msgpack.dumps(
{self.KEY_DICT[key]: val for key, val in self._asdict().items()}
{self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()}
)
)
def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch":
return self._replace(**kwds)
return attr.evolve(self, **kwds)
def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:

View File

@ -13,9 +13,10 @@
# limitations under the License.
import logging
import random
from collections import namedtuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import (
@ -37,7 +38,10 @@ logger = logging.getLogger(__name__)
# A tiny object useful for storing a user's membership in a room, as a mapping
# key
RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RoomMember:
room_id: str
user_id: str
# How often we expect remote servers to resend us presence.
@ -119,7 +123,7 @@ class FollowerTypingHandler:
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
def is_typing(self, member: RoomMember) -> bool:
return member.user_id in self._room_typing.get(member.room_id, [])
return member.user_id in self._room_typing.get(member.room_id, set())
async def _push_remote(self, member: RoomMember, typing: bool) -> None:
if not self.federation:
@ -166,9 +170,9 @@ class FollowerTypingHandler:
for row in rows:
self._room_serials[row.room_id] = token
prev_typing = set(self._room_typing.get(row.room_id, []))
prev_typing = self._room_typing.get(row.room_id, set())
now_typing = set(row.user_ids)
self._room_typing[row.room_id] = row.user_ids
self._room_typing[row.room_id] = now_typing
if self.federation:
run_as_background_process(

View File

@ -14,7 +14,6 @@
# limitations under the License.
import abc
import collections
import html
import logging
import types
@ -37,6 +36,7 @@ from typing import (
Union,
)
import attr
import jinja2
from canonicaljson import encode_canonical_json
from typing_extensions import Protocol
@ -354,9 +354,11 @@ class DirectServeJsonResource(_AsyncResource):
return_json_error(f, request)
_PathEntry = collections.namedtuple(
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _PathEntry:
pattern: Pattern
callback: ServletCallback
servlet_classname: str
class JsonResource(DirectServeJsonResource):

View File

@ -15,7 +15,6 @@
import heapq
import logging
from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
@ -30,6 +29,7 @@ from typing import (
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -226,17 +226,14 @@ class BackfillStream(Stream):
or it went from being an outlier to not.
"""
BackfillStreamRow = namedtuple(
"BackfillStreamRow",
(
"event_id", # str
"room_id", # str
"type", # str
"state_key", # str, optional
"redacts", # str, optional
"relates_to", # str, optional
),
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class BackfillStreamRow:
event_id: str
room_id: str
type: str
state_key: Optional[str]
redacts: Optional[str]
relates_to: Optional[str]
NAME = "backfill"
ROW_TYPE = BackfillStreamRow
@ -256,18 +253,15 @@ class BackfillStream(Stream):
class PresenceStream(Stream):
PresenceStreamRow = namedtuple(
"PresenceStreamRow",
(
"user_id", # str
"state", # str
"last_active_ts", # int
"last_federation_update_ts", # int
"last_user_sync_ts", # int
"status_msg", # str
"currently_active", # bool
),
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceStreamRow:
user_id: str
state: str
last_active_ts: int
last_federation_update_ts: int
last_user_sync_ts: int
status_msg: str
currently_active: bool
NAME = "presence"
ROW_TYPE = PresenceStreamRow
@ -302,7 +296,7 @@ class PresenceFederationStream(Stream):
send.
"""
@attr.s(slots=True, auto_attribs=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceFederationStreamRow:
destination: str
user_id: str
@ -320,9 +314,10 @@ class PresenceFederationStream(Stream):
class TypingStream(Stream):
TypingStreamRow = namedtuple(
"TypingStreamRow", ("room_id", "user_ids") # str # list(str)
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class TypingStreamRow:
room_id: str
user_ids: List[str]
NAME = "typing"
ROW_TYPE = TypingStreamRow
@ -348,16 +343,13 @@ class TypingStream(Stream):
class ReceiptsStream(Stream):
ReceiptsStreamRow = namedtuple(
"ReceiptsStreamRow",
(
"room_id", # str
"receipt_type", # str
"user_id", # str
"event_id", # str
"data", # dict
),
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ReceiptsStreamRow:
room_id: str
receipt_type: str
user_id: str
event_id: str
data: dict
NAME = "receipts"
ROW_TYPE = ReceiptsStreamRow
@ -374,7 +366,9 @@ class ReceiptsStream(Stream):
class PushRulesStream(Stream):
"""A user has changed their push rules"""
PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PushRulesStreamRow:
user_id: str
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow
@ -396,10 +390,12 @@ class PushRulesStream(Stream):
class PushersStream(Stream):
"""A user has added/changed/removed a pusher"""
PushersStreamRow = namedtuple(
"PushersStreamRow",
("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PushersStreamRow:
user_id: str
app_id: str
pushkey: str
deleted: bool
NAME = "pushers"
ROW_TYPE = PushersStreamRow
@ -419,7 +415,7 @@ class CachesStream(Stream):
the cache on the workers
"""
@attr.s(slots=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class CachesStreamRow:
"""Stream to inform workers they should invalidate their cache.
@ -430,9 +426,9 @@ class CachesStream(Stream):
invalidation_ts: Timestamp of when the invalidation took place.
"""
cache_func = attr.ib(type=str)
keys = attr.ib(type=Optional[List[Any]])
invalidation_ts = attr.ib(type=int)
cache_func: str
keys: Optional[List[Any]]
invalidation_ts: int
NAME = "caches"
ROW_TYPE = CachesStreamRow
@ -451,9 +447,9 @@ class DeviceListsStream(Stream):
told about a device update.
"""
@attr.s(slots=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceListsStreamRow:
entity = attr.ib(type=str)
entity: str
NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow
@ -470,7 +466,9 @@ class DeviceListsStream(Stream):
class ToDeviceStream(Stream):
"""New to_device messages for a client"""
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ToDeviceStreamRow:
entity: str
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow
@ -487,9 +485,11 @@ class ToDeviceStream(Stream):
class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room"""
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class TagAccountDataStreamRow:
user_id: str
room_id: str
data: JsonDict
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
@ -506,10 +506,11 @@ class TagAccountDataStream(Stream):
class AccountDataStream(Stream):
"""Global or per room account data was changed"""
AccountDataStreamRow = namedtuple(
"AccountDataStreamRow",
("user_id", "room_id", "data_type"), # str # Optional[str] # str
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class AccountDataStreamRow:
user_id: str
room_id: Optional[str]
data_type: str
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
@ -573,10 +574,12 @@ class AccountDataStream(Stream):
class GroupServerStream(Stream):
GroupsStreamRow = namedtuple(
"GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class GroupsStreamRow:
group_id: str
user_id: str
type: str
content: JsonDict
NAME = "groups"
ROW_TYPE = GroupsStreamRow
@ -593,7 +596,9 @@ class GroupServerStream(Stream):
class UserSignatureStream(Stream):
"""A user has signed their own device with their user-signing key"""
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
@attr.s(slots=True, frozen=True, auto_attribs=True)
class UserSignatureStreamRow:
user_id: str
NAME = "user_signature"
ROW_TYPE = UserSignatureStreamRow

View File

@ -12,14 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Tuple
import attr
from synapse.replication.tcp.streams._base import (
Stream,
current_token_without_instance,
make_http_update_function,
)
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@ -30,13 +32,10 @@ class FederationStream(Stream):
sending disabled.
"""
FederationStreamRow = namedtuple(
"FederationStreamRow",
(
"type", # str, the type of data as defined in the BaseFederationRows
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
),
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class FederationStreamRow:
type: str # the type of data as defined in the BaseFederationRows
data: JsonDict # serialization of a federation.send_queue.BaseFederationRow
NAME = "federation"
ROW_TYPE = FederationStreamRow

View File

@ -739,14 +739,21 @@ class MediaRepository:
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails: Dict[Tuple[int, int, str], str] = {}
for r_width, r_height, r_method, r_type in requirements:
if r_method == "crop":
thumbnails.setdefault((r_width, r_height, r_type), r_method)
elif r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
for requirement in requirements:
if requirement.method == "crop":
thumbnails.setdefault(
(requirement.width, requirement.height, requirement.media_type),
requirement.method,
)
elif requirement.method == "scale":
t_width, t_height = thumbnailer.aspect(
requirement.width, requirement.height
)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
thumbnails[(t_width, t_height, r_type)] = r_method
thumbnails[
(t_width, t_height, requirement.media_type)
] = requirement.method
# Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.items():

View File

@ -14,7 +14,7 @@
# limitations under the License.
import heapq
import logging
from collections import defaultdict, namedtuple
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
@ -69,9 +69,6 @@ state_groups_histogram = Histogram(
)
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
EVICTION_TIMEOUT_SECONDS = 60 * 60

View File

@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from typing import Iterable, List, Optional, Tuple
import attr
from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RoomAliasMapping:
room_id: str
room_alias: str
servers: List[str]
class DirectoryWorkerStore(CacheInvalidationWorkerStore):

View File

@ -1976,14 +1976,17 @@ class PersistEventsStore:
txn, self.store.get_retention_policy_for_room, (event.room_id,)
)
def store_event_search_txn(self, txn, event, key, value):
def store_event_search_txn(
self, txn: LoggingTransaction, event: EventBase, key: str, value: str
) -> None:
"""Add event to the search table
Args:
txn (cursor):
event (EventBase):
key (str):
value (str):
txn: The database transaction.
event: The event being added to the search table.
key: A key describing the search value (one of "content.name",
"content.topic", or "content.body")
value: The value from the event's content.
"""
self.store.store_search_entries_txn(
txn,

View File

@ -13,11 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
from abc import abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, cast
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)
import attr
from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
@ -43,9 +54,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
RatelimitOverride = collections.namedtuple(
"RatelimitOverride", ("messages_per_second", "burst_count")
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RatelimitOverride:
messages_per_second: int
burst_count: int
class RoomSortOrder(Enum):
@ -207,6 +219,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
assert network_tuple.network_id is not None
query_args.append(network_tuple.network_id)
else:
published_sql = """
@ -284,7 +297,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
"""
where_clauses = []
query_args = []
query_args: List[Union[str, int]] = []
if network_tuple:
if network_tuple.appservice_id:
@ -293,6 +306,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
WHERE appservice_id = ? AND network_id = ?
"""
query_args.append(network_tuple.appservice_id)
assert network_tuple.network_id is not None
query_args.append(network_tuple.network_id)
else:
published_sql = """

View File

@ -14,9 +14,10 @@
import logging
import re
from collections import namedtuple
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set
import attr
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
@ -33,10 +34,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
SearchEntry = namedtuple(
"SearchEntry",
["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"],
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class SearchEntry:
key: str
value: str
event_id: str
room_id: str
stream_ordering: Optional[int]
origin_server_ts: int
def _clean_value_for_search(value: str) -> str:

View File

@ -14,7 +14,6 @@
# limitations under the License.
import collections.abc
import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Iterable, Optional, Set
from synapse.api.constants import EventTypes, Membership
@ -43,19 +42,6 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
class _GetStateGroupDelta(
namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
):
"""Return type of get_state_group_delta that implements __len__, which lets
us use the itrable flag when caching
"""
__slots__ = []
def __len__(self):
return len(self.delta_ids) if self.delta_ids else 0
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers."""

View File

@ -36,9 +36,9 @@ what sort order was used:
"""
import abc
import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
import attr
from frozendict import frozendict
from twisted.internet import defer
@ -74,9 +74,11 @@ _TOPOLOGICAL_TOKEN = "topological"
# Used as return values for pagination APIs
_EventDictReturn = namedtuple(
"_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering")
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn:
event_id: str
topological_ordering: Optional[int]
stream_ordering: int
def generate_pagination_where_clause(
@ -825,7 +827,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
for event, row in zip(events, rows):
stream = row.stream_ordering
if topo_order and row.topological_ordering:
topo = row.topological_ordering
topo: Optional[int] = row.topological_ordering
else:
topo = None
internal = event.internal_metadata

View File

@ -15,7 +15,6 @@
import abc
import re
import string
from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
@ -227,8 +226,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
localpart = attr.ib(type=str)
domain = attr.ib(type=str)
# Because this class is a namedtuple of strings and booleans, it is deeply
# immutable.
# Because this is a frozen class, it is deeply immutable.
def __copy__(self):
return self
@ -708,16 +706,18 @@ class PersistedEventPosition:
return RoomStreamToken(None, self.stream)
class ThirdPartyInstanceID(
namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThirdPartyInstanceID:
appservice_id: Optional[str]
network_id: Optional[str]
# Deny iteration because it will bite you if you try to create a singleton
# set by:
# users = set(user)
def __iter__(self):
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
# Because this class is a namedtuple of strings, it is deeply immutable.
# Because this class is a frozen class, it is deeply immutable.
def __copy__(self):
return self
@ -725,22 +725,18 @@ class ThirdPartyInstanceID(
return self
@classmethod
def from_string(cls, s):
def from_string(cls, s: str) -> "ThirdPartyInstanceID":
bits = s.split("|", 2)
if len(bits) != 2:
raise SynapseError(400, "Invalid ID %r" % (s,))
return cls(appservice_id=bits[0], network_id=bits[1])
def to_string(self):
def to_string(self) -> str:
return "%s|%s" % (self.appservice_id, self.network_id)
__str__ = to_string
@classmethod
def create(cls, appservice_id, network_id):
return cls(appservice_id=appservice_id, network_id=network_id)
@attr.s(slots=True)
class ReadReceipt:

View File

@ -62,7 +62,11 @@ class FederationAckTestCase(HomeserverTestCase):
"federation",
"master",
token=10,
rows=[FederationStream.FederationStreamRow(type="x", data=[1, 2, 3])],
rows=[
FederationStream.FederationStreamRow(
type="x", data={"test": [1, 2, 3]}
)
],
)
)