325 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			325 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
| # -*- coding: utf-8 -*-
 | |
| # Copyright 2014-2016 OpenMarket Ltd
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import logging
 | |
| 
 | |
| from six import iteritems, itervalues
 | |
| 
 | |
| import attr
 | |
| 
 | |
| from synapse.api.constants import EventTypes
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| @attr.s(slots=True)
 | |
| class StateFilter(object):
 | |
|     """A filter used when querying for state.
 | |
| 
 | |
|     Attributes:
 | |
|         types (dict[str, set[str]|None]): Map from type to set of state keys (or
 | |
|             None). This specifies which state_keys for the given type to fetch
 | |
|             from the DB. If None then all events with that type are fetched. If
 | |
|             the set is empty then no events with that type are fetched.
 | |
|         include_others (bool): Whether to fetch events with types that do not
 | |
|             appear in `types`.
 | |
|     """
 | |
| 
 | |
|     types = attr.ib()
 | |
|     include_others = attr.ib(default=False)
 | |
| 
 | |
|     def __attrs_post_init__(self):
 | |
|         # If `include_others` is set we canonicalise the filter by removing
 | |
|         # wildcards from the types dictionary
 | |
|         if self.include_others:
 | |
|             self.types = {k: v for k, v in iteritems(self.types) if v is not None}
 | |
| 
 | |
|     @staticmethod
 | |
|     def all():
 | |
|         """Creates a filter that fetches everything.
 | |
| 
 | |
|         Returns:
 | |
|             StateFilter
 | |
|         """
 | |
|         return StateFilter(types={}, include_others=True)
 | |
| 
 | |
|     @staticmethod
 | |
|     def none():
 | |
|         """Creates a filter that fetches nothing.
 | |
| 
 | |
|         Returns:
 | |
|             StateFilter
 | |
|         """
 | |
|         return StateFilter(types={}, include_others=False)
 | |
| 
 | |
|     @staticmethod
 | |
|     def from_types(types):
 | |
|         """Creates a filter that only fetches the given types
 | |
| 
 | |
|         Args:
 | |
|             types (Iterable[tuple[str, str|None]]): A list of type and state
 | |
|                 keys to fetch. A state_key of None fetches everything for
 | |
|                 that type
 | |
| 
 | |
|         Returns:
 | |
|             StateFilter
 | |
|         """
 | |
|         type_dict = {}
 | |
|         for typ, s in types:
 | |
|             if typ in type_dict:
 | |
|                 if type_dict[typ] is None:
 | |
|                     continue
 | |
| 
 | |
|             if s is None:
 | |
|                 type_dict[typ] = None
 | |
|                 continue
 | |
| 
 | |
|             type_dict.setdefault(typ, set()).add(s)
 | |
| 
 | |
|         return StateFilter(types=type_dict)
 | |
| 
 | |
|     @staticmethod
 | |
|     def from_lazy_load_member_list(members):
 | |
|         """Creates a filter that returns all non-member events, plus the member
 | |
|         events for the given users
 | |
| 
 | |
|         Args:
 | |
|             members (iterable[str]): Set of user IDs
 | |
| 
 | |
|         Returns:
 | |
|             StateFilter
 | |
|         """
 | |
|         return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
 | |
| 
 | |
|     def return_expanded(self):
 | |
|         """Creates a new StateFilter where type wild cards have been removed
 | |
|         (except for memberships). The returned filter is a superset of the
 | |
|         current one, i.e. anything that passes the current filter will pass
 | |
|         the returned filter.
 | |
| 
 | |
|         This helps the caching as the DictionaryCache knows if it has *all* the
 | |
|         state, but does not know if it has all of the keys of a particular type,
 | |
|         which makes wildcard lookups expensive unless we have a complete cache.
 | |
|         Hence, if we are doing a wildcard lookup, populate the cache fully so
 | |
|         that we can do an efficient lookup next time.
 | |
| 
 | |
|         Note that since we have two caches, one for membership events and one for
 | |
|         other events, we can be a bit more clever than simply returning
 | |
|         `StateFilter.all()` if `has_wildcards()` is True.
 | |
| 
 | |
|         We return a StateFilter where:
 | |
|             1. the list of membership events to return is the same
 | |
|             2. if there is a wildcard that matches non-member events we
 | |
|                return all non-member events
 | |
| 
 | |
|         Returns:
 | |
|             StateFilter
 | |
|         """
 | |
| 
 | |
|         if self.is_full():
 | |
|             # If we're going to return everything then there's nothing to do
 | |
|             return self
 | |
| 
 | |
|         if not self.has_wildcards():
 | |
|             # If there are no wild cards, there's nothing to do
 | |
|             return self
 | |
| 
 | |
|         if EventTypes.Member in self.types:
 | |
|             get_all_members = self.types[EventTypes.Member] is None
 | |
|         else:
 | |
|             get_all_members = self.include_others
 | |
| 
 | |
|         has_non_member_wildcard = self.include_others or any(
 | |
|             state_keys is None
 | |
|             for t, state_keys in iteritems(self.types)
 | |
|             if t != EventTypes.Member
 | |
|         )
 | |
| 
 | |
|         if not has_non_member_wildcard:
 | |
|             # If there are no non-member wild cards we can just return ourselves
 | |
|             return self
 | |
| 
 | |
|         if get_all_members:
 | |
|             # We want to return everything.
 | |
|             return StateFilter.all()
 | |
|         else:
 | |
|             # We want to return all non-members, but only particular
 | |
|             # memberships
 | |
|             return StateFilter(
 | |
|                 types={EventTypes.Member: self.types[EventTypes.Member]},
 | |
|                 include_others=True,
 | |
|             )
 | |
| 
 | |
|     def make_sql_filter_clause(self):
 | |
|         """Converts the filter to an SQL clause.
 | |
| 
 | |
|         For example:
 | |
| 
 | |
|             f = StateFilter.from_types([("m.room.create", "")])
 | |
|             clause, args = f.make_sql_filter_clause()
 | |
|             clause == "(type = ? AND state_key = ?)"
 | |
|             args == ['m.room.create', '']
 | |
| 
 | |
| 
 | |
|         Returns:
 | |
|             tuple[str, list]: The SQL string (may be empty) and arguments. An
 | |
|             empty SQL string is returned when the filter matches everything
 | |
|             (i.e. is "full").
 | |
|         """
 | |
| 
 | |
|         where_clause = ""
 | |
|         where_args = []
 | |
| 
 | |
|         if self.is_full():
 | |
|             return where_clause, where_args
 | |
| 
 | |
|         if not self.include_others and not self.types:
 | |
|             # i.e. this is an empty filter, so we need to return a clause that
 | |
|             # will match nothing
 | |
|             return "1 = 2", []
 | |
| 
 | |
|         # First we build up a lost of clauses for each type/state_key combo
 | |
|         clauses = []
 | |
|         for etype, state_keys in iteritems(self.types):
 | |
|             if state_keys is None:
 | |
|                 clauses.append("(type = ?)")
 | |
|                 where_args.append(etype)
 | |
|                 continue
 | |
| 
 | |
|             for state_key in state_keys:
 | |
|                 clauses.append("(type = ? AND state_key = ?)")
 | |
|                 where_args.extend((etype, state_key))
 | |
| 
 | |
|         # This will match anything that appears in `self.types`
 | |
|         where_clause = " OR ".join(clauses)
 | |
| 
 | |
|         # If we want to include stuff that's not in the types dict then we add
 | |
|         # a `OR type NOT IN (...)` clause to the end.
 | |
|         if self.include_others:
 | |
|             if where_clause:
 | |
|                 where_clause += " OR "
 | |
| 
 | |
|             where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
 | |
|             where_args.extend(self.types)
 | |
| 
 | |
|         return where_clause, where_args
 | |
| 
 | |
|     def max_entries_returned(self):
 | |
|         """Returns the maximum number of entries this filter will return if
 | |
|         known, otherwise returns None.
 | |
| 
 | |
|         For example a simple state filter asking for `("m.room.create", "")`
 | |
|         will return 1, whereas the default state filter will return None.
 | |
| 
 | |
|         This is used to bail out early if the right number of entries have been
 | |
|         fetched.
 | |
|         """
 | |
|         if self.has_wildcards():
 | |
|             return None
 | |
| 
 | |
|         return len(self.concrete_types())
 | |
| 
 | |
|     def filter_state(self, state_dict):
 | |
|         """Returns the state filtered with by this StateFilter
 | |
| 
 | |
|         Args:
 | |
|             state (dict[tuple[str, str], Any]): The state map to filter
 | |
| 
 | |
|         Returns:
 | |
|             dict[tuple[str, str], Any]: The filtered state map
 | |
|         """
 | |
|         if self.is_full():
 | |
|             return dict(state_dict)
 | |
| 
 | |
|         filtered_state = {}
 | |
|         for k, v in iteritems(state_dict):
 | |
|             typ, state_key = k
 | |
|             if typ in self.types:
 | |
|                 state_keys = self.types[typ]
 | |
|                 if state_keys is None or state_key in state_keys:
 | |
|                     filtered_state[k] = v
 | |
|             elif self.include_others:
 | |
|                 filtered_state[k] = v
 | |
| 
 | |
|         return filtered_state
 | |
| 
 | |
|     def is_full(self):
 | |
|         """Whether this filter fetches everything or not
 | |
| 
 | |
|         Returns:
 | |
|             bool
 | |
|         """
 | |
|         return self.include_others and not self.types
 | |
| 
 | |
|     def has_wildcards(self):
 | |
|         """Whether the filter includes wildcards or is attempting to fetch
 | |
|         specific state.
 | |
| 
 | |
|         Returns:
 | |
|             bool
 | |
|         """
 | |
| 
 | |
|         return self.include_others or any(
 | |
|             state_keys is None for state_keys in itervalues(self.types)
 | |
|         )
 | |
| 
 | |
|     def concrete_types(self):
 | |
|         """Returns a list of concrete type/state_keys (i.e. not None) that
 | |
|         will be fetched. This will be a complete list if `has_wildcards`
 | |
|         returns False, but otherwise will be a subset (or even empty).
 | |
| 
 | |
|         Returns:
 | |
|             list[tuple[str,str]]
 | |
|         """
 | |
|         return [
 | |
|             (t, s)
 | |
|             for t, state_keys in iteritems(self.types)
 | |
|             if state_keys is not None
 | |
|             for s in state_keys
 | |
|         ]
 | |
| 
 | |
|     def get_member_split(self):
 | |
|         """Return the filter split into two: one which assumes it's exclusively
 | |
|         matching against member state, and one which assumes it's matching
 | |
|         against non member state.
 | |
| 
 | |
|         This is useful due to the returned filters giving correct results for
 | |
|         `is_full()`, `has_wildcards()`, etc, when operating against maps that
 | |
|         either exclusively contain member events or only contain non-member
 | |
|         events. (Which is the case when dealing with the member vs non-member
 | |
|         state caches).
 | |
| 
 | |
|         Returns:
 | |
|             tuple[StateFilter, StateFilter]: The member and non member filters
 | |
|         """
 | |
| 
 | |
|         if EventTypes.Member in self.types:
 | |
|             state_keys = self.types[EventTypes.Member]
 | |
|             if state_keys is None:
 | |
|                 member_filter = StateFilter.all()
 | |
|             else:
 | |
|                 member_filter = StateFilter({EventTypes.Member: state_keys})
 | |
|         elif self.include_others:
 | |
|             member_filter = StateFilter.all()
 | |
|         else:
 | |
|             member_filter = StateFilter.none()
 | |
| 
 | |
|         non_member_filter = StateFilter(
 | |
|             types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member},
 | |
|             include_others=self.include_others,
 | |
|         )
 | |
| 
 | |
|         return member_filter, non_member_filter
 |