Create a separate filter object to do the actual filtering, so that we can

split the storage and management of filters from the actual filter code
and don't have to load a filter from the db each time we filter an event
pull/37/head
Mark Haines 2015-01-29 17:41:48 +00:00
parent 295322048d
commit 93ed31dda2
3 changed files with 166 additions and 164 deletions

View File

@ -25,127 +25,25 @@ class Filtering(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
def get_user_filter(self, user_localpart, filter_id): def get_user_filter(self, user_localpart, filter_id):
return self.store.get_user_filter(user_localpart, filter_id) result = self.store.get_user_filter(user_localpart, filter_id)
result.addCallback(Filter)
return result
def add_user_filter(self, user_localpart, user_filter): def add_user_filter(self, user_localpart, user_filter):
self._check_valid_filter(user_filter) self._check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter) return self.store.add_user_filter(user_localpart, user_filter)
def filter_public_user_data(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["public_user_data"]
)
def filter_private_user_data(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["private_user_data"]
)
def filter_room_state(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["room", "state"]
)
def filter_room_events(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["room", "events"]
)
def filter_room_ephemeral(self, events, user, filter_id):
return self._filter_on_key(
events, user, filter_id, ["room", "ephemeral"]
)
# TODO(paul): surely we should probably add a delete_user_filter or # TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for # replace_user_filter at some point? There's no REST API specified for
# them however # them however
@defer.inlineCallbacks def _check_valid_filter(self, user_filter_json):
def _filter_on_key(self, events, user, filter_id, keys):
filter_json = yield self.get_user_filter(user.localpart, filter_id)
if not filter_json:
defer.returnValue(events)
try:
# extract the right definition from the filter
definition = filter_json
for key in keys:
definition = definition[key]
defer.returnValue(self._filter_with_definition(events, definition))
except KeyError:
# return all events if definition isn't specified.
defer.returnValue(events)
def _filter_with_definition(self, events, definition):
return [e for e in events if self._passes_definition(definition, e)]
def _passes_definition(self, definition, event):
"""Check if the event passes through the given definition.
Args:
definition(dict): The definition to check against.
event(Event): The event to check.
Returns:
True if the event passes through the filter.
"""
# Algorithm notes:
# For each key in the definition, check the event meets the criteria:
# * For types: Literal match or prefix match (if ends with wildcard)
# * For senders/rooms: Literal match only
# * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
# and 'not_types' then it is treated as only being in 'not_types')
# room checks
if hasattr(event, "room_id"):
room_id = event.room_id
allow_rooms = definition.get("rooms", None)
reject_rooms = definition.get("not_rooms", None)
if reject_rooms and room_id in reject_rooms:
return False
if allow_rooms and room_id not in allow_rooms:
return False
# sender checks
if hasattr(event, "sender"):
# Should we be including event.state_key for some event types?
sender = event.sender
allow_senders = definition.get("senders", None)
reject_senders = definition.get("not_senders", None)
if reject_senders and sender in reject_senders:
return False
if allow_senders and sender not in allow_senders:
return False
# type checks
if "not_types" in definition:
for def_type in definition["not_types"]:
if self._event_matches_type(event, def_type):
return False
if "types" in definition:
included = False
for def_type in definition["types"]:
if self._event_matches_type(event, def_type):
included = True
break
if not included:
return False
return True
def _event_matches_type(self, event, def_type):
if def_type.endswith("*"):
type_prefix = def_type[:-1]
return event.type.startswith(type_prefix)
else:
return event.type == def_type
def _check_valid_filter(self, user_filter):
"""Check if the provided filter is valid. """Check if the provided filter is valid.
This inspects all definitions contained within the filter. This inspects all definitions contained within the filter.
Args: Args:
user_filter(dict): The filter user_filter_json(dict): The filter
Raises: Raises:
SynapseError: If the filter is not valid. SynapseError: If the filter is not valid.
""" """
@ -162,13 +60,13 @@ class Filtering(object):
] ]
for key in top_level_definitions: for key in top_level_definitions:
if key in user_filter: if key in user_filter_json:
self._check_definition(user_filter[key]) self._check_definition(user_filter_json[key])
if "room" in user_filter: if "room" in user_filter_json:
for key in room_level_definitions: for key in room_level_definitions:
if key in user_filter["room"]: if key in user_filter_json["room"]:
self._check_definition(user_filter["room"][key]) self._check_definition(user_filter_json["room"][key])
def _check_definition(self, definition): def _check_definition(self, definition):
"""Check if the provided definition is valid. """Check if the provided definition is valid.
@ -237,3 +135,101 @@ class Filtering(object):
if ("bundle_updates" in definition and if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool): type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.") raise SynapseError(400, "Bad bundle_updates: expected bool.")
class Filter(object):
def __init__(self, filter_json):
self.filter_json = filter_json
def filter_public_user_data(self, events):
return self._filter_on_key(events, ["public_user_data"])
def filter_private_user_data(self, events):
return self._filter_on_key(events, ["private_user_data"])
def filter_room_state(self, events):
return self._filter_on_key(events, ["room", "state"])
def filter_room_events(self, events):
return self._filter_on_key(events, ["room", "events"])
def filter_room_ephemeral(self, events):
return self._filter_on_key(events, ["room", "ephemeral"])
def _filter_on_key(self, events, keys):
filter_json = self.filter_json
if not filter_json:
return events
try:
# extract the right definition from the filter
definition = filter_json
for key in keys:
definition = definition[key]
return self._filter_with_definition(events, definition)
except KeyError:
# return all events if definition isn't specified.
return events
def _filter_with_definition(self, events, definition):
return [e for e in events if self._passes_definition(definition, e)]
def _passes_definition(self, definition, event):
"""Check if the event passes through the given definition.
Args:
definition(dict): The definition to check against.
event(Event): The event to check.
Returns:
True if the event passes through the filter.
"""
# Algorithm notes:
# For each key in the definition, check the event meets the criteria:
# * For types: Literal match or prefix match (if ends with wildcard)
# * For senders/rooms: Literal match only
# * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
# and 'not_types' then it is treated as only being in 'not_types')
# room checks
if hasattr(event, "room_id"):
room_id = event.room_id
allow_rooms = definition.get("rooms", None)
reject_rooms = definition.get("not_rooms", None)
if reject_rooms and room_id in reject_rooms:
return False
if allow_rooms and room_id not in allow_rooms:
return False
# sender checks
if hasattr(event, "sender"):
# Should we be including event.state_key for some event types?
sender = event.sender
allow_senders = definition.get("senders", None)
reject_senders = definition.get("not_senders", None)
if reject_senders and sender in reject_senders:
return False
if allow_senders and sender not in allow_senders:
return False
# type checks
if "not_types" in definition:
for def_type in definition["not_types"]:
if self._event_matches_type(event, def_type):
return False
if "types" in definition:
included = False
for def_type in definition["types"]:
if self._event_matches_type(event, def_type):
included = True
break
if not included:
return False
return True
def _event_matches_type(self, event, def_type):
if def_type.endswith("*"):
type_prefix = def_type[:-1]
return event.type.startswith(type_prefix)
else:
return event.type == def_type

View File

@ -59,7 +59,7 @@ class GetFilterRestServlet(RestServlet):
filter_id=filter_id, filter_id=filter_id,
) )
defer.returnValue((200, filter)) defer.returnValue((200, filter.filter_json))
except KeyError: except KeyError:
raise SynapseError(400, "No such filter") raise SynapseError(400, "No such filter")

View File

@ -25,6 +25,7 @@ from tests.utils import (
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID from synapse.types import UserID
from synapse.api.filtering import Filter
user_localpart = "test_user" user_localpart = "test_user"
MockEvent = namedtuple("MockEvent", "sender type room_id") MockEvent = namedtuple("MockEvent", "sender type room_id")
@ -53,6 +54,7 @@ class FilteringTestCase(unittest.TestCase):
) )
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
self.filter = Filter({})
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
@ -66,7 +68,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_types_works_with_wildcards(self): def test_definition_types_works_with_wildcards(self):
@ -79,7 +81,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_types_works_with_unknowns(self): def test_definition_types_works_with_unknowns(self):
@ -92,7 +94,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_not_types_works_with_literals(self): def test_definition_not_types_works_with_literals(self):
@ -105,7 +107,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_not_types_works_with_wildcards(self): def test_definition_not_types_works_with_wildcards(self):
@ -118,7 +120,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_not_types_works_with_unknowns(self): def test_definition_not_types_works_with_unknowns(self):
@ -131,7 +133,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_not_types_takes_priority_over_types(self): def test_definition_not_types_takes_priority_over_types(self):
@ -145,7 +147,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_senders_works_with_literals(self): def test_definition_senders_works_with_literals(self):
@ -158,7 +160,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_senders_works_with_unknowns(self): def test_definition_senders_works_with_unknowns(self):
@ -171,7 +173,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_not_senders_works_with_literals(self): def test_definition_not_senders_works_with_literals(self):
@ -184,7 +186,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_not_senders_works_with_unknowns(self): def test_definition_not_senders_works_with_unknowns(self):
@ -197,7 +199,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_not_senders_takes_priority_over_senders(self): def test_definition_not_senders_takes_priority_over_senders(self):
@ -211,7 +213,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_rooms_works_with_literals(self): def test_definition_rooms_works_with_literals(self):
@ -224,7 +226,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!secretbase:unknown" room_id="!secretbase:unknown"
) )
self.assertTrue( self.assertTrue(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_rooms_works_with_unknowns(self): def test_definition_rooms_works_with_unknowns(self):
@ -237,7 +239,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown"
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_not_rooms_works_with_literals(self): def test_definition_not_rooms_works_with_literals(self):
@ -250,7 +252,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown"
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_not_rooms_works_with_unknowns(self): def test_definition_not_rooms_works_with_unknowns(self):
@ -263,7 +265,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown"
) )
self.assertTrue( self.assertTrue(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_not_rooms_takes_priority_over_rooms(self): def test_definition_not_rooms_takes_priority_over_rooms(self):
@ -277,7 +279,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!secretbase:unknown" room_id="!secretbase:unknown"
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_combined_event(self): def test_definition_combined_event(self):
@ -295,7 +297,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup room_id="!stage:unknown" # yup
) )
self.assertTrue( self.assertTrue(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_combined_event_bad_sender(self): def test_definition_combined_event_bad_sender(self):
@ -313,7 +315,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup room_id="!stage:unknown" # yup
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_combined_event_bad_room(self): def test_definition_combined_event_bad_room(self):
@ -331,7 +333,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!piggyshouse:muppets" # nope room_id="!piggyshouse:muppets" # nope
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
def test_definition_combined_event_bad_type(self): def test_definition_combined_event_bad_type(self):
@ -349,12 +351,12 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup room_id="!stage:unknown" # yup
) )
self.assertFalse( self.assertFalse(
self.filtering._passes_definition(definition, event) self.filter._passes_definition(definition, event)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_public_user_data_match(self): def test_filter_public_user_data_match(self):
user_filter = { user_filter_json = {
"public_user_data": { "public_user_data": {
"types": ["m.*"] "types": ["m.*"]
} }
@ -362,7 +364,7 @@ class FilteringTestCase(unittest.TestCase):
user = UserID.from_string("@" + user_localpart + ":test") user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
user_filter=user_filter, user_filter=user_filter_json,
) )
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -371,16 +373,17 @@ class FilteringTestCase(unittest.TestCase):
) )
events = [event] events = [event]
results = yield self.filtering.filter_public_user_data( user_filter = yield self.filtering.get_user_filter(
events=events, user_localpart=user_localpart,
user=user, filter_id=filter_id,
filter_id=filter_id
) )
results = user_filter.filter_public_user_data(events=events)
self.assertEquals(events, results) self.assertEquals(events, results)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_public_user_data_no_match(self): def test_filter_public_user_data_no_match(self):
user_filter = { user_filter_json = {
"public_user_data": { "public_user_data": {
"types": ["m.*"] "types": ["m.*"]
} }
@ -388,7 +391,7 @@ class FilteringTestCase(unittest.TestCase):
user = UserID.from_string("@" + user_localpart + ":test") user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
user_filter=user_filter, user_filter=user_filter_json,
) )
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -397,16 +400,17 @@ class FilteringTestCase(unittest.TestCase):
) )
events = [event] events = [event]
results = yield self.filtering.filter_public_user_data( user_filter = yield self.filtering.get_user_filter(
events=events, user_localpart=user_localpart,
user=user, filter_id=filter_id,
filter_id=filter_id
) )
results = user_filter.filter_public_user_data(events=events)
self.assertEquals([], results) self.assertEquals([], results)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_room_state_match(self): def test_filter_room_state_match(self):
user_filter = { user_filter_json = {
"room": { "room": {
"state": { "state": {
"types": ["m.*"] "types": ["m.*"]
@ -416,7 +420,7 @@ class FilteringTestCase(unittest.TestCase):
user = UserID.from_string("@" + user_localpart + ":test") user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
user_filter=user_filter, user_filter=user_filter_json,
) )
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -425,16 +429,17 @@ class FilteringTestCase(unittest.TestCase):
) )
events = [event] events = [event]
results = yield self.filtering.filter_room_state( user_filter = yield self.filtering.get_user_filter(
events=events, user_localpart=user_localpart,
user=user, filter_id=filter_id,
filter_id=filter_id
) )
results = user_filter.filter_room_state(events=events)
self.assertEquals(events, results) self.assertEquals(events, results)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filter_room_state_no_match(self): def test_filter_room_state_no_match(self):
user_filter = { user_filter_json = {
"room": { "room": {
"state": { "state": {
"types": ["m.*"] "types": ["m.*"]
@ -444,7 +449,7 @@ class FilteringTestCase(unittest.TestCase):
user = UserID.from_string("@" + user_localpart + ":test") user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
user_filter=user_filter, user_filter=user_filter_json,
) )
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
@ -453,16 +458,17 @@ class FilteringTestCase(unittest.TestCase):
) )
events = [event] events = [event]
results = yield self.filtering.filter_room_state( user_filter = yield self.filtering.get_user_filter(
events=events, user_localpart=user_localpart,
user=user, filter_id=filter_id,
filter_id=filter_id
) )
results = user_filter.filter_room_state(events)
self.assertEquals([], results) self.assertEquals([], results)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_add_filter(self): def test_add_filter(self):
user_filter = { user_filter_json = {
"room": { "room": {
"state": { "state": {
"types": ["m.*"] "types": ["m.*"]
@ -472,11 +478,11 @@ class FilteringTestCase(unittest.TestCase):
filter_id = yield self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
user_filter=user_filter, user_filter=user_filter_json,
) )
self.assertEquals(filter_id, 0) self.assertEquals(filter_id, 0)
self.assertEquals(user_filter, self.assertEquals(user_filter_json,
(yield self.datastore.get_user_filter( (yield self.datastore.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
filter_id=0, filter_id=0,
@ -485,7 +491,7 @@ class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_filter(self): def test_get_filter(self):
user_filter = { user_filter_json = {
"room": { "room": {
"state": { "state": {
"types": ["m.*"] "types": ["m.*"]
@ -495,7 +501,7 @@ class FilteringTestCase(unittest.TestCase):
filter_id = yield self.datastore.add_user_filter( filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
user_filter=user_filter, user_filter=user_filter_json,
) )
filter = yield self.filtering.get_user_filter( filter = yield self.filtering.get_user_filter(
@ -503,4 +509,4 @@ class FilteringTestCase(unittest.TestCase):
filter_id=filter_id, filter_id=filter_id,
) )
self.assertEquals(filter, user_filter) self.assertEquals(filter.filter_json, user_filter_json)