From aa9e47e14426b243ebd1f7e9be78fe200e599306 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Thu, 3 Aug 2023 17:18:48 +0200 Subject: [PATCH] Add logging of invalid mxids when persisting events --- synapse/storage/databases/main/events.py | 6 ++++- synapse/types/__init__.py | 33 +++++++++++++++++++++--- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index bd3f14fb71..da8cbc9642 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -51,7 +51,7 @@ from synapse.storage.databases.main.search import SearchEntry from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.storage.util.sequence import SequenceGenerator -from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id +from synapse.types import JsonDict, StateMap, StrCollection, UserID, get_domain_from_id from synapse.util import json_encoder from synapse.util.iterutils import batch_iter, sorted_topologically from synapse.util.stringutils import non_null_str_or_none @@ -393,6 +393,10 @@ class PersistEventsStore: # Once the txn completes, invalidate all of the relevant caches. Note that we do this # up here because it captures all the events_and_contexts before any are removed. for event, _ in events_and_contexts: + sender = UserID.from_string(event.sender) + # The result of `validate` is not used yet because for now we only want to + # log invalid mxids in the wild. + sender.validate(allow_historical_mxids=True) self.store.invalidate_get_event_cache_after_txn(txn, event.event_id) if event.redacts: self.store.invalidate_get_event_cache_after_txn(txn, event.redacts) diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index fdfd465c8d..dffd98863d 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import logging import re import string from typing import ( @@ -62,6 +63,9 @@ if TYPE_CHECKING: from synapse.storage.databases.main import DataStore, PurgeEventsStore from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore +logger = logging.getLogger(__name__) + + # Define a state map type from type/state_key to T (usually an event ID or # event) T = TypeVar("T") @@ -326,6 +330,20 @@ class UserID(DomainSpecificString): SIGIL = "@" + def validate(self, allow_historical_mxids: Optional[bool] = False) -> bool: + is_valid = True + if len(self.to_string().encode("utf-8")) > 255: + logger.warn( + f"User ID {self.to_string()} has more than 255 bytes and is invalid per the spec" + ) + is_valid = False + if contains_invalid_mxid_characters(self.localpart, allow_historical_mxids): + logger.warn( + f"localpart of User ID {self.to_string()} contains invalid characters per the spec" + ) + is_valid = False + return is_valid + @attr.s(slots=True, frozen=True, repr=False) class RoomAlias(DomainSpecificString): @@ -352,22 +370,31 @@ MXID_LOCALPART_ALLOWED_CHARACTERS = set( "_-./=+" + string.ascii_lowercase + string.digits ) +ASCII_PRINTABLE_CHARACTERS = set(string.printable) + # Guest user IDs are purely numeric. GUEST_USER_ID_PATTERN = re.compile(r"^\d+$") -def contains_invalid_mxid_characters(localpart: str) -> bool: +def contains_invalid_mxid_characters( + localpart: str, allow_historical_mxids: Optional[bool] = False +) -> bool: """Check for characters not allowed in an mxid or groupid localpart Args: localpart: the localpart to be checked - use_extended_character_set: True to use the extended allowed characters + allow_legacy_mxids: True to use the extended allowed characters from MSC4009. Returns: True if there are any naughty characters """ - return any(c not in MXID_LOCALPART_ALLOWED_CHARACTERS for c in localpart) + + if allow_historical_mxids: + allowed_characters = ASCII_PRINTABLE_CHARACTERS + else: + allowed_characters = MXID_LOCALPART_ALLOWED_CHARACTERS + return any(c not in allowed_characters for c in localpart) UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")