Fix message duplication if something goes wrong after persisting the event (#8476)
Should fix #3365.pull/8547/head
parent
a9a8f29729
commit
b2486f6656
|
@ -0,0 +1 @@
|
|||
Fix message duplication if something goes wrong after persisting the event.
|
|
@ -2966,17 +2966,20 @@ class FederationHandler(BaseHandler):
|
|||
return result["max_stream_id"]
|
||||
else:
|
||||
assert self.storage.persistence
|
||||
max_stream_token = await self.storage.persistence.persist_events(
|
||||
|
||||
# Note that this returns the events that were persisted, which may not be
|
||||
# the same as were passed in if some were deduplicated due to transaction IDs.
|
||||
events, max_stream_token = await self.storage.persistence.persist_events(
|
||||
event_and_contexts, backfilled=backfilled
|
||||
)
|
||||
|
||||
if self._ephemeral_messages_enabled:
|
||||
for (event, context) in event_and_contexts:
|
||||
for event in events:
|
||||
# If there's an expiry timestamp on the event, schedule its expiry.
|
||||
self._message_handler.maybe_schedule_expiry(event)
|
||||
|
||||
if not backfilled: # Never notify for backfilled events
|
||||
for event, _ in event_and_contexts:
|
||||
for event in events:
|
||||
await self._notify_persisted_event(event, max_stream_token)
|
||||
|
||||
return max_stream_token.stream
|
||||
|
|
|
@ -689,7 +689,7 @@ class EventCreationHandler:
|
|||
send this event.
|
||||
|
||||
Returns:
|
||||
The event, and its stream ordering (if state event deduplication happened,
|
||||
The event, and its stream ordering (if deduplication happened,
|
||||
the previous, duplicate event).
|
||||
|
||||
Raises:
|
||||
|
@ -712,6 +712,19 @@ class EventCreationHandler:
|
|||
# extremities to pile up, which in turn leads to state resolution
|
||||
# taking longer.
|
||||
with (await self.limiter.queue(event_dict["room_id"])):
|
||||
if txn_id and requester.access_token_id:
|
||||
existing_event_id = await self.store.get_event_id_from_transaction_id(
|
||||
event_dict["room_id"],
|
||||
requester.user.to_string(),
|
||||
requester.access_token_id,
|
||||
txn_id,
|
||||
)
|
||||
if existing_event_id:
|
||||
event = await self.store.get_event(existing_event_id)
|
||||
# we know it was persisted, so must have a stream ordering
|
||||
assert event.internal_metadata.stream_ordering
|
||||
return event, event.internal_metadata.stream_ordering
|
||||
|
||||
event, context = await self.create_event(
|
||||
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
|
||||
)
|
||||
|
@ -913,10 +926,20 @@ class EventCreationHandler:
|
|||
extra_users=extra_users,
|
||||
)
|
||||
stream_id = result["stream_id"]
|
||||
event.internal_metadata.stream_ordering = stream_id
|
||||
event_id = result["event_id"]
|
||||
if event_id != event.event_id:
|
||||
# If we get a different event back then it means that its
|
||||
# been de-duplicated, so we replace the given event with the
|
||||
# one already persisted.
|
||||
event = await self.store.get_event(event_id)
|
||||
else:
|
||||
# If we newly persisted the event then we need to update its
|
||||
# stream_ordering entry manually (as it was persisted on
|
||||
# another worker).
|
||||
event.internal_metadata.stream_ordering = stream_id
|
||||
return event
|
||||
|
||||
stream_id = await self.persist_and_notify_client_event(
|
||||
event = await self.persist_and_notify_client_event(
|
||||
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
|
||||
)
|
||||
|
||||
|
@ -965,11 +988,16 @@ class EventCreationHandler:
|
|||
context: EventContext,
|
||||
ratelimit: bool = True,
|
||||
extra_users: List[UserID] = [],
|
||||
) -> int:
|
||||
) -> EventBase:
|
||||
"""Called when we have fully built the event, have already
|
||||
calculated the push actions for the event, and checked auth.
|
||||
|
||||
This should only be run on the instance in charge of persisting events.
|
||||
|
||||
Returns:
|
||||
The persisted event. This may be different than the given event if
|
||||
it was de-duplicated (e.g. because we had already persisted an
|
||||
event with the same transaction ID.)
|
||||
"""
|
||||
assert self.storage.persistence is not None
|
||||
assert self._events_shard_config.should_handle(
|
||||
|
@ -1137,9 +1165,13 @@ class EventCreationHandler:
|
|||
if prev_state_ids:
|
||||
raise AuthError(403, "Changing the room create event is forbidden")
|
||||
|
||||
event_pos, max_stream_token = await self.storage.persistence.persist_event(
|
||||
event, context=context
|
||||
)
|
||||
# Note that this returns the event that was persisted, which may not be
|
||||
# the same as we passed in if it was deduplicated due transaction IDs.
|
||||
(
|
||||
event,
|
||||
event_pos,
|
||||
max_stream_token,
|
||||
) = await self.storage.persistence.persist_event(event, context=context)
|
||||
|
||||
if self._ephemeral_events_enabled:
|
||||
# If there's an expiry timestamp on the event, schedule its expiry.
|
||||
|
@ -1160,7 +1192,7 @@ class EventCreationHandler:
|
|||
# matters as sometimes presence code can take a while.
|
||||
run_in_background(self._bump_active_time, requester.user)
|
||||
|
||||
return event_pos.stream
|
||||
return event
|
||||
|
||||
async def _bump_active_time(self, user: UserID) -> None:
|
||||
try:
|
||||
|
|
|
@ -171,6 +171,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
if requester.is_guest:
|
||||
content["kind"] = "guest"
|
||||
|
||||
# Check if we already have an event with a matching transaction ID. (We
|
||||
# do this check just before we persist an event as well, but may as well
|
||||
# do it up front for efficiency.)
|
||||
if txn_id and requester.access_token_id:
|
||||
existing_event_id = await self.store.get_event_id_from_transaction_id(
|
||||
room_id, requester.user.to_string(), requester.access_token_id, txn_id,
|
||||
)
|
||||
if existing_event_id:
|
||||
event_pos = await self.store.get_position_for_event(existing_event_id)
|
||||
return existing_event_id, event_pos.stream
|
||||
|
||||
event, context = await self.event_creation_handler.create_event(
|
||||
requester,
|
||||
{
|
||||
|
@ -679,7 +690,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
|||
if is_blocked:
|
||||
raise SynapseError(403, "This room has been blocked on this server")
|
||||
|
||||
await self.event_creation_handler.handle_new_client_event(
|
||||
event = await self.event_creation_handler.handle_new_client_event(
|
||||
requester, event, context, extra_users=[target_user], ratelimit=ratelimit
|
||||
)
|
||||
|
||||
|
|
|
@ -46,6 +46,12 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||
"ratelimit": true,
|
||||
"extra_users": [],
|
||||
}
|
||||
|
||||
200 OK
|
||||
|
||||
{ "stream_id": 12345, "event_id": "$abcdef..." }
|
||||
|
||||
The returned event ID may not match the sent event if it was deduplicated.
|
||||
"""
|
||||
|
||||
NAME = "send_event"
|
||||
|
@ -116,11 +122,17 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
|||
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
|
||||
)
|
||||
|
||||
stream_id = await self.event_creation_handler.persist_and_notify_client_event(
|
||||
event = await self.event_creation_handler.persist_and_notify_client_event(
|
||||
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
|
||||
)
|
||||
|
||||
return 200, {"stream_id": stream_id}
|
||||
return (
|
||||
200,
|
||||
{
|
||||
"stream_id": event.internal_metadata.stream_ordering,
|
||||
"event_id": event.event_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
|
|
|
@ -361,6 +361,8 @@ class PersistEventsStore:
|
|||
|
||||
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
|
||||
|
||||
self._persist_transaction_ids_txn(txn, events_and_contexts)
|
||||
|
||||
# Insert into event_to_state_groups.
|
||||
self._store_event_state_mappings_txn(txn, events_and_contexts)
|
||||
|
||||
|
@ -405,6 +407,35 @@ class PersistEventsStore:
|
|||
# room_memberships, where applicable.
|
||||
self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
|
||||
|
||||
def _persist_transaction_ids_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
):
|
||||
"""Persist the mapping from transaction IDs to event IDs (if defined).
|
||||
"""
|
||||
|
||||
to_insert = []
|
||||
for event, _ in events_and_contexts:
|
||||
token_id = getattr(event.internal_metadata, "token_id", None)
|
||||
txn_id = getattr(event.internal_metadata, "txn_id", None)
|
||||
if token_id and txn_id:
|
||||
to_insert.append(
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"user_id": event.sender,
|
||||
"token_id": token_id,
|
||||
"txn_id": txn_id,
|
||||
"inserted_ts": self._clock.time_msec(),
|
||||
}
|
||||
)
|
||||
|
||||
if to_insert:
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn, table="event_txn_id", values=to_insert,
|
||||
)
|
||||
|
||||
def _update_current_state_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import threading
|
||||
|
@ -137,6 +136,15 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
db_conn, "events", "stream_ordering", step=-1
|
||||
)
|
||||
|
||||
if not hs.config.worker.worker_app:
|
||||
# We periodically clean out old transaction ID mappings
|
||||
self._clock.looping_call(
|
||||
run_as_background_process,
|
||||
5 * 60 * 1000,
|
||||
"_cleanup_old_transaction_ids",
|
||||
self._cleanup_old_transaction_ids,
|
||||
)
|
||||
|
||||
self._get_event_cache = Cache(
|
||||
"*getEvent*",
|
||||
keylen=3,
|
||||
|
@ -1308,3 +1316,76 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
return await self.db_pool.runInteraction(
|
||||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||
)
|
||||
|
||||
async def get_event_id_from_transaction_id(
|
||||
self, room_id: str, user_id: str, token_id: int, txn_id: str
|
||||
) -> Optional[str]:
|
||||
"""Look up if we have already persisted an event for the transaction ID,
|
||||
returning the event ID if so.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="event_txn_id",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"user_id": user_id,
|
||||
"token_id": token_id,
|
||||
"txn_id": txn_id,
|
||||
},
|
||||
retcol="event_id",
|
||||
allow_none=True,
|
||||
desc="get_event_id_from_transaction_id",
|
||||
)
|
||||
|
||||
async def get_already_persisted_events(
|
||||
self, events: Iterable[EventBase]
|
||||
) -> Dict[str, str]:
|
||||
"""Look up if we have already persisted an event for the transaction ID,
|
||||
returning a mapping from event ID in the given list to the event ID of
|
||||
an existing event.
|
||||
|
||||
Also checks if there are duplicates in the given events, if there are
|
||||
will map duplicates to the *first* event.
|
||||
"""
|
||||
|
||||
mapping = {}
|
||||
txn_id_to_event = {} # type: Dict[Tuple[str, int, str], str]
|
||||
|
||||
for event in events:
|
||||
token_id = getattr(event.internal_metadata, "token_id", None)
|
||||
txn_id = getattr(event.internal_metadata, "txn_id", None)
|
||||
|
||||
if token_id and txn_id:
|
||||
# Check if this is a duplicate of an event in the given events.
|
||||
existing = txn_id_to_event.get((event.room_id, token_id, txn_id))
|
||||
if existing:
|
||||
mapping[event.event_id] = existing
|
||||
continue
|
||||
|
||||
# Check if this is a duplicate of an event we've already
|
||||
# persisted.
|
||||
existing = await self.get_event_id_from_transaction_id(
|
||||
event.room_id, event.sender, token_id, txn_id
|
||||
)
|
||||
if existing:
|
||||
mapping[event.event_id] = existing
|
||||
txn_id_to_event[(event.room_id, token_id, txn_id)] = existing
|
||||
else:
|
||||
txn_id_to_event[(event.room_id, token_id, txn_id)] = event.event_id
|
||||
|
||||
return mapping
|
||||
|
||||
async def _cleanup_old_transaction_ids(self):
|
||||
"""Cleans out transaction id mappings older than 24hrs.
|
||||
"""
|
||||
|
||||
def _cleanup_old_transaction_ids_txn(txn):
|
||||
sql = """
|
||||
DELETE FROM event_txn_id
|
||||
WHERE inserted_ts < ?
|
||||
"""
|
||||
one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000
|
||||
txn.execute(sql, (one_day_ago,))
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
|
||||
)
|
||||
|
|
|
@ -1003,7 +1003,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
token: str,
|
||||
device_id: Optional[str],
|
||||
valid_until_ms: Optional[int],
|
||||
) -> None:
|
||||
) -> int:
|
||||
"""Adds an access token for the given user.
|
||||
|
||||
Args:
|
||||
|
@ -1013,6 +1013,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
valid_until_ms: when the token is valid until. None for no expiry.
|
||||
Raises:
|
||||
StoreError if there was a problem adding this.
|
||||
Returns:
|
||||
The token ID
|
||||
"""
|
||||
next_id = self._access_tokens_id_gen.get_next()
|
||||
|
||||
|
@ -1028,6 +1030,8 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||
desc="add_access_token_to_user",
|
||||
)
|
||||
|
||||
return next_id
|
||||
|
||||
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
|
||||
old_device_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn, "access_tokens", {"token": token}, "device_id"
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
-- A map of recent events persisted with transaction IDs. Used to deduplicate
|
||||
-- send event requests with the same transaction ID.
|
||||
--
|
||||
-- Note: transaction IDs are scoped to the room ID/user ID/access token that was
|
||||
-- used to make the request.
|
||||
--
|
||||
-- Note: The foreign key constraints are ON DELETE CASCADE, as if we delete the
|
||||
-- events or access token we don't want to try and de-duplicate the event.
|
||||
CREATE TABLE IF NOT EXISTS event_txn_id (
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
token_id BIGINT NOT NULL,
|
||||
txn_id TEXT NOT NULL,
|
||||
inserted_ts BIGINT NOT NULL,
|
||||
FOREIGN KEY (event_id)
|
||||
REFERENCES events (event_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (token_id)
|
||||
REFERENCES access_tokens (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_event_id ON event_txn_id(event_id);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS event_txn_id_txn_id ON event_txn_id(room_id, user_id, token_id, txn_id);
|
||||
CREATE INDEX IF NOT EXISTS event_txn_id_ts ON event_txn_id(inserted_ts);
|
|
@ -96,7 +96,9 @@ class _EventPeristenceQueue:
|
|||
|
||||
Returns:
|
||||
defer.Deferred: a deferred which will resolve once the events are
|
||||
persisted. Runs its callbacks *without* a logcontext.
|
||||
persisted. Runs its callbacks *without* a logcontext. The result
|
||||
is the same as that returned by the callback passed to
|
||||
`handle_queue`.
|
||||
"""
|
||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||
if queue:
|
||||
|
@ -199,7 +201,7 @@ class EventsPersistenceStorage:
|
|||
self,
|
||||
events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool = False,
|
||||
) -> RoomStreamToken:
|
||||
) -> Tuple[List[EventBase], RoomStreamToken]:
|
||||
"""
|
||||
Write events to the database
|
||||
Args:
|
||||
|
@ -209,7 +211,11 @@ class EventsPersistenceStorage:
|
|||
which might update the current state etc.
|
||||
|
||||
Returns:
|
||||
the stream ordering of the latest persisted event
|
||||
List of events persisted, the current position room stream position.
|
||||
The list of events persisted may not be the same as those passed in
|
||||
if they were deduplicated due to an event already existing that
|
||||
matched the transcation ID; the existing event is returned in such
|
||||
a case.
|
||||
"""
|
||||
partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
|
||||
for event, ctx in events_and_contexts:
|
||||
|
@ -225,19 +231,41 @@ class EventsPersistenceStorage:
|
|||
for room_id in partitioned:
|
||||
self._maybe_start_persisting(room_id)
|
||||
|
||||
await make_deferred_yieldable(
|
||||
# Each deferred returns a map from event ID to existing event ID if the
|
||||
# event was deduplicated. (The dict may also include other entries if
|
||||
# the event was persisted in a batch with other events).
|
||||
#
|
||||
# Since we use `defer.gatherResults` we need to merge the returned list
|
||||
# of dicts into one.
|
||||
ret_vals = await make_deferred_yieldable(
|
||||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
replaced_events = {}
|
||||
for d in ret_vals:
|
||||
replaced_events.update(d)
|
||||
|
||||
return self.main_store.get_room_max_token()
|
||||
events = []
|
||||
for event, _ in events_and_contexts:
|
||||
existing_event_id = replaced_events.get(event.event_id)
|
||||
if existing_event_id:
|
||||
events.append(await self.main_store.get_event(existing_event_id))
|
||||
else:
|
||||
events.append(event)
|
||||
|
||||
return (
|
||||
events,
|
||||
self.main_store.get_room_max_token(),
|
||||
)
|
||||
|
||||
async def persist_event(
|
||||
self, event: EventBase, context: EventContext, backfilled: bool = False
|
||||
) -> Tuple[PersistedEventPosition, RoomStreamToken]:
|
||||
) -> Tuple[EventBase, PersistedEventPosition, RoomStreamToken]:
|
||||
"""
|
||||
Returns:
|
||||
The stream ordering of `event`, and the stream ordering of the
|
||||
latest persisted event
|
||||
The event, stream ordering of `event`, and the stream ordering of the
|
||||
latest persisted event. The returned event may not match the given
|
||||
event if it was deduplicated due to an existing event matching the
|
||||
transaction ID.
|
||||
"""
|
||||
deferred = self._event_persist_queue.add_to_queue(
|
||||
event.room_id, [(event, context)], backfilled=backfilled
|
||||
|
@ -245,19 +273,33 @@ class EventsPersistenceStorage:
|
|||
|
||||
self._maybe_start_persisting(event.room_id)
|
||||
|
||||
await make_deferred_yieldable(deferred)
|
||||
# The deferred returns a map from event ID to existing event ID if the
|
||||
# event was deduplicated. (The dict may also include other entries if
|
||||
# the event was persisted in a batch with other events.)
|
||||
replaced_events = await make_deferred_yieldable(deferred)
|
||||
replaced_event = replaced_events.get(event.event_id)
|
||||
if replaced_event:
|
||||
event = await self.main_store.get_event(replaced_event)
|
||||
|
||||
event_stream_id = event.internal_metadata.stream_ordering
|
||||
# stream ordering should have been assigned by now
|
||||
assert event_stream_id
|
||||
|
||||
pos = PersistedEventPosition(self._instance_name, event_stream_id)
|
||||
return pos, self.main_store.get_room_max_token()
|
||||
return event, pos, self.main_store.get_room_max_token()
|
||||
|
||||
def _maybe_start_persisting(self, room_id: str):
|
||||
"""Pokes the `_event_persist_queue` to start handling new items in the
|
||||
queue, if not already in progress.
|
||||
|
||||
Causes the deferreds returned by `add_to_queue` to resolve with: a
|
||||
dictionary of event ID to event ID we didn't persist as we already had
|
||||
another event persisted with the same TXN ID.
|
||||
"""
|
||||
|
||||
async def persisting_queue(item):
|
||||
with Measure(self._clock, "persist_events"):
|
||||
await self._persist_events(
|
||||
return await self._persist_events(
|
||||
item.events_and_contexts, backfilled=item.backfilled
|
||||
)
|
||||
|
||||
|
@ -267,12 +309,38 @@ class EventsPersistenceStorage:
|
|||
self,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool = False,
|
||||
):
|
||||
) -> Dict[str, str]:
|
||||
"""Calculates the change to current state and forward extremities, and
|
||||
persists the given events and with those updates.
|
||||
|
||||
Returns:
|
||||
A dictionary of event ID to event ID we didn't persist as we already
|
||||
had another event persisted with the same TXN ID.
|
||||
"""
|
||||
replaced_events = {} # type: Dict[str, str]
|
||||
if not events_and_contexts:
|
||||
return
|
||||
return replaced_events
|
||||
|
||||
# Check if any of the events have a transaction ID that has already been
|
||||
# persisted, and if so we don't persist it again.
|
||||
#
|
||||
# We should have checked this a long time before we get here, but it's
|
||||
# possible that different send event requests race in such a way that
|
||||
# they both pass the earlier checks. Checking here isn't racey as we can
|
||||
# have only one `_persist_events` per room being called at a time.
|
||||
replaced_events = await self.main_store.get_already_persisted_events(
|
||||
(event for event, _ in events_and_contexts)
|
||||
)
|
||||
|
||||
if replaced_events:
|
||||
events_and_contexts = [
|
||||
(e, ctx)
|
||||
for e, ctx in events_and_contexts
|
||||
if e.event_id not in replaced_events
|
||||
]
|
||||
|
||||
if not events_and_contexts:
|
||||
return replaced_events
|
||||
|
||||
chunks = [
|
||||
events_and_contexts[x : x + 100]
|
||||
|
@ -441,6 +509,8 @@ class EventsPersistenceStorage:
|
|||
|
||||
await self._handle_potentially_left_users(potentially_left_users)
|
||||
|
||||
return replaced_events
|
||||
|
||||
async def _calculate_new_extremities(
|
||||
self,
|
||||
room_id: str,
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client.v1 import login, room
|
||||
from synapse.types import create_requester
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from tests import unittest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.handler = self.hs.get_event_creation_handler()
|
||||
self.persist_event_storage = self.hs.get_storage().persistence
|
||||
|
||||
self.user_id = self.register_user("tester", "foobar")
|
||||
self.access_token = self.login("tester", "foobar")
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
||||
|
||||
self.info = self.get_success(
|
||||
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
|
||||
)
|
||||
self.token_id = self.info["token_id"]
|
||||
|
||||
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
|
||||
|
||||
def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
|
||||
"""Create a new event with the given transaction ID. All events produced
|
||||
by this method will be considered duplicates.
|
||||
"""
|
||||
|
||||
# We create a new event with a random body, as otherwise we'll produce
|
||||
# *exactly* the same event with the same hash, and so same event ID.
|
||||
return self.get_success(
|
||||
self.handler.create_event(
|
||||
self.requester,
|
||||
{
|
||||
"type": EventTypes.Message,
|
||||
"room_id": self.room_id,
|
||||
"sender": self.requester.user.to_string(),
|
||||
"content": {"msgtype": "m.text", "body": random_string(5)},
|
||||
},
|
||||
token_id=self.token_id,
|
||||
txn_id=txn_id,
|
||||
)
|
||||
)
|
||||
|
||||
def test_duplicated_txn_id(self):
|
||||
"""Test that attempting to handle/persist an event with a transaction ID
|
||||
that has already been persisted correctly returns the old event and does
|
||||
*not* produce duplicate messages.
|
||||
"""
|
||||
|
||||
txn_id = "something_suitably_random"
|
||||
|
||||
event1, context = self._create_duplicate_event(txn_id)
|
||||
|
||||
ret_event1 = self.get_success(
|
||||
self.handler.handle_new_client_event(self.requester, event1, context)
|
||||
)
|
||||
stream_id1 = ret_event1.internal_metadata.stream_ordering
|
||||
|
||||
self.assertEqual(event1.event_id, ret_event1.event_id)
|
||||
|
||||
event2, context = self._create_duplicate_event(txn_id)
|
||||
|
||||
# We want to test that the deduplication at the persit event end works,
|
||||
# so we want to make sure we test with different events.
|
||||
self.assertNotEqual(event1.event_id, event2.event_id)
|
||||
|
||||
ret_event2 = self.get_success(
|
||||
self.handler.handle_new_client_event(self.requester, event2, context)
|
||||
)
|
||||
stream_id2 = ret_event2.internal_metadata.stream_ordering
|
||||
|
||||
# Assert that the returned values match those from the initial event
|
||||
# rather than the new one.
|
||||
self.assertEqual(ret_event1.event_id, ret_event2.event_id)
|
||||
self.assertEqual(stream_id1, stream_id2)
|
||||
|
||||
# Let's test that calling `persist_event` directly also does the right
|
||||
# thing.
|
||||
event3, context = self._create_duplicate_event(txn_id)
|
||||
self.assertNotEqual(event1.event_id, event3.event_id)
|
||||
|
||||
ret_event3, event_pos3, _ = self.get_success(
|
||||
self.persist_event_storage.persist_event(event3, context)
|
||||
)
|
||||
|
||||
# Assert that the returned values match those from the initial event
|
||||
# rather than the new one.
|
||||
self.assertEqual(ret_event1.event_id, ret_event3.event_id)
|
||||
self.assertEqual(stream_id1, event_pos3.stream)
|
||||
|
||||
# Let's test that calling `persist_events` directly also does the right
|
||||
# thing.
|
||||
event4, context = self._create_duplicate_event(txn_id)
|
||||
self.assertNotEqual(event1.event_id, event3.event_id)
|
||||
|
||||
events, _ = self.get_success(
|
||||
self.persist_event_storage.persist_events([(event3, context)])
|
||||
)
|
||||
ret_event4 = events[0]
|
||||
|
||||
# Assert that the returned values match those from the initial event
|
||||
# rather than the new one.
|
||||
self.assertEqual(ret_event1.event_id, ret_event4.event_id)
|
||||
|
||||
def test_duplicated_txn_id_one_call(self):
|
||||
"""Test that we correctly handle duplicates that we try and persist at
|
||||
the same time.
|
||||
"""
|
||||
|
||||
txn_id = "something_else_suitably_random"
|
||||
|
||||
# Create two duplicate events to persist at the same time
|
||||
event1, context1 = self._create_duplicate_event(txn_id)
|
||||
event2, context2 = self._create_duplicate_event(txn_id)
|
||||
|
||||
# Ensure their event IDs are different to start with
|
||||
self.assertNotEqual(event1.event_id, event2.event_id)
|
||||
|
||||
events, _ = self.get_success(
|
||||
self.persist_event_storage.persist_events(
|
||||
[(event1, context1), (event2, context2)]
|
||||
)
|
||||
)
|
||||
|
||||
# Check that we've deduplicated the events.
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[0].event_id, events[1].event_id)
|
|
@ -107,7 +107,7 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
"/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
|
||||
"/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id,
|
||||
{},
|
||||
access_token=self.tok,
|
||||
)
|
||||
|
|
|
@ -254,17 +254,24 @@ class HomeserverTestCase(TestCase):
|
|||
if hasattr(self, "user_id"):
|
||||
if self.hijack_auth:
|
||||
|
||||
# We need a valid token ID to satisfy foreign key constraints.
|
||||
token_id = self.get_success(
|
||||
self.hs.get_datastore().add_access_token_to_user(
|
||||
self.helper.auth_user_id, "some_fake_token", None, None,
|
||||
)
|
||||
)
|
||||
|
||||
async def get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.helper.auth_user_id),
|
||||
"token_id": 1,
|
||||
"token_id": token_id,
|
||||
"is_guest": False,
|
||||
}
|
||||
|
||||
async def get_user_by_req(request, allow_guest=False, rights="access"):
|
||||
return create_requester(
|
||||
UserID.from_string(self.helper.auth_user_id),
|
||||
1,
|
||||
token_id,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
|
|
Loading…
Reference in New Issue