Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

pull/8675/head
Erik Johnston 2020-06-25 09:39:01 +01:00
commit c16bb06d25
25 changed files with 254 additions and 172 deletions

View File

@ -195,7 +195,7 @@ By default Synapse uses SQLite in and doing so trades performance for convenienc
SQLite is only recommended in Synapse for testing purposes or for servers with
light workloads.
Almost all installations should opt to use PostreSQL. Advantages include:
Almost all installations should opt to use PostgreSQL. Advantages include:
* significant performance improvements due to the superior threading and
caching model, smarter query optimiser

1
changelog.d/7675.removal Normal file
View File

@ -0,0 +1 @@
Deprecate `m.login.jwt` login method in favour of `org.matrix.login.jwt`, as `m.login.jwt` is not part of the Matrix spec.

1
changelog.d/7717.bugfix Normal file
View File

@ -0,0 +1 @@
Fix the tables ignored by `synapse_port_db` to be in sync the current database schema.

1
changelog.d/7718.feature Normal file
View File

@ -0,0 +1 @@
Media can now be marked as safe from quarantined.

1
changelog.d/7724.doc Normal file
View File

@ -0,0 +1 @@
Corrected misspelling of PostgreSQL.

1
changelog.d/7725.misc Normal file
View File

@ -0,0 +1 @@
Speed up state res v2 across large state differences.

1
changelog.d/7727.misc Normal file
View File

@ -0,0 +1 @@
Convert directory handler to async/await.

1
changelog.d/7730.bugfix Normal file
View File

@ -0,0 +1 @@
Fix missing `Content-Length` on HTTP responses from the metrics handler.

1
changelog.d/7735.bugfix Normal file
View File

@ -0,0 +1 @@
Fix large state resolutions from stalling Synapse for seconds at a time.

View File

@ -89,6 +89,7 @@ BOOLEAN_COLUMNS = {
"account_validity": ["email_sent"],
"redactions": ["have_censored"],
"room_stats_state": ["is_federatable"],
"local_media_repository": ["safe_from_quarantine"],
}
@ -128,10 +129,20 @@ APPEND_ONLY_TABLES = [
IGNORED_TABLES = {
# We don't port these tables, as they're a faff and we can regenerate
# them anyway.
"user_directory",
"user_directory_search",
"users_who_share_rooms",
"users_in_pubic_room",
"user_directory_search_content",
"user_directory_search_docsize",
"user_directory_search_segdir",
"user_directory_search_segments",
"user_directory_search_stat",
"user_directory_search_pos",
"users_who_share_private_rooms",
"users_in_public_room",
# UI auth sessions have foreign keys so additional care needs to be taken,
# the sessions are transient anyway, so ignore them.
"ui_auth_sessions",
"ui_auth_sessions_credentials",
}
@ -300,8 +311,6 @@ class Porter(object):
return
if table in IGNORED_TABLES:
# We don't port these tables, as they're a faff and we can regenerate
# them anyway.
self.progress.update(table, table_size) # Mark table as done
return

View File

@ -17,8 +17,6 @@ import logging
import string
from typing import Iterable, List, Optional
from twisted.internet import defer
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
from synapse.api.errors import (
AuthError,
@ -55,8 +53,7 @@ class DirectoryHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks
def _create_association(
async def _create_association(
self,
room_alias: RoomAlias,
room_id: str,
@ -76,13 +73,13 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association.
if not servers:
users = yield self.state.get_current_users_in_room(room_id)
users = await self.state.get_current_users_in_room(room_id)
servers = {get_domain_from_id(u) for u in users}
if not servers:
raise SynapseError(400, "Failed to get server list")
yield self.store.create_room_alias_association(
await self.store.create_room_alias_association(
room_alias, room_id, servers, creator=creator
)
@ -93,7 +90,7 @@ class DirectoryHandler(BaseHandler):
room_id: str,
servers: Optional[List[str]] = None,
check_membership: bool = True,
):
) -> None:
"""Attempt to create a new alias
Args:
@ -103,9 +100,6 @@ class DirectoryHandler(BaseHandler):
servers: Iterable of servers that others servers should try and join via
check_membership: Whether to check if the user is in the room
before the alias can be set (if the server's config requires it).
Returns:
Deferred
"""
user_id = requester.user.to_string()
@ -148,7 +142,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to create alias")
can_create = await self.can_modify_alias(room_alias, user_id=user_id)
can_create = self.can_modify_alias(room_alias, user_id=user_id)
if not can_create:
raise AuthError(
400,
@ -158,7 +152,9 @@ class DirectoryHandler(BaseHandler):
await self._create_association(room_alias, room_id, servers, creator=user_id)
async def delete_association(self, requester: Requester, room_alias: RoomAlias):
async def delete_association(
self, requester: Requester, room_alias: RoomAlias
) -> str:
"""Remove an alias from the directory
(this is only meant for human users; AS users should call
@ -169,7 +165,7 @@ class DirectoryHandler(BaseHandler):
room_alias
Returns:
Deferred[unicode]: room id that the alias used to point to
room id that the alias used to point to
Raises:
NotFoundError: if the alias doesn't exist
@ -191,7 +187,7 @@ class DirectoryHandler(BaseHandler):
if not can_delete:
raise AuthError(403, "You don't have permission to delete the alias.")
can_delete = await self.can_modify_alias(room_alias, user_id=user_id)
can_delete = self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete:
raise SynapseError(
400,
@ -208,8 +204,7 @@ class DirectoryHandler(BaseHandler):
return room_id
@defer.inlineCallbacks
def delete_appservice_association(
async def delete_appservice_association(
self, service: ApplicationService, room_alias: RoomAlias
):
if not service.is_interested_in_alias(room_alias.to_string()):
@ -218,29 +213,27 @@ class DirectoryHandler(BaseHandler):
"This application service has not reserved this kind of alias",
errcode=Codes.EXCLUSIVE,
)
yield self._delete_association(room_alias)
await self._delete_association(room_alias)
@defer.inlineCallbacks
def _delete_association(self, room_alias: RoomAlias):
async def _delete_association(self, room_alias: RoomAlias):
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
room_id = yield self.store.delete_room_alias(room_alias)
room_id = await self.store.delete_room_alias(room_alias)
return room_id
@defer.inlineCallbacks
def get_association(self, room_alias: RoomAlias):
async def get_association(self, room_alias: RoomAlias):
room_id = None
if self.hs.is_mine(room_alias):
result = yield self.get_association_from_room_alias(room_alias)
result = await self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id
servers = result.servers
else:
try:
result = yield self.federation.make_query(
result = await self.federation.make_query(
destination=room_alias.domain,
query_type="directory",
args={"room_alias": room_alias.to_string()},
@ -265,7 +258,7 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND,
)
users = yield self.state.get_current_users_in_room(room_id)
users = await self.state.get_current_users_in_room(room_id)
extra_servers = {get_domain_from_id(u) for u in users}
servers = set(extra_servers) | set(servers)
@ -277,13 +270,12 @@ class DirectoryHandler(BaseHandler):
return {"room_id": room_id, "servers": servers}
@defer.inlineCallbacks
def on_directory_query(self, args):
async def on_directory_query(self, args):
room_alias = RoomAlias.from_string(args["room_alias"])
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room Alias is not hosted on this homeserver")
result = yield self.get_association_from_room_alias(room_alias)
result = await self.get_association_from_room_alias(room_alias)
if result is not None:
return {"room_id": result.room_id, "servers": result.servers}
@ -344,16 +336,15 @@ class DirectoryHandler(BaseHandler):
ratelimit=False,
)
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias: RoomAlias):
result = yield self.store.get_association_from_room_alias(room_alias)
async def get_association_from_room_alias(self, room_alias: RoomAlias):
result = await self.store.get_association_from_room_alias(room_alias)
if not result:
# Query AS to see if it exists
as_handler = self.appservice_handler
result = yield as_handler.query_room_alias_exists(room_alias)
result = await as_handler.query_room_alias_exists(room_alias)
return result
def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None):
def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None) -> bool:
# Any application service "interested" in an alias they are regexing on
# can modify the alias.
# Users can only modify the alias if ALL the interested services have
@ -366,12 +357,12 @@ class DirectoryHandler(BaseHandler):
for service in interested_services:
if user_id == service.sender:
# this user IS the app service so they can do whatever they like
return defer.succeed(True)
return True
elif service.is_exclusive_alias(alias.to_string()):
# another service has an exclusive lock on this alias.
return defer.succeed(False)
return False
# either no interested services, or no service with an exclusive lock
return defer.succeed(True)
return True
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
"""Determine whether a user can delete an alias.
@ -459,8 +450,7 @@ class DirectoryHandler(BaseHandler):
await self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks
def edit_published_appservice_room_list(
async def edit_published_appservice_room_list(
self, appservice_id: str, network_id: str, room_id: str, visibility: str
):
"""Add or remove a room from the appservice/network specific public
@ -475,7 +465,7 @@ class DirectoryHandler(BaseHandler):
if visibility not in ["public", "private"]:
raise SynapseError(400, "Invalid visibility setting")
yield self.store.set_room_is_public_appservice(
await self.store.set_room_is_public_appservice(
room_id, appservice_id, network_id, visibility == "public"
)

View File

@ -376,6 +376,7 @@ class FederationHandler(BaseHandler):
room_version = await self.store.get_room_version_id(room_id)
state_map = await resolve_events_with_store(
self.clock,
room_id,
room_version,
state_maps,

View File

@ -881,7 +881,9 @@ class EventCreationHandler(object):
"""
room_alias = RoomAlias.from_string(room_alias_str)
try:
mapping = yield directory_handler.get_association(room_alias)
mapping = yield defer.ensureDeferred(
directory_handler.get_association(room_alias)
)
except SynapseError as e:
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
if e.errcode == Codes.NOT_FOUND:

View File

@ -208,6 +208,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
raise
self.send_response(200)
self.send_header("Content-Type", CONTENT_TYPE_LATEST)
self.send_header("Content-Length", str(len(output)))
self.end_headers()
self.wfile.write(output)
@ -261,4 +262,6 @@ class MetricsResource(Resource):
def render_GET(self, request):
request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
return generate_latest(self.registry)
response = generate_latest(self.registry)
request.setHeader(b"Content-Length", str(len(response)))
return response

View File

@ -133,6 +133,8 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _update_badge(self):
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
yield self._send_badge(badge)

View File

@ -81,7 +81,8 @@ class LoginRestServlet(RestServlet):
CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token"
JWT_TYPE = "m.login.jwt"
JWT_TYPE = "org.matrix.login.jwt"
JWT_TYPE_DEPRECATED = "m.login.jwt"
def __init__(self, hs):
super(LoginRestServlet, self).__init__()
@ -116,6 +117,7 @@ class LoginRestServlet(RestServlet):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})
if self.cas_enabled:
# we advertise CAS for backwards compat, though MSC1721 renamed it
@ -149,6 +151,7 @@ class LoginRestServlet(RestServlet):
try:
if self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
result = await self.do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:

View File

@ -32,6 +32,7 @@ from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
@ -414,6 +415,7 @@ class StateHandler(object):
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
self.clock,
event.room_id,
room_version,
state_set_ids,
@ -516,6 +518,7 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
self.clock,
room_id,
room_version,
list(state_groups_ids.values()),
@ -589,6 +592,7 @@ def _make_state_cache_entry(new_state, state_groups_ids):
def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
state_sets: List[StateMap[str]],
@ -625,7 +629,7 @@ def resolve_events_with_store(
)
else:
return v2.resolve_events_with_store(
room_id, room_version, state_sets, event_map, state_res_store
clock, room_id, room_version, state_sets, event_map, state_res_store
)

View File

@ -27,12 +27,20 @@ from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.types import StateMap
from synapse.util import Clock
logger = logging.getLogger(__name__)
# We want to yield to the reactor occasionally during state res when dealing
# with large data sets, so that we don't exhaust the reactor. This is done by
# yielding to reactor during loops every N iterations.
_YIELD_AFTER_ITERATIONS = 100
@defer.inlineCallbacks
def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
state_sets: List[StateMap[str]],
@ -42,13 +50,11 @@ def resolve_events_with_store(
"""Resolves the state using the v2 state resolution algorithm
Args:
clock
room_id: the room we are working in
room_version: The room version
state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
@ -113,7 +119,7 @@ def resolve_events_with_store(
)
sorted_power_events = yield _reverse_topological_power_sort(
room_id, power_events, event_map, state_res_store, full_conflicted_set
clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
)
logger.debug("sorted %d power events", len(sorted_power_events))
@ -133,15 +139,16 @@ def resolve_events_with_store(
# OK, so we've now resolved the power events. Now sort the remaining
# events using the mainline of the resolved power level.
set_power_events = set(sorted_power_events)
leftover_events = [
ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events
ev_id for ev_id in full_conflicted_set if ev_id not in set_power_events
]
logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
leftover_events = yield _mainline_sort(
room_id, leftover_events, pl, event_map, state_res_store
clock, room_id, leftover_events, pl, event_map, state_res_store
)
logger.debug("resolving remaining events")
@ -316,12 +323,13 @@ def _add_event_and_auth_chain_to_graph(
@defer.inlineCallbacks
def _reverse_topological_power_sort(
room_id, event_ids, event_map, state_res_store, auth_diff
clock, room_id, event_ids, event_map, state_res_store, auth_diff
):
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
Args:
clock (Clock)
room_id (str): the room we are working in
event_ids (list[str]): The events to sort
event_map (dict[str,FrozenEvent])
@ -333,18 +341,28 @@ def _reverse_topological_power_sort(
"""
graph = {}
for event_id in event_ids:
for idx, event_id in enumerate(event_ids, start=1):
yield _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
)
# We yield occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
event_to_pl = {}
for event_id in graph:
for idx, event_id in enumerate(graph, start=1):
pl = yield _get_power_level_for_sender(
room_id, event_id, event_map, state_res_store
)
event_to_pl[event_id] = pl
# We yield occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
def _get_power_order(event_id):
ev = event_map[event_id]
pl = event_to_pl[event_id]
@ -422,12 +440,13 @@ def _iterative_auth_checks(
@defer.inlineCallbacks
def _mainline_sort(
room_id, event_ids, resolved_power_event_id, event_map, state_res_store
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
):
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Args:
clock (Clock)
room_id (str): room we're working in
event_ids (list[str]): Events to sort
resolved_power_event_id (str): The final resolved power level event ID
@ -437,8 +456,14 @@ def _mainline_sort(
Returns:
Deferred[list[str]]: The sorted list
"""
if not event_ids:
# It's possible for there to be no event IDs here to sort, so we can
# skip calculating the mainline in that case.
return []
mainline = []
pl = resolved_power_event_id
idx = 0
while pl:
mainline.append(pl)
pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
@ -452,17 +477,29 @@ def _mainline_sort(
pl = aid
break
# We yield occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
idx += 1
mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
event_ids = list(event_ids)
order_map = {}
for ev_id in event_ids:
for idx, ev_id in enumerate(event_ids, start=1):
depth = yield _get_mainline_depth_for_event(
event_map[ev_id], mainline_map, event_map, state_res_store
)
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
# We yield occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0:
yield clock.sleep(0)
event_ids.sort(key=lambda ev_id: order_map[ev_id])
return event_ids

View File

@ -81,6 +81,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_media",
)
def mark_local_media_as_safe(self, media_id: str):
"""Mark a local media as safe from quarantining."""
return self.db.simple_update_one(
table="local_media_repository",
keyvalues={"media_id": media_id},
updatevalues={"safe_from_quarantine": True},
desc="mark_local_media_as_safe",
)
def get_url_cache(self, url, ts):
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:

View File

@ -626,36 +626,10 @@ class RoomWorkerStore(SQLBaseStore):
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
# Now update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""",
((quarantined_by, media_id) for media_id in local_mxcs),
return self._quarantine_media_txn(
txn, local_mxcs, remote_mxcs, quarantined_by
)
txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin = ? AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs
),
)
total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
return self.db.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@ -805,17 +779,17 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
The total number of media items quarantined
"""
total_media_quarantined = 0
# Update all the tables to set the quarantined_by flag
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
WHERE media_id = ? AND safe_from_quarantine = ?
""",
((quarantined_by, media_id) for media_id in local_mxcs),
((quarantined_by, media_id, False) for media_id in local_mxcs),
)
# Note that a rowcount of -1 can be used to indicate no rows were affected.
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
txn.executemany(
"""
@ -825,9 +799,7 @@ class RoomWorkerStore(SQLBaseStore):
""",
((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
)
total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
return total_media_quarantined

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- The local_media_repository should have files which do not get quarantined,
-- e.g. files from sticker packs.
ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT FALSE;

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- The local_media_repository should have files which do not get quarantined,
-- e.g. files from sticker packs.
ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0;

View File

@ -220,6 +220,24 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
return hs
def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it."""
request, channel = self.make_request(
"GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id
),
)
def test_quarantine_media_requires_admin(self):
self.register_user("nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("nonadmin", "pass")
@ -292,24 +310,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
# Attempt to access the media
request, channel = self.make_request(
"GET",
server_name_and_media_id,
shorthand=False,
access_token=admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_name_and_media_id
),
)
self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
def test_quarantine_all_media_in_room(self, override_url_template=None):
self.register_user("room_admin", "pass", admin=True)
@ -371,45 +372,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
server_and_media_id_2 = mxc_2[6:]
# Test that we cannot download any of the media anymore
request, channel = self.make_request(
"GET",
server_and_media_id_1,
shorthand=False,
access_token=non_admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_1
),
)
request, channel = self.make_request(
"GET",
server_and_media_id_2,
shorthand=False,
access_token=non_admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_2
),
)
def test_quaraantine_all_media_in_room_deprecated_api_path(self):
def test_quarantine_all_media_in_room_deprecated_api_path(self):
# Perform the above test with the deprecated API path
self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s")
@ -449,24 +415,51 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
)
# Attempt to access each piece of media
request, channel = self.make_request(
"GET",
server_and_media_id_1,
shorthand=False,
access_token=non_admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
# Should be quarantined
self.assertEqual(
404,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
% server_and_media_id_1,
),
def test_cannot_quarantine_safe_media(self):
self.register_user("user_admin", "pass", admin=True)
admin_user_tok = self.login("user_admin", "pass")
non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
non_admin_user_tok = self.login("user_nonadmin", "pass")
# Upload some media
response_1 = self.helper.upload_media(
self.upload_resource, self.image_data, tok=non_admin_user_tok
)
response_2 = self.helper.upload_media(
self.upload_resource, self.image_data, tok=non_admin_user_tok
)
# Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:]
server_and_media_id_2 = response_2["content_uri"][6:]
# Mark the second item as safe from quarantine.
_, media_id_2 = server_and_media_id_2.split("/")
self.get_success(self.store.mark_local_media_as_safe(media_id_2))
# Quarantine all media by this user
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
non_admin_user
)
request, channel = self.make_request(
"POST", url.encode("ascii"), access_token=admin_user_tok,
)
self.render(request)
self.pump(1.0)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(
json.loads(channel.result["body"].decode("utf-8")),
{"num_quarantined": 1},
"Expected 1 quarantined item",
)
# Attempt to access each piece of media, the first should fail, the
# second should succeed.
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media
request, channel = self.make_request(
@ -478,12 +471,12 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
request.render(self.download_resource)
self.pump(1.0)
# Should be quarantined
# Shouldn't be quarantined
self.assertEqual(
404,
200,
int(channel.code),
msg=(
"Expected to receive a 404 on accessing quarantined media: %s"
"Expected to receive a 200 on accessing not-quarantined media: %s"
% server_and_media_id_2
),
)

View File

@ -526,7 +526,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
return jwt.encode(token, secret, "HS256").decode("ascii")
def jwt_login(self, *args):
params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
params = json.dumps(
{"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel
@ -568,7 +570,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.json_body["error"], "Invalid JWT")
def test_login_no_token(self):
params = json.dumps({"type": "m.login.jwt"})
params = json.dumps({"type": "org.matrix.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
self.assertEqual(channel.result["code"], b"401", channel.result)
@ -640,7 +642,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
return jwt.encode(token, secret, "RS256").decode("ascii")
def jwt_login(self, *args):
params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
params = json.dumps(
{"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
return channel

View File

@ -17,6 +17,8 @@ import itertools
import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
@ -41,6 +43,11 @@ MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
ORIGIN_SERVER_TS = 0
class FakeClock:
def sleep(self, msec):
return defer.succeed(None)
class FakeEvent(object):
"""A fake event we use as a convenience.
@ -417,6 +424,7 @@ class StateTestCase(unittest.TestCase):
state_before = dict(state_at_event[prev_events[0]])
else:
state_d = resolve_events_with_store(
FakeClock(),
ROOM_ID,
RoomVersions.V2.identifier,
[state_at_event[n] for n in prev_events],
@ -565,6 +573,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
# Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store(
FakeClock(),
ROOM_ID,
RoomVersions.V2.identifier,
[self.state_at_bob, self.state_at_charlie],