Avoid unnecessary copies when filtering private read receipts. (#12711)

A minor optimization to avoid unnecessary copying/building
identical dictionaries when filtering private read receipts.

Also clarifies comments and cleans-up some tests.
pull/12749/head
Šimon Brandner 2022-05-16 17:06:23 +02:00 committed by GitHub
parent b4eb163434
commit 3ce15cc7be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 70 deletions

1
changelog.d/12711.misc Normal file
View File

@ -0,0 +1 @@
Optimize private read receipt filtering.

View File

@ -143,7 +143,7 @@ class InitialSyncHandler:
to_key=int(now_token.receipt_key), to_key=int(now_token.receipt_key),
) )
if self.hs.config.experimental.msc2285_enabled: if self.hs.config.experimental.msc2285_enabled:
receipt = ReceiptEventSource.filter_out_private(receipt, user_id) receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
tags_by_room = await self.store.get_tags_for_user(user_id) tags_by_room = await self.store.get_tags_for_user(user_id)
@ -449,7 +449,9 @@ class InitialSyncHandler:
if not receipts: if not receipts:
return [] return []
if self.hs.config.experimental.msc2285_enabled: if self.hs.config.experimental.msc2285_enabled:
receipts = ReceiptEventSource.filter_out_private(receipts, user_id) receipts = ReceiptEventSource.filter_out_private_receipts(
receipts, user_id
)
return receipts return receipts
presence, receipts, (messages, token) = await make_deferred_yieldable( presence, receipts, (messages, token) = await make_deferred_yieldable(

View File

@ -165,43 +165,69 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
self.config = hs.config self.config = hs.config
@staticmethod @staticmethod
def filter_out_private(events: List[JsonDict], user_id: str) -> List[JsonDict]: def filter_out_private_receipts(
rooms: List[JsonDict], user_id: str
) -> List[JsonDict]:
""" """
This method takes in what is returned by Filters a list of serialized receipts (as returned by /sync and /initialSync)
get_linearized_receipts_for_rooms() and goes through read receipts and removes private read receipts of other users.
filtering out m.read.private receipts if they were not sent by the
current user. This operates on the return value of get_linearized_receipts_for_rooms(),
which is wrapped in a cache. Care must be taken to ensure that the input
values are not modified.
Args:
rooms: A list of mappings, each mapping has a `content` field, which
is a map of event ID -> receipt type -> user ID -> receipt information.
Returns:
The same as rooms, but filtered.
""" """
visible_events = [] result = []
# filter out private receipts the user shouldn't see # Iterate through each room's receipt content.
for event in events: for room in rooms:
content = event.get("content", {}) # The receipt content with other user's private read receipts removed.
new_event = event.copy() content = {}
new_event["content"] = {}
for event_id, event_content in content.items(): # Iterate over each event ID / receipts for that event.
receipt_event = {} for event_id, orig_event_content in room.get("content", {}).items():
for receipt_type, receipt_content in event_content.items(): event_content = orig_event_content
if receipt_type == ReceiptTypes.READ_PRIVATE: # If there are private read receipts, additional logic is necessary.
user_rr = receipt_content.get(user_id, None) if ReceiptTypes.READ_PRIVATE in event_content:
if user_rr: # Make a copy without private read receipts to avoid leaking
receipt_event[ReceiptTypes.READ_PRIVATE] = { # other user's private read receipts..
user_id: user_rr.copy() event_content = {
} receipt_type: receipt_value
else: for receipt_type, receipt_value in event_content.items()
receipt_event[receipt_type] = receipt_content.copy() if receipt_type != ReceiptTypes.READ_PRIVATE
}
# Only include the receipt event if it is non-empty. # Copy the current user's private read receipt from the
if receipt_event: # original content, if it exists.
new_event["content"][event_id] = receipt_event user_private_read_receipt = orig_event_content[
ReceiptTypes.READ_PRIVATE
].get(user_id, None)
if user_private_read_receipt:
event_content[ReceiptTypes.READ_PRIVATE] = {
user_id: user_private_read_receipt
}
# Append new_event to visible_events unless empty # Include the event if there is at least one non-private read
if len(new_event["content"].keys()) > 0: # receipt or the current user has a private read receipt.
visible_events.append(new_event) if event_content:
content[event_id] = event_content
return visible_events # Include the event if there is at least one non-private read receipt
# or the current user has a private read receipt.
if content:
# Build a new event to avoid mutating the cache.
new_room = {k: v for k, v in room.items() if k != "content"}
new_room["content"] = content
result.append(new_room)
return result
async def get_new_events( async def get_new_events(
self, self,
@ -223,7 +249,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
) )
if self.config.experimental.msc2285_enabled: if self.config.experimental.msc2285_enabled:
events = ReceiptEventSource.filter_out_private(events, user.to_string()) events = ReceiptEventSource.filter_out_private_receipts(
events, user.to_string()
)
return events, to_key return events, to_key

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from copy import deepcopy
from typing import List from typing import List
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
@ -125,42 +125,6 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_handles_missing_content_of_m_read(self):
self._test_filters_private(
[
{
"content": {
"$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}},
"$1435641916114394fHBLK:matrix.org": {
ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
}
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
"type": "m.receipt",
}
],
[
{
"content": {
"$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}},
"$1435641916114394fHBLK:matrix.org": {
ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
}
},
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
"type": "m.receipt",
}
],
)
def test_handles_empty_event(self): def test_handles_empty_event(self):
self._test_filters_private( self._test_filters_private(
[ [
@ -332,9 +296,33 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
], ],
) )
def test_we_do_not_mutate(self):
"""Ensure the input values are not modified."""
events = [
{
"content": {
"$1435641916114394fHBLK:matrix.org": {
ReceiptTypes.READ_PRIVATE: {
"@rikj:jki.re": {
"ts": 1436451550453,
}
}
}
},
"room_id": "!jEsUZKDJdhlrceRyVU:example.org",
"type": "m.receipt",
}
]
original_events = deepcopy(events)
self._test_filters_private(events, [])
# Since the events are fed in from a cache they should not be modified.
self.assertEqual(events, original_events)
def _test_filters_private( def _test_filters_private(
self, events: List[JsonDict], expected_output: List[JsonDict] self, events: List[JsonDict], expected_output: List[JsonDict]
): ):
"""Tests that the _filter_out_private returns the expected output""" """Tests that the _filter_out_private returns the expected output"""
filtered_events = self.event_source.filter_out_private(events, "@me:server.org") filtered_events = self.event_source.filter_out_private_receipts(
events, "@me:server.org"
)
self.assertEqual(filtered_events, expected_output) self.assertEqual(filtered_events, expected_output)