Add new module API for adding custom fields to events `unsigned` section (#16549)

pull/16561/head
Erik Johnston 2023-10-27 10:04:08 +01:00 committed by GitHub
parent 679c691f6f
commit c02406ac71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 194 additions and 44 deletions

View File

@ -0,0 +1 @@
Add a new module API callback that allows adding extra fields to events' unsigned section when sent down to clients.

View File

@ -48,6 +48,7 @@
- [Password auth provider callbacks](modules/password_auth_provider_callbacks.md)
- [Background update controller callbacks](modules/background_update_controller_callbacks.md)
- [Account data callbacks](modules/account_data_callbacks.md)
- [Add extra fields to client events unsigned section callbacks](modules/add_extra_fields_to_client_events_unsigned.md)
- [Porting a legacy module to the new interface](modules/porting_legacy_module.md)
- [Workers](workers.md)
- [Using `synctl` with Workers](synctl_workers.md)

View File

@ -0,0 +1,32 @@
# Add extra fields to client events unsigned section callbacks
_First introduced in Synapse v1.96.0_
This callback allows modules to add extra fields to the unsigned section of
events when they get sent down to clients.
These get called *every* time an event is to be sent to clients, so care should
be taken to ensure with respect to performance.
### API
To register the callback, use
`register_add_extra_fields_to_unsigned_client_event_callbacks` on the
`ModuleApi`.
The callback should be of the form
```python
async def add_field_to_unsigned(
event: EventBase,
) -> JsonDict:
```
where the extra fields to add to the event's unsigned section is returned.
(Modules must not attempt to modify the `event` directly).
This cannot be used to alter the "core" fields in the unsigned section emitted
by Synapse itself.
If multiple such callbacks try to add the same field to an event's unsigned
section, the last-registered callback wins.

View File

@ -17,6 +17,7 @@ import re
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
@ -45,6 +46,7 @@ from . import EventBase
if TYPE_CHECKING:
from synapse.handlers.relations import BundledAggregations
from synapse.server import HomeServer
# Split strings on "." but not "\." (or "\\\.").
@ -56,6 +58,13 @@ CANONICALJSON_MAX_INT = (2**53) - 1
CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT
# Module API callback that allows adding fields to the unsigned section of
# events that are sent to clients.
ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK = Callable[
[EventBase], Awaitable[JsonDict]
]
def prune_event(event: EventBase) -> EventBase:
"""Returns a pruned version of the given event, which removes all keys we
don't know about or think could potentially be dodgy.
@ -509,7 +518,13 @@ class EventClientSerializer:
clients.
"""
def serialize_event(
def __init__(self, hs: "HomeServer") -> None:
self._store = hs.get_datastores().main
self._add_extra_fields_to_unsigned_client_event_callbacks: List[
ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
] = []
async def serialize_event(
self,
event: Union[JsonDict, EventBase],
time_now: int,
@ -535,10 +550,21 @@ class EventClientSerializer:
serialized_event = serialize_event(event, time_now, config=config)
new_unsigned = {}
for callback in self._add_extra_fields_to_unsigned_client_event_callbacks:
u = await callback(event)
new_unsigned.update(u)
if new_unsigned:
# We do the `update` this way round so that modules can't clobber
# existing fields.
new_unsigned.update(serialized_event["unsigned"])
serialized_event["unsigned"] = new_unsigned
# Check if there are any bundled aggregations to include with the event.
if bundle_aggregations:
if event.event_id in bundle_aggregations:
self._inject_bundled_aggregations(
await self._inject_bundled_aggregations(
event,
time_now,
config,
@ -548,7 +574,7 @@ class EventClientSerializer:
return serialized_event
def _inject_bundled_aggregations(
async def _inject_bundled_aggregations(
self,
event: EventBase,
time_now: int,
@ -590,7 +616,7 @@ class EventClientSerializer:
# said that we should only include the `event_id`, `origin_server_ts` and
# `sender` of the edit; however MSC3925 proposes extending it to the whole
# of the edit, which is what we do here.
serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event(
serialized_aggregations[RelationTypes.REPLACE] = await self.serialize_event(
event_aggregations.replace,
time_now,
config=config,
@ -600,7 +626,7 @@ class EventClientSerializer:
if event_aggregations.thread:
thread = event_aggregations.thread
serialized_latest_event = self.serialize_event(
serialized_latest_event = await self.serialize_event(
thread.latest_event,
time_now,
config=config,
@ -623,7 +649,7 @@ class EventClientSerializer:
"m.relations", {}
).update(serialized_aggregations)
def serialize_events(
async def serialize_events(
self,
events: Iterable[Union[JsonDict, EventBase]],
time_now: int,
@ -645,7 +671,7 @@ class EventClientSerializer:
The list of serialized events
"""
return [
self.serialize_event(
await self.serialize_event(
event,
time_now,
config=config,
@ -654,6 +680,14 @@ class EventClientSerializer:
for event in events
]
def register_add_extra_fields_to_unsigned_client_event_callback(
self, callback: ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
) -> None:
"""Register a callback that returns additions to the unsigned section of
serialized events.
"""
self._add_extra_fields_to_unsigned_client_event_callbacks.append(callback)
_PowerLevel = Union[str, int]
PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]

View File

@ -120,7 +120,7 @@ class EventStreamHandler:
events.extend(to_add)
chunks = self._event_serializer.serialize_events(
chunks = await self._event_serializer.serialize_events(
events,
time_now,
config=SerializeEventConfig(

View File

@ -173,7 +173,7 @@ class InitialSyncHandler:
d["inviter"] = event.sender
invite_event = await self.store.get_event(event.event_id)
d["invite"] = self._event_serializer.serialize_event(
d["invite"] = await self._event_serializer.serialize_event(
invite_event,
time_now,
config=serializer_options,
@ -225,7 +225,7 @@ class InitialSyncHandler:
d["messages"] = {
"chunk": (
self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
messages,
time_now=time_now,
config=serializer_options,
@ -235,7 +235,7 @@ class InitialSyncHandler:
"end": await end_token.to_string(self.store),
}
d["state"] = self._event_serializer.serialize_events(
d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
config=serializer_options,
@ -387,7 +387,7 @@ class InitialSyncHandler:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
messages, time_now, config=serialize_options
)
),
@ -396,7 +396,7 @@ class InitialSyncHandler:
},
"state": (
# Don't bundle aggregations as this is a deprecated API.
self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
room_state.values(), time_now, config=serialize_options
)
),
@ -420,7 +420,7 @@ class InitialSyncHandler:
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
# Don't bundle aggregations as this is a deprecated API.
state = self._event_serializer.serialize_events(
state = await self._event_serializer.serialize_events(
current_state.values(),
time_now,
config=serialize_options,
@ -497,7 +497,7 @@ class InitialSyncHandler:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
messages, time_now, config=serialize_options
)
),

View File

@ -244,7 +244,7 @@ class MessageHandler:
)
room_state = room_state_events[membership_event_id]
events = self._event_serializer.serialize_events(
events = await self._event_serializer.serialize_events(
room_state.values(),
self.clock.time_msec(),
config=SerializeEventConfig(requester=requester),

View File

@ -657,7 +657,7 @@ class PaginationHandler:
chunk = {
"chunk": (
self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
events,
time_now,
config=serialize_options,
@ -669,7 +669,7 @@ class PaginationHandler:
}
if state:
chunk["state"] = self._event_serializer.serialize_events(
chunk["state"] = await self._event_serializer.serialize_events(
state, time_now, config=serialize_options
)

View File

@ -167,7 +167,7 @@ class RelationsHandler:
now = self._clock.time_msec()
serialize_options = SerializeEventConfig(requester=requester)
return_value: JsonDict = {
"chunk": self._event_serializer.serialize_events(
"chunk": await self._event_serializer.serialize_events(
events,
now,
bundle_aggregations=aggregations,
@ -177,7 +177,9 @@ class RelationsHandler:
if include_original_event:
# Do not bundle aggregations when retrieving the original event because
# we want the content before relations are applied to it.
return_value["original_event"] = self._event_serializer.serialize_event(
return_value[
"original_event"
] = await self._event_serializer.serialize_event(
event,
now,
bundle_aggregations=None,
@ -602,7 +604,7 @@ class RelationsHandler:
)
now = self._clock.time_msec()
serialized_events = self._event_serializer.serialize_events(
serialized_events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)

View File

@ -374,13 +374,13 @@ class SearchHandler:
serialize_options = SerializeEventConfig(requester=requester)
for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events(
context["events_before"] = await self._event_serializer.serialize_events(
context["events_before"],
time_now,
bundle_aggregations=aggregations,
config=serialize_options,
)
context["events_after"] = self._event_serializer.serialize_events(
context["events_after"] = await self._event_serializer.serialize_events(
context["events_after"],
time_now,
bundle_aggregations=aggregations,
@ -390,7 +390,7 @@ class SearchHandler:
results = [
{
"rank": search_result.rank_map[e.event_id],
"result": self._event_serializer.serialize_event(
"result": await self._event_serializer.serialize_event(
e,
time_now,
bundle_aggregations=aggregations,
@ -409,7 +409,7 @@ class SearchHandler:
if state_results:
rooms_cat_res["state"] = {
room_id: self._event_serializer.serialize_events(
room_id: await self._event_serializer.serialize_events(
state_events, time_now, config=serialize_options
)
for room_id, state_events in state_results.items()

View File

@ -48,6 +48,7 @@ from synapse.events.presence_router import (
GET_USERS_FOR_STATES_CALLBACK,
PresenceRouter,
)
from synapse.events.utils import ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK
from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK,
@ -259,6 +260,7 @@ class ModuleApi:
self.custom_template_dir = hs.config.server.custom_template_directory
self._callbacks = hs.get_module_api_callbacks()
self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled
self._event_serializer = hs.get_event_client_serializer()
try:
app_name = self._hs.config.email.email_app_name
@ -490,6 +492,25 @@ class ModuleApi:
"""
self._hs.register_module_web_resource(path, resource)
def register_add_extra_fields_to_unsigned_client_event_callbacks(
self,
*,
add_field_to_unsigned_callback: Optional[
ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
] = None,
) -> None:
"""Registers a callback that can be used to add fields to the unsigned
section of events.
The callback is called every time an event is sent down to a client.
Added in Synapse 1.96.0
"""
if add_field_to_unsigned_callback is not None:
self._event_serializer.register_add_extra_fields_to_unsigned_client_event_callback(
add_field_to_unsigned_callback
)
#########################################################################
# The following methods can be called by the module at any point in time.

View File

@ -444,7 +444,7 @@ class RoomStateRestServlet(RestServlet):
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec()
room_state = self._event_serializer.serialize_events(events.values(), now)
room_state = await self._event_serializer.serialize_events(events.values(), now)
ret = {"state": room_state}
return HTTPStatus.OK, ret
@ -789,22 +789,22 @@ class RoomEventContextServlet(RestServlet):
time_now = self.clock.time_msec()
results = {
"events_before": self._event_serializer.serialize_events(
"events_before": await self._event_serializer.serialize_events(
event_context.events_before,
time_now,
bundle_aggregations=event_context.aggregations,
),
"event": self._event_serializer.serialize_event(
"event": await self._event_serializer.serialize_event(
event_context.event,
time_now,
bundle_aggregations=event_context.aggregations,
),
"events_after": self._event_serializer.serialize_events(
"events_after": await self._event_serializer.serialize_events(
event_context.events_after,
time_now,
bundle_aggregations=event_context.aggregations,
),
"state": self._event_serializer.serialize_events(
"state": await self._event_serializer.serialize_events(
event_context.state, time_now
),
"start": event_context.start,

View File

@ -93,7 +93,7 @@ class EventRestServlet(RestServlet):
event = await self.event_handler.get_event(requester.user, None, event_id)
if event:
result = self._event_serializer.serialize_event(
result = await self._event_serializer.serialize_event(
event,
self.clock.time_msec(),
config=SerializeEventConfig(requester=requester),

View File

@ -87,7 +87,7 @@ class NotificationsServlet(RestServlet):
"actions": pa.actions,
"ts": pa.received_ts,
"event": (
self._event_serializer.serialize_event(
await self._event_serializer.serialize_event(
notif_events[pa.event_id],
now,
config=serialize_options,

View File

@ -859,7 +859,7 @@ class RoomEventServlet(RestServlet):
# per MSC2676, /rooms/{roomId}/event/{eventId}, should return the
# *original* event, rather than the edited version
event_dict = self._event_serializer.serialize_event(
event_dict = await self._event_serializer.serialize_event(
event,
self.clock.time_msec(),
bundle_aggregations=aggregations,
@ -911,25 +911,25 @@ class RoomEventContextServlet(RestServlet):
time_now = self.clock.time_msec()
serializer_options = SerializeEventConfig(requester=requester)
results = {
"events_before": self._event_serializer.serialize_events(
"events_before": await self._event_serializer.serialize_events(
event_context.events_before,
time_now,
bundle_aggregations=event_context.aggregations,
config=serializer_options,
),
"event": self._event_serializer.serialize_event(
"event": await self._event_serializer.serialize_event(
event_context.event,
time_now,
bundle_aggregations=event_context.aggregations,
config=serializer_options,
),
"events_after": self._event_serializer.serialize_events(
"events_after": await self._event_serializer.serialize_events(
event_context.events_after,
time_now,
bundle_aggregations=event_context.aggregations,
config=serializer_options,
),
"state": self._event_serializer.serialize_events(
"state": await self._event_serializer.serialize_events(
event_context.state,
time_now,
config=serializer_options,

View File

@ -384,7 +384,7 @@ class SyncRestServlet(RestServlet):
"""
invited = {}
for room in rooms:
invite = self._event_serializer.serialize_event(
invite = await self._event_serializer.serialize_event(
room.invite, time_now, config=serialize_options
)
unsigned = dict(invite.get("unsigned", {}))
@ -415,7 +415,7 @@ class SyncRestServlet(RestServlet):
"""
knocked = {}
for room in rooms:
knock = self._event_serializer.serialize_event(
knock = await self._event_serializer.serialize_event(
room.knock, time_now, config=serialize_options
)
@ -506,10 +506,10 @@ class SyncRestServlet(RestServlet):
event.room_id,
)
serialized_state = self._event_serializer.serialize_events(
serialized_state = await self._event_serializer.serialize_events(
state_events, time_now, config=serialize_options
)
serialized_timeline = self._event_serializer.serialize_events(
serialized_timeline = await self._event_serializer.serialize_events(
timeline_events,
time_now,
config=serialize_options,

View File

@ -786,7 +786,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_event_client_serializer(self) -> EventClientSerializer:
return EventClientSerializer()
return EventClientSerializer(self)
@cache_in_self
def get_password_policy_handler(self) -> PasswordPolicyHandler:

View File

@ -0,0 +1,59 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.events import EventBase
from synapse.rest import admin, login, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
class EventUnsignedAdditionTestCase(HomeserverTestCase):
servlets = [
room.register_servlets,
admin.register_servlets,
login.register_servlets,
]
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self._store = homeserver.get_datastores().main
self._module_api = homeserver.get_module_api()
self._account_data_mgr = self._module_api.account_data_manager
def test_annotate_event(self) -> None:
"""Test that we can annotate an event when we request it from the
server.
"""
async def add_unsigned_event(event: EventBase) -> JsonDict:
return {"test_key": event.event_id}
self._module_api.register_add_extra_fields_to_unsigned_client_event_callbacks(
add_field_to_unsigned_callback=add_unsigned_event
)
user_id = self.register_user("user", "password")
token = self.login("user", "password")
room_id = self.helper.create_room_as(user_id, tok=token)
result = self.helper.send(room_id, "Hello!", tok=token)
event_id = result["event_id"]
event_json = self.helper.get_event(room_id, event_id, tok=token)
self.assertEqual(event_json["unsigned"].get("test_key"), event_id)

View File

@ -243,7 +243,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
assert event is not None
time_now = self.clock.time_msec()
serialized = self.serializer.serialize_event(event, time_now)
serialized = self.get_success(self.serializer.serialize_event(event, time_now))
return serialized