Merge pull request #1638 from matrix-org/kegan/sync-event-fields

Implement "event_fields" in filters
pull/1640/head
Kegsay 2016-11-22 14:02:38 +00:00 committed by GitHub
commit d4a459f7cb
4 changed files with 296 additions and 15 deletions

View File

@ -71,6 +71,21 @@ class Filtering(object):
if key in user_filter_json["room"]: if key in user_filter_json["room"]:
self._check_definition(user_filter_json["room"][key]) self._check_definition(user_filter_json["room"][key])
if "event_fields" in user_filter_json:
if type(user_filter_json["event_fields"]) != list:
raise SynapseError(400, "event_fields must be a list of strings")
for field in user_filter_json["event_fields"]:
if not isinstance(field, basestring):
raise SynapseError(400, "Event field must be a string")
# Don't allow '\\' in event field filters. This makes matching
# events a lot easier as we can then use a negative lookbehind
# assertion to split '\.' If we allowed \\ then it would
# incorrectly split '\\.' See synapse.events.utils.serialize_event
if r'\\' in field:
raise SynapseError(
400, r'The escape character \ cannot itself be escaped'
)
def _check_definition_room_lists(self, definition): def _check_definition_room_lists(self, definition):
"""Check that "rooms" and "not_rooms" are lists of room ids if they """Check that "rooms" and "not_rooms" are lists of room ids if they
are present are present
@ -152,6 +167,7 @@ class FilterCollection(object):
self.include_leave = filter_json.get("room", {}).get( self.include_leave = filter_json.get("room", {}).get(
"include_leave", False "include_leave", False
) )
self.event_fields = filter_json.get("event_fields", [])
def __repr__(self): def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),) return "<FilterCollection %s>" % (json.dumps(self._filter_json),)

View File

@ -16,6 +16,17 @@
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from . import EventBase from . import EventBase
from frozendict import frozendict
import re
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.')
def prune_event(event): def prune_event(event):
""" Returns a pruned version of the given event, which removes all keys we """ Returns a pruned version of the given event, which removes all keys we
@ -97,6 +108,83 @@ def prune_event(event):
) )
def _copy_field(src, dst, field):
"""Copy the field in 'src' to 'dst'.
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
then dst={"foo":{"bar":5}}.
Args:
src(dict): The dict to read from.
dst(dict): The dict to modify.
field(list<str>): List of keys to drill down to in 'src'.
"""
if len(field) == 0: # this should be impossible
return
if len(field) == 1: # common case e.g. 'origin_server_ts'
if field[0] in src:
dst[field[0]] = src[field[0]]
return
# Else is a nested field e.g. 'content.body'
# Pop the last field as that's the key to move across and we need the
# parent dict in order to access the data. Drill down to the right dict.
key_to_move = field.pop(-1)
sub_dict = src
for sub_field in field: # e.g. sub_field => "content"
if sub_field in sub_dict and type(sub_dict[sub_field]) in [dict, frozendict]:
sub_dict = sub_dict[sub_field]
else:
return
if key_to_move not in sub_dict:
return
# Insert the key into the output dictionary, creating nested objects
# as required. We couldn't do this any earlier or else we'd need to delete
# the empty objects if the key didn't exist.
sub_out_dict = dst
for sub_field in field:
sub_out_dict = sub_out_dict.setdefault(sub_field, {})
sub_out_dict[key_to_move] = sub_dict[key_to_move]
def only_fields(dictionary, fields):
"""Return a new dict with only the fields in 'dictionary' which are present
in 'fields'.
If there are no event fields specified then all fields are included.
The entries may include '.' charaters to indicate sub-fields.
So ['content.body'] will include the 'body' field of the 'content' object.
A literal '.' character in a field name may be escaped using a '\'.
Args:
dictionary(dict): The dictionary to read from.
fields(list<str>): A list of fields to copy over. Only shallow refs are
taken.
Returns:
dict: A new dictionary with only the given fields. If fields was empty,
the same dictionary is returned.
"""
if len(fields) == 0:
return dictionary
# for each field, convert it:
# ["content.body.thing\.with\.dots"] => [["content", "body", "thing\.with\.dots"]]
split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields]
# for each element of the output array of arrays:
# remove escaping so we can use the right key names.
split_fields[:] = [
[f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
]
output = {}
for field_array in split_fields:
_copy_field(dictionary, output, field_array)
return output
def format_event_raw(d): def format_event_raw(d):
return d return d
@ -137,7 +225,7 @@ def format_event_for_client_v2_without_room_id(d):
def serialize_event(e, time_now_ms, as_client_event=True, def serialize_event(e, time_now_ms, as_client_event=True,
event_format=format_event_for_client_v1, event_format=format_event_for_client_v1,
token_id=None): token_id=None, only_event_fields=None):
# FIXME(erikj): To handle the case of presence events and the like # FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase): if not isinstance(e, EventBase):
return e return e
@ -164,6 +252,12 @@ def serialize_event(e, time_now_ms, as_client_event=True,
d["unsigned"]["transaction_id"] = txn_id d["unsigned"]["transaction_id"] = txn_id
if as_client_event: if as_client_event:
return event_format(d) d = event_format(d)
else:
if only_event_fields:
if (not isinstance(only_event_fields, list) or
not all(isinstance(f, basestring) for f in only_event_fields)):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
return d return d

View File

@ -162,7 +162,7 @@ class SyncRestServlet(RestServlet):
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
joined = self.encode_joined( joined = self.encode_joined(
sync_result.joined, time_now, requester.access_token_id sync_result.joined, time_now, requester.access_token_id, filter.event_fields
) )
invited = self.encode_invited( invited = self.encode_invited(
@ -170,7 +170,7 @@ class SyncRestServlet(RestServlet):
) )
archived = self.encode_archived( archived = self.encode_archived(
sync_result.archived, time_now, requester.access_token_id sync_result.archived, time_now, requester.access_token_id, filter.event_fields
) )
response_content = { response_content = {
@ -197,7 +197,7 @@ class SyncRestServlet(RestServlet):
formatted.append(event) formatted.append(event)
return {"events": formatted} return {"events": formatted}
def encode_joined(self, rooms, time_now, token_id): def encode_joined(self, rooms, time_now, token_id, event_fields):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
@ -208,7 +208,8 @@ class SyncRestServlet(RestServlet):
calculations calculations
token_id(int): ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned.
Returns: Returns:
dict[str, dict[str, object]]: the joined rooms list, in our dict[str, dict[str, object]]: the joined rooms list, in our
response format response format
@ -216,7 +217,7 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
room, time_now, token_id room, time_now, token_id, only_fields=event_fields
) )
return joined return joined
@ -253,7 +254,7 @@ class SyncRestServlet(RestServlet):
return invited return invited
def encode_archived(self, rooms, time_now, token_id): def encode_archived(self, rooms, time_now, token_id, event_fields):
""" """
Encode the archived rooms in a sync result Encode the archived rooms in a sync result
@ -264,7 +265,8 @@ class SyncRestServlet(RestServlet):
calculations calculations
token_id(int): ID of the user's auth token - used for namespacing token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs of transaction IDs
event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned.
Returns: Returns:
dict[str, dict[str, object]]: The invited rooms list, in our dict[str, dict[str, object]]: The invited rooms list, in our
response format response format
@ -272,13 +274,13 @@ class SyncRestServlet(RestServlet):
joined = {} joined = {}
for room in rooms: for room in rooms:
joined[room.room_id] = self.encode_room( joined[room.room_id] = self.encode_room(
room, time_now, token_id, joined=False room, time_now, token_id, joined=False, only_fields=event_fields
) )
return joined return joined
@staticmethod @staticmethod
def encode_room(room, time_now, token_id, joined=True): def encode_room(room, time_now, token_id, joined=True, only_fields=None):
""" """
Args: Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a room (JoinedSyncResult|ArchivedSyncResult): sync result for a
@ -289,7 +291,7 @@ class SyncRestServlet(RestServlet):
of transaction IDs of transaction IDs
joined (bool): True if the user is joined to this room - will mean joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events we handle ephemeral events
only_fields(list<str>): Optional. The list of event fields to include.
Returns: Returns:
dict[str, object]: the room, encoded in our response format dict[str, object]: the room, encoded in our response format
""" """
@ -298,6 +300,7 @@ class SyncRestServlet(RestServlet):
return serialize_event( return serialize_event(
event, time_now, token_id=token_id, event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id, event_format=format_event_for_client_v2_without_room_id,
only_event_fields=only_fields,
) )
state_dict = room.state state_dict = room.state

View File

@ -17,7 +17,11 @@
from .. import unittest from .. import unittest
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import prune_event from synapse.events.utils import prune_event, serialize_event
def MockEvent(**kwargs):
return FrozenEvent(kwargs)
class PruneEventTestCase(unittest.TestCase): class PruneEventTestCase(unittest.TestCase):
@ -114,3 +118,167 @@ class PruneEventTestCase(unittest.TestCase):
'unsigned': {}, 'unsigned': {},
} }
) )
class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields):
return serialize_event(ev, 1479807801915, only_event_fields=fields)
def test_event_fields_works_with_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar"
),
["room_id"]
),
{
"room_id": "!foo:bar",
}
)
def test_event_fields_works_with_nested_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar",
content={
"body": "A message",
},
),
["content.body"]
),
{
"content": {
"body": "A message",
}
}
)
def test_event_fields_works_with_dot_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar",
content={
"key.with.dots": {},
},
),
["content.key\.with\.dots"]
),
{
"content": {
"key.with.dots": {},
}
}
)
def test_event_fields_works_with_nested_dot_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar",
content={
"not_me": 1,
"nested.dot.key": {
"leaf.key": 42,
"not_me_either": 1,
},
},
),
["content.nested\.dot\.key.leaf\.key"]
),
{
"content": {
"nested.dot.key": {
"leaf.key": 42,
},
}
}
)
def test_event_fields_nops_with_unknown_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar",
content={
"foo": "bar",
},
),
["content.foo", "content.notexists"]
),
{
"content": {
"foo": "bar",
}
}
)
def test_event_fields_nops_with_non_dict_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar",
content={
"foo": ["I", "am", "an", "array"],
},
),
["content.foo.am"]
),
{}
)
def test_event_fields_nops_with_array_keys(self):
self.assertEquals(
self.serialize(
MockEvent(
sender="@alice:localhost",
room_id="!foo:bar",
content={
"foo": ["I", "am", "an", "array"],
},
),
["content.foo.1"]
),
{}
)
def test_event_fields_all_fields_if_empty(self):
self.assertEquals(
self.serialize(
MockEvent(
room_id="!foo:bar",
content={
"foo": "bar",
},
),
[]
),
{
"room_id": "!foo:bar",
"content": {
"foo": "bar",
},
"unsigned": {}
}
)
def test_event_fields_fail_if_fields_not_str(self):
with self.assertRaises(TypeError):
self.serialize(
MockEvent(
room_id="!foo:bar",
content={
"foo": "bar",
},
),
["room_id", 4]
)