diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 2718e9482e..28f5300dc9 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -46,15 +46,20 @@ class SearchHandler(BaseHandler): """ try: - search_term = content["search_categories"]["room_events"]["search_term"] - keys = content["search_categories"]["room_events"].get("keys", [ + room_cat = content["search_categories"]["room_events"] + search_term = room_cat["search_term"] + keys = room_cat.get("keys", [ "content.body", "content.name", "content.topic", ]) - filter_dict = content["search_categories"]["room_events"].get("filter", {}) - event_context = content["search_categories"]["room_events"].get( + filter_dict = room_cat.get("filter", {}) + order_by = room_cat.get("order_by", "rank") + event_context = room_cat.get( "event_context", None ) + group_by = room_cat.get("groupings", {}).get("group_by", {}) + group_keys = [g["key"] for g in group_by] + if event_context is not None: before_limit = int(event_context.get( "before_limit", 5 @@ -65,6 +70,15 @@ class SearchHandler(BaseHandler): except KeyError: raise SynapseError(400, "Invalid search query") + if order_by not in ("rank", "recent"): + raise SynapseError(400, "Invalid order by: %r" % (order_by,)) + + if set(group_keys) - {"room_id", "sender"}: + raise SynapseError( + 400, + "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},) + ) + search_filter = Filter(filter_dict) # TODO: Search through left rooms too @@ -77,18 +91,88 @@ class SearchHandler(BaseHandler): room_ids = search_filter.filter_rooms(room_ids) - rank_map, event_map, _ = yield self.store.search_msgs( - room_ids, search_term, keys - ) + rank_map = {} + allowed_events = [] + room_groups = {} + sender_group = {} - filtered_events = search_filter.filter(event_map.values()) + if order_by == "rank": + rank_map, event_map, _ = yield self.store.search_msgs( + room_ids, search_term, keys + ) - allowed_events = yield self._filter_events_for_client( - user.to_string(), filtered_events - ) + filtered_events = search_filter.filter(event_map.values()) - allowed_events.sort(key=lambda e: -rank_map[e.event_id]) - allowed_events = allowed_events[:search_filter.limit()] + events = yield self._filter_events_for_client( + user.to_string(), filtered_events + ) + + events.sort(key=lambda e: -rank_map[e.event_id]) + allowed_events = events[:search_filter.limit()] + + for e in allowed_events: + rm = room_groups.setdefault(e.room_id, { + "results": [], + "order": rank_map[e.event_id], + }) + rm["results"].append(e.event_id) + + s = sender_group.setdefault(e.sender, { + "results": [], + "order": rank_map[e.event_id], + }) + s["results"].append(e.event_id) + + elif order_by == "recent": + for room_id in room_ids: + room_events = [] + pagination_token = None + i = 0 + + while len(room_events) < search_filter.limit() and i < 5: + i += 5 + r_map, event_map, pagination_token = yield self.store.search_room( + room_id, search_term, keys, search_filter.limit() * 2, + pagination_token=pagination_token, + ) + rank_map.update(r_map) + + filtered_events = search_filter.filter(event_map.values()) + + events = yield self._filter_events_for_client( + user.to_string(), filtered_events + ) + + room_events.extend(events) + room_events = room_events[:search_filter.limit()] + + if len(event_map) < search_filter.limit() * 2: + break + + if room_events: + group = room_groups.setdefault(room_id, {}) + if pagination_token: + group["next_batch"] = pagination_token + + group["results"] = [e.event_id for e in room_events] + group["order"] = max( + e.origin_server_ts/1000 for e in room_events + if hasattr(e, "origin_server_ts") + ) + + allowed_events.extend(room_events) + + # Normalize the group ranks + if room_groups: + mx = max(g["order"] for g in room_groups.values()) + mn = min(g["order"] for g in room_groups.values()) + + for g in room_groups.values(): + g["order"] = (g["order"] - mn) * 1.0 / (mx - mn) + + else: + # We should never get here due to the guard earlier. + raise NotImplementedError() if event_context is not None: now_token = yield self.hs.get_event_sources().get_current_token() @@ -144,11 +228,19 @@ class SearchHandler(BaseHandler): logger.info("Found %d results", len(results)) + rooms_cat_res = { + "results": results, + "count": len(results) + } + + if room_groups and "room_id" in group_keys: + rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups + + if sender_group and "sender" in group_keys: + rooms_cat_res.setdefault("groups", {})["sender"] = sender_group + defer.returnValue({ "search_categories": { - "room_events": { - "results": results, - "count": len(results) - } + "room_events": rooms_cat_res } }) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index cdf003502f..e37e56c1f2 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -20,6 +20,12 @@ from synapse.storage.engines import PostgresEngine, Sqlite3Engine from collections import namedtuple +import logging + + +logger = logging.getLogger(__name__) + + """The result of a search. Fields: @@ -109,3 +115,93 @@ class SearchStore(SQLBaseStore): event_map, None )) + + @defer.inlineCallbacks + def search_room(self, room_id, search_term, keys, limit, pagination_token=None): + """Performs a full text search over events with given keys. + + Args: + room_id (str): The room_id to search in + search_term (str): Search term to search for + keys (list): List of keys to search in, currently supports + "content.body", "content.name", "content.topic" + pagination_token (str): A pagination token previously returned + + Returns: + SearchResult + """ + clauses = [] + args = [search_term, room_id] + + local_clauses = [] + for key in keys: + local_clauses.append("key = ?") + args.append(key) + + clauses.append( + "(%s)" % (" OR ".join(local_clauses),) + ) + + if pagination_token: + topo, stream = pagination_token.split(",") + clauses.append( + "(topological_ordering < ?" + " OR (topological_ordering = ? AND stream_ordering < ?))" + ) + args.extend([topo, topo, stream]) + + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "SELECT ts_rank_cd(vector, query) as rank," + " topological_ordering, stream_ordering, room_id, event_id" + " FROM plainto_tsquery('english', ?) as query, event_search" + " NATURAL JOIN events" + " WHERE vector @@ query AND room_id = ?" + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id" + " topological_ordering, stream_ordering" + " FROM event_search" + " NATURAL JOIN events" + " WHERE value MATCH ? AND room_id = ?" + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + for clause in clauses: + sql += " AND " + clause + + # We add an arbitrary limit here to ensure we don't try to pull the + # entire table from the database. + sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" + + args.append(limit) + + results = yield self._execute( + "search_rooms", self.cursor_to_dict, sql, *args + ) + + events = yield self._get_events([r["event_id"] for r in results]) + + event_map = { + ev.event_id: ev + for ev in events + } + + pagination_token = None + if results: + topo = results[-1]["topological_ordering"] + stream = results[-1]["stream_ordering"] + pagination_token = "%s,%s" % (topo, stream) + + defer.returnValue(SearchResult( + { + r["event_id"]: r["rank"] + for r in results + if r["event_id"] in event_map + }, + event_map, + pagination_token + ))