Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

pull/4396/head
Amber Brown 2018-10-31 04:41:03 +11:00
commit f91aefd245
21 changed files with 431 additions and 76 deletions

View File

@ -23,6 +23,9 @@ branches:
- develop - develop
- /^release-v/ - /^release-v/
# When running the tox environments that call Twisted Trial, we can pass the -j
# flag to run the tests concurrently. We set this to 2 for CPU bound tests
# (SQLite) and 4 for I/O bound tests (PostgreSQL).
matrix: matrix:
fast_finish: true fast_finish: true
include: include:
@ -33,10 +36,10 @@ matrix:
env: TOX_ENV="pep8,check_isort" env: TOX_ENV="pep8,check_isort"
- python: 2.7 - python: 2.7
env: TOX_ENV=py27 env: TOX_ENV=py27 TRIAL_FLAGS="-j 2"
- python: 2.7 - python: 2.7
env: TOX_ENV=py27-old env: TOX_ENV=py27-old TRIAL_FLAGS="-j 2"
- python: 2.7 - python: 2.7
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4" env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
@ -44,10 +47,10 @@ matrix:
- postgresql - postgresql
- python: 3.5 - python: 3.5
env: TOX_ENV=py35 env: TOX_ENV=py35 TRIAL_FLAGS="-j 2"
- python: 3.6 - python: 3.6
env: TOX_ENV=py36 env: TOX_ENV=py36 TRIAL_FLAGS="-j 2"
- python: 3.6 - python: 3.6
env: TOX_ENV=py36-postgres TRIAL_FLAGS="-j 4" env: TOX_ENV=py36-postgres TRIAL_FLAGS="-j 4"

1
changelog.d/4006.misc Normal file
View File

@ -0,0 +1 @@
Delete unreferenced state groups during history purge

1
changelog.d/4095.bugfix Normal file
View File

@ -0,0 +1 @@
Fix exceptions when using the email mailer on Python 3.

1
changelog.d/4118.removal Normal file
View File

@ -0,0 +1 @@
The obsolete and non-functional /pull federation endpoint has been removed.

1
changelog.d/4122.bugfix Normal file
View File

@ -0,0 +1 @@
Searches that request profile info now no longer fail with a 500.

View File

@ -323,11 +323,6 @@ class FederationServer(FederationBase):
else: else:
defer.returnValue((404, "")) defer.returnValue((404, ""))
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
raise NotImplementedError("Pull transactions not implemented")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_request(self, query_type, args): def on_query_request(self, query_type, args):
received_queries_counter.labels(query_type).inc() received_queries_counter.labels(query_type).inc()

View File

@ -362,14 +362,6 @@ class FederationSendServlet(BaseFederationServlet):
defer.returnValue((code, response)) defer.returnValue((code, response))
class FederationPullServlet(BaseFederationServlet):
PATH = "/pull/"
# This is for when someone asks us for everything since version X
def on_GET(self, origin, content, query):
return self.handler.on_pull_request(query["origin"][0], query["v"])
class FederationEventServlet(BaseFederationServlet): class FederationEventServlet(BaseFederationServlet):
PATH = "/event/(?P<event_id>[^/]*)/" PATH = "/event/(?P<event_id>[^/]*)/"
@ -1261,7 +1253,6 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
FEDERATION_SERVLET_CLASSES = ( FEDERATION_SERVLET_CLASSES = (
FederationSendServlet, FederationSendServlet,
FederationPullServlet,
FederationEventServlet, FederationEventServlet,
FederationStateServlet, FederationStateServlet,
FederationStateIdsServlet, FederationStateIdsServlet,

View File

@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -324,9 +325,12 @@ class SearchHandler(BaseHandler):
else: else:
last_event_id = event.event_id last_event_id = event.event_id
state_filter = StateFilter.from_types(
[(EventTypes.Member, sender) for sender in senders]
)
state = yield self.store.get_state_for_event( state = yield self.store.get_state_for_event(
last_event_id, last_event_id, state_filter
types=[(EventTypes.Member, sender) for sender in senders]
) )
res["profile_info"] = { res["profile_info"] = {

View File

@ -85,7 +85,10 @@ class EmailPusher(object):
self.timed_call = None self.timed_call = None
def on_new_notifications(self, min_stream_ordering, max_stream_ordering): def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) if self.max_stream_ordering:
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
else:
self.max_stream_ordering = max_stream_ordering
self._start_processing() self._start_processing()
def on_new_receipts(self, min_stream_id, max_stream_id): def on_new_receipts(self, min_stream_id, max_stream_id):

View File

@ -26,7 +26,6 @@ import bleach
import jinja2 import jinja2
from twisted.internet import defer from twisted.internet import defer
from twisted.mail.smtp import sendmail
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
@ -85,6 +84,7 @@ class Mailer(object):
self.notif_template_html = notif_template_html self.notif_template_html = notif_template_html
self.notif_template_text = notif_template_text self.notif_template_text = notif_template_text
self.sendmail = self.hs.get_sendmail()
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.macaroon_gen = self.hs.get_macaroon_generator() self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
@ -191,11 +191,11 @@ class Mailer(object):
multipart_msg.attach(html_part) multipart_msg.attach(html_part)
logger.info("Sending email push notification to %s" % email_address) logger.info("Sending email push notification to %s" % email_address)
# logger.debug(html_text)
yield sendmail( yield self.sendmail(
self.hs.config.email_smtp_host, self.hs.config.email_smtp_host,
raw_from, raw_to, multipart_msg.as_string(), raw_from, raw_to, multipart_msg.as_string().encode('utf8'),
reactor=self.hs.get_reactor(),
port=self.hs.config.email_smtp_port, port=self.hs.config.email_smtp_port,
requireAuthentication=self.hs.config.email_smtp_user is not None, requireAuthentication=self.hs.config.email_smtp_user is not None,
username=self.hs.config.email_smtp_user, username=self.hs.config.email_smtp_user,
@ -333,7 +333,7 @@ class Mailer(object):
notif_events, user_id, reason): notif_events, user_id, reason):
if len(notifs_by_room) == 1: if len(notifs_by_room) == 1:
# Only one room has new stuff # Only one room has new stuff
room_id = notifs_by_room.keys()[0] room_id = list(notifs_by_room.keys())[0]
# If the room has some kind of name, use it, but we don't # If the room has some kind of name, use it, but we don't
# want the generated-from-names one here otherwise we'll # want the generated-from-names one here otherwise we'll

View File

@ -23,6 +23,7 @@ import abc
import logging import logging
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.mail.smtp import sendmail
from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.web.client import BrowserLikePolicyForHTTPS
from synapse.api.auth import Auth from synapse.api.auth import Auth
@ -174,6 +175,7 @@ class HomeServer(object):
'message_handler', 'message_handler',
'pagination_handler', 'pagination_handler',
'room_context_handler', 'room_context_handler',
'sendmail',
] ]
# This is overridden in derived application classes # This is overridden in derived application classes
@ -269,6 +271,9 @@ class HomeServer(object):
def build_room_creation_handler(self): def build_room_creation_handler(self):
return RoomCreationHandler(self) return RoomCreationHandler(self)
def build_sendmail(self):
return sendmail
def build_state_handler(self): def build_state_handler(self):
return StateHandler(self) return StateHandler(self)

View File

@ -38,6 +38,7 @@ from synapse.state import StateResolutionStore
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import batch_iter from synapse.util import batch_iter
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
@ -205,7 +206,8 @@ def _retry_on_integrity_error(func):
# inherits from EventFederationStore so that we can call _update_backward_extremities # inherits from EventFederationStore so that we can call _update_backward_extremities
# and _handle_mult_prev_events (though arguably those could both be moved in here) # and _handle_mult_prev_events (though arguably those could both be moved in here)
class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore): class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
BackgroundUpdateStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
@ -2034,55 +2036,37 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
logger.info("[purge] finding redundant state groups") logger.info("[purge] finding redundant state groups")
# Get all state groups that are only referenced by events that are # Get all state groups that are referenced by events that are to be
# to be deleted. # deleted. We then go and check if they are referenced by other events
# This works by first getting state groups that we may want to delete, # or state groups, and if not we delete them.
# joining against event_to_state_groups to get events that use that
# state group, then left joining against events_to_purge again. Any
# state group where the left join produce *no nulls* are referenced
# only by events that are going to be purged.
txn.execute(""" txn.execute("""
SELECT state_group FROM SELECT DISTINCT state_group FROM events_to_purge
( INNER JOIN event_to_state_groups USING (event_id)
SELECT DISTINCT state_group FROM events_to_purge
INNER JOIN event_to_state_groups USING (event_id)
) AS sp
INNER JOIN event_to_state_groups USING (state_group)
LEFT JOIN events_to_purge AS ep USING (event_id)
GROUP BY state_group
HAVING SUM(CASE WHEN ep.event_id IS NULL THEN 1 ELSE 0 END) = 0
""") """)
state_rows = txn.fetchall() referenced_state_groups = set(sg for sg, in txn)
logger.info("[purge] found %i redundant state groups", len(state_rows)) logger.info(
"[purge] found %i referenced state groups",
len(referenced_state_groups),
)
# make a set of the redundant state groups, so that we can look them up logger.info("[purge] finding state groups that can be deleted")
# efficiently
state_groups_to_delete = set([sg for sg, in state_rows])
# Now we get all the state groups that rely on these state groups state_groups_to_delete, remaining_state_groups = (
logger.info("[purge] finding state groups which depend on redundant" self._find_unreferenced_groups_during_purge(
" state groups") txn, referenced_state_groups,
remaining_state_groups = []
for i in range(0, len(state_rows), 100):
chunk = [sg for sg, in state_rows[i:i + 100]]
# look for state groups whose prev_state_group is one we are about
# to delete
rows = self._simple_select_many_txn(
txn,
table="state_group_edges",
column="prev_state_group",
iterable=chunk,
retcols=["state_group"],
keyvalues={},
) )
remaining_state_groups.extend( )
row["state_group"] for row in rows
# exclude state groups we are about to delete: no point in logger.info(
# updating them "[purge] found %i state groups to delete",
if row["state_group"] not in state_groups_to_delete len(state_groups_to_delete),
) )
logger.info(
"[purge] de-delta-ing %i remaining state groups",
len(remaining_state_groups),
)
# Now we turn the state groups that reference to-be-deleted state # Now we turn the state groups that reference to-be-deleted state
# groups to non delta versions. # groups to non delta versions.
@ -2127,11 +2111,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
logger.info("[purge] removing redundant state groups") logger.info("[purge] removing redundant state groups")
txn.executemany( txn.executemany(
"DELETE FROM state_groups_state WHERE state_group = ?", "DELETE FROM state_groups_state WHERE state_group = ?",
state_rows ((sg,) for sg in state_groups_to_delete),
) )
txn.executemany( txn.executemany(
"DELETE FROM state_groups WHERE id = ?", "DELETE FROM state_groups WHERE id = ?",
state_rows ((sg,) for sg in state_groups_to_delete),
) )
logger.info("[purge] removing events from event_to_state_groups") logger.info("[purge] removing events from event_to_state_groups")
@ -2227,6 +2211,85 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
logger.info("[purge] done") logger.info("[purge] done")
def _find_unreferenced_groups_during_purge(self, txn, state_groups):
"""Used when purging history to figure out which state groups can be
deleted and which need to be de-delta'ed (due to one of its prev groups
being scheduled for deletion).
Args:
txn
state_groups (set[int]): Set of state groups referenced by events
that are going to be deleted.
Returns:
tuple[set[int], set[int]]: The set of state groups that can be
deleted and the set of state groups that need to be de-delta'ed
"""
# Graph of state group -> previous group
graph = {}
# Set of events that we have found to be referenced by events
referenced_groups = set()
# Set of state groups we've already seen
state_groups_seen = set(state_groups)
# Set of state groups to handle next.
next_to_search = set(state_groups)
while next_to_search:
# We bound size of groups we're looking up at once, to stop the
# SQL query getting too big
if len(next_to_search) < 100:
current_search = next_to_search
next_to_search = set()
else:
current_search = set(itertools.islice(next_to_search, 100))
next_to_search -= current_search
# Check if state groups are referenced
sql = """
SELECT DISTINCT state_group FROM event_to_state_groups
LEFT JOIN events_to_purge AS ep USING (event_id)
WHERE state_group IN (%s) AND ep.event_id IS NULL
""" % (",".join("?" for _ in current_search),)
txn.execute(sql, list(current_search))
referenced = set(sg for sg, in txn)
referenced_groups |= referenced
# We don't continue iterating up the state group graphs for state
# groups that are referenced.
current_search -= referenced
rows = self._simple_select_many_txn(
txn,
table="state_group_edges",
column="prev_state_group",
iterable=current_search,
keyvalues={},
retcols=("prev_state_group", "state_group",),
)
prevs = set(row["state_group"] for row in rows)
# We don't bother re-handling groups we've already seen
prevs -= state_groups_seen
next_to_search |= prevs
state_groups_seen |= prevs
for row in rows:
# Note: Each state group can have at most one prev group
graph[row["state_group"]] = row["prev_state_group"]
to_delete = state_groups_seen - referenced_groups
to_dedelta = set()
for sg in referenced_groups:
prev_sg = graph.get(sg)
if prev_sg and prev_sg in to_delete:
to_dedelta.add(sg)
return to_delete, to_dedelta
@defer.inlineCallbacks @defer.inlineCallbacks
def is_event_after(self, event_id1, event_id2): def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream """Returns True if event_id1 is after event_id2 in the stream

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 51 SCHEMA_VERSION = 52
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -0,0 +1,19 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- This is needed to efficiently check for unreferenced state groups during
-- purge. Added events_to_state_group(state_group) index
INSERT into background_updates (update_name, progress_json)
VALUES ('event_to_state_groups_sg_index', '{}');

View File

@ -1257,6 +1257,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs) super(StateStore, self).__init__(db_conn, hs)
@ -1275,6 +1276,12 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
columns=["state_key"], columns=["state_key"],
where_clause="type='m.room.member'", where_clause="type='m.room.member'",
) )
self.register_background_index_update(
self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
index_name="event_to_state_groups_sg_index",
table="event_to_state_groups",
columns=["state_group"],
)
def _store_event_state_mappings_txn(self, txn, events_and_contexts): def _store_event_state_mappings_txn(self, txn, events_and_contexts):
state_groups = {} state_groups = {}

0
tests/push/__init__.py Normal file
View File

148
tests/push/test_email.py Normal file
View File

@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pkg_resources
from twisted.internet.defer import Deferred
from synapse.rest.client.v1 import admin, login, room
from tests.unittest import HomeserverTestCase
try:
from synapse.push.mailer import load_jinja2_templates
except Exception:
load_jinja2_templates = None
class EmailPusherTests(HomeserverTestCase):
skip = "No Jinja installed" if not load_jinja2_templates else None
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
user_id = True
hijack_auth = False
def make_homeserver(self, reactor, clock):
# List[Tuple[Deferred, args, kwargs]]
self.email_attempts = []
def sendmail(*args, **kwargs):
d = Deferred()
self.email_attempts.append((d, args, kwargs))
return d
config = self.default_config()
config.email_enable_notifs = True
config.start_pushers = True
config.email_template_dir = os.path.abspath(
pkg_resources.resource_filename('synapse', 'res/templates')
)
config.email_notif_template_html = "notif_mail.html"
config.email_notif_template_text = "notif_mail.txt"
config.email_smtp_host = "127.0.0.1"
config.email_smtp_port = 20
config.require_transport_security = False
config.email_smtp_user = None
config.email_app_name = "Matrix"
config.email_notif_from = "test@example.com"
hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
return hs
def test_sends_email(self):
# Register the user who gets notified
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")
# Register the user who sends the message
other_user_id = self.register_user("otheruser", "pass")
other_access_token = self.login("otheruser", "pass")
# Register the pusher
user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token)
)
token_id = user_tuple["token_id"]
self.get_success(
self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=token_id,
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name="a@example.com",
pushkey="a@example.com",
lang=None,
data={},
)
)
# Create a room
room = self.helper.create_room_as(user_id, tok=access_token)
# Invite the other person
self.helper.invite(room=room, src=user_id, tok=access_token, targ=other_user_id)
# The other user joins
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
# The other user sends some messages
self.helper.send(room, body="Hi!", tok=other_access_token)
self.helper.send(room, body="There!", tok=other_access_token)
# Get the stream ordering before it gets sent
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0]["last_stream_ordering"]
# Advance time a bit, so the pusher will register something has happened
self.pump(100)
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
# One email was attempted to be sent
self.assertEqual(len(self.email_attempts), 1)
# Make the email succeed
self.email_attempts[0][0].callback(True)
self.pump()
# One email was attempted to be sent
self.assertEqual(len(self.email_attempts), 1)
# The stream ordering has increased
pushers = self.get_success(
self.hs.get_datastore().get_pushers_by(dict(user_name=user_id))
)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)

View File

@ -23,7 +23,7 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.rest.client.v1 import room from synapse.rest.client.v1 import admin, login, room
from tests import unittest from tests import unittest
@ -801,3 +801,107 @@ class RoomMessageListTestCase(RoomBase):
self.assertEquals(token, channel.json_body['start']) self.assertEquals(token, channel.json_body['start'])
self.assertTrue("chunk" in channel.json_body) self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body) self.assertTrue("end" in channel.json_body)
class RoomSearchTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
user_id = True
hijack_auth = False
def prepare(self, reactor, clock, hs):
# Register the user who does the searching
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
# Register the user who sends the message
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
# Create a room
self.room = self.helper.create_room_as(self.user_id, tok=self.access_token)
# Invite the other person
self.helper.invite(
room=self.room,
src=self.user_id,
tok=self.access_token,
targ=self.other_user_id,
)
# The other user joins
self.helper.join(
room=self.room, user=self.other_user_id, tok=self.other_access_token
)
def test_finds_message(self):
"""
The search functionality will search for content in messages if asked to
do so.
"""
# The other user sends some messages
self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
self.helper.send(self.room, body="There!", tok=self.other_access_token)
request, channel = self.make_request(
"POST",
"/search?access_token=%s" % (self.access_token,),
{
"search_categories": {
"room_events": {"keys": ["content.body"], "search_term": "Hi"}
}
},
)
self.render(request)
# Check we get the results we expect -- one search result, of the sent
# messages
self.assertEqual(channel.code, 200)
results = channel.json_body["search_categories"]["room_events"]
self.assertEqual(results["count"], 1)
self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!")
# No context was requested, so we should get none.
self.assertEqual(results["results"][0]["context"], {})
def test_include_context(self):
"""
When event_context includes include_profile, profile information will be
included in the search response.
"""
# The other user sends some messages
self.helper.send(self.room, body="Hi!", tok=self.other_access_token)
self.helper.send(self.room, body="There!", tok=self.other_access_token)
request, channel = self.make_request(
"POST",
"/search?access_token=%s" % (self.access_token,),
{
"search_categories": {
"room_events": {
"keys": ["content.body"],
"search_term": "Hi",
"event_context": {"include_profile": True},
}
}
},
)
self.render(request)
# Check we get the results we expect -- one search result, of the sent
# messages
self.assertEqual(channel.code, 200)
results = channel.json_body["search_categories"]["room_events"]
self.assertEqual(results["count"], 1)
self.assertEqual(results["results"][0]["result"]["content"]["body"], "Hi!")
# We should get context info, like the two users, and the display names.
context = results["results"][0]["context"]
self.assertEqual(len(context["profile_info"].keys()), 2)
self.assertEqual(
context["profile_info"][self.other_user_id]["displayname"], "otheruser"
)

View File

@ -125,7 +125,9 @@ def make_request(method, path, content=b"", access_token=None, request=SynapseRe
req.content = BytesIO(content) req.content = BytesIO(content)
if access_token: if access_token:
req.requestHeaders.addRawHeader(b"Authorization", b"Bearer " + access_token) req.requestHeaders.addRawHeader(
b"Authorization", b"Bearer " + access_token.encode('ascii')
)
if content: if content:
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")

View File

@ -207,7 +207,7 @@ class TestMauLimit(unittest.TestCase):
def do_sync_for_user(self, token): def do_sync_for_user(self, token):
request, channel = make_request( request, channel = make_request(
"GET", "/sync", access_token=token.encode('ascii') "GET", "/sync", access_token=token
) )
render(request, self.resource, self.reactor) render(request, self.resource, self.reactor)

View File

@ -146,6 +146,13 @@ def DEBUG(target):
return target return target
def INFO(target):
"""A decorator to set the .loglevel attribute to logging.INFO.
Can apply to either a TestCase or an individual test method."""
target.loglevel = logging.INFO
return target
class HomeserverTestCase(TestCase): class HomeserverTestCase(TestCase):
""" """
A base TestCase that reduces boilerplate for HomeServer-using test cases. A base TestCase that reduces boilerplate for HomeServer-using test cases.
@ -373,5 +380,5 @@ class HomeserverTestCase(TestCase):
self.render(request) self.render(request)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
access_token = channel.json_body["access_token"].encode('ascii') access_token = channel.json_body["access_token"]
return access_token return access_token