Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
commit
c16bb06d25
|
@ -195,7 +195,7 @@ By default Synapse uses SQLite in and doing so trades performance for convenienc
|
|||
SQLite is only recommended in Synapse for testing purposes or for servers with
|
||||
light workloads.
|
||||
|
||||
Almost all installations should opt to use PostreSQL. Advantages include:
|
||||
Almost all installations should opt to use PostgreSQL. Advantages include:
|
||||
|
||||
* significant performance improvements due to the superior threading and
|
||||
caching model, smarter query optimiser
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
Deprecate `m.login.jwt` login method in favour of `org.matrix.login.jwt`, as `m.login.jwt` is not part of the Matrix spec.
|
|
@ -0,0 +1 @@
|
|||
Fix the tables ignored by `synapse_port_db` to be in sync the current database schema.
|
|
@ -0,0 +1 @@
|
|||
Media can now be marked as safe from quarantined.
|
|
@ -0,0 +1 @@
|
|||
Corrected misspelling of PostgreSQL.
|
|
@ -0,0 +1 @@
|
|||
Speed up state res v2 across large state differences.
|
|
@ -0,0 +1 @@
|
|||
Convert directory handler to async/await.
|
|
@ -0,0 +1 @@
|
|||
Fix missing `Content-Length` on HTTP responses from the metrics handler.
|
|
@ -0,0 +1 @@
|
|||
Fix large state resolutions from stalling Synapse for seconds at a time.
|
|
@ -89,6 +89,7 @@ BOOLEAN_COLUMNS = {
|
|||
"account_validity": ["email_sent"],
|
||||
"redactions": ["have_censored"],
|
||||
"room_stats_state": ["is_federatable"],
|
||||
"local_media_repository": ["safe_from_quarantine"],
|
||||
}
|
||||
|
||||
|
||||
|
@ -128,10 +129,20 @@ APPEND_ONLY_TABLES = [
|
|||
|
||||
|
||||
IGNORED_TABLES = {
|
||||
# We don't port these tables, as they're a faff and we can regenerate
|
||||
# them anyway.
|
||||
"user_directory",
|
||||
"user_directory_search",
|
||||
"users_who_share_rooms",
|
||||
"users_in_pubic_room",
|
||||
"user_directory_search_content",
|
||||
"user_directory_search_docsize",
|
||||
"user_directory_search_segdir",
|
||||
"user_directory_search_segments",
|
||||
"user_directory_search_stat",
|
||||
"user_directory_search_pos",
|
||||
"users_who_share_private_rooms",
|
||||
"users_in_public_room",
|
||||
# UI auth sessions have foreign keys so additional care needs to be taken,
|
||||
# the sessions are transient anyway, so ignore them.
|
||||
"ui_auth_sessions",
|
||||
"ui_auth_sessions_credentials",
|
||||
}
|
||||
|
@ -300,8 +311,6 @@ class Porter(object):
|
|||
return
|
||||
|
||||
if table in IGNORED_TABLES:
|
||||
# We don't port these tables, as they're a faff and we can regenerate
|
||||
# them anyway.
|
||||
self.progress.update(table, table_size) # Mark table as done
|
||||
return
|
||||
|
||||
|
|
|
@ -17,8 +17,6 @@ import logging
|
|||
import string
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
|
@ -55,8 +53,7 @@ class DirectoryHandler(BaseHandler):
|
|||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_association(
|
||||
async def _create_association(
|
||||
self,
|
||||
room_alias: RoomAlias,
|
||||
room_id: str,
|
||||
|
@ -76,13 +73,13 @@ class DirectoryHandler(BaseHandler):
|
|||
# TODO(erikj): Add transactions.
|
||||
# TODO(erikj): Check if there is a current association.
|
||||
if not servers:
|
||||
users = yield self.state.get_current_users_in_room(room_id)
|
||||
users = await self.state.get_current_users_in_room(room_id)
|
||||
servers = {get_domain_from_id(u) for u in users}
|
||||
|
||||
if not servers:
|
||||
raise SynapseError(400, "Failed to get server list")
|
||||
|
||||
yield self.store.create_room_alias_association(
|
||||
await self.store.create_room_alias_association(
|
||||
room_alias, room_id, servers, creator=creator
|
||||
)
|
||||
|
||||
|
@ -93,7 +90,7 @@ class DirectoryHandler(BaseHandler):
|
|||
room_id: str,
|
||||
servers: Optional[List[str]] = None,
|
||||
check_membership: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
"""Attempt to create a new alias
|
||||
|
||||
Args:
|
||||
|
@ -103,9 +100,6 @@ class DirectoryHandler(BaseHandler):
|
|||
servers: Iterable of servers that others servers should try and join via
|
||||
check_membership: Whether to check if the user is in the room
|
||||
before the alias can be set (if the server's config requires it).
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
|
||||
user_id = requester.user.to_string()
|
||||
|
@ -148,7 +142,7 @@ class DirectoryHandler(BaseHandler):
|
|||
# per alias creation rule?
|
||||
raise SynapseError(403, "Not allowed to create alias")
|
||||
|
||||
can_create = await self.can_modify_alias(room_alias, user_id=user_id)
|
||||
can_create = self.can_modify_alias(room_alias, user_id=user_id)
|
||||
if not can_create:
|
||||
raise AuthError(
|
||||
400,
|
||||
|
@ -158,7 +152,9 @@ class DirectoryHandler(BaseHandler):
|
|||
|
||||
await self._create_association(room_alias, room_id, servers, creator=user_id)
|
||||
|
||||
async def delete_association(self, requester: Requester, room_alias: RoomAlias):
|
||||
async def delete_association(
|
||||
self, requester: Requester, room_alias: RoomAlias
|
||||
) -> str:
|
||||
"""Remove an alias from the directory
|
||||
|
||||
(this is only meant for human users; AS users should call
|
||||
|
@ -169,7 +165,7 @@ class DirectoryHandler(BaseHandler):
|
|||
room_alias
|
||||
|
||||
Returns:
|
||||
Deferred[unicode]: room id that the alias used to point to
|
||||
room id that the alias used to point to
|
||||
|
||||
Raises:
|
||||
NotFoundError: if the alias doesn't exist
|
||||
|
@ -191,7 +187,7 @@ class DirectoryHandler(BaseHandler):
|
|||
if not can_delete:
|
||||
raise AuthError(403, "You don't have permission to delete the alias.")
|
||||
|
||||
can_delete = await self.can_modify_alias(room_alias, user_id=user_id)
|
||||
can_delete = self.can_modify_alias(room_alias, user_id=user_id)
|
||||
if not can_delete:
|
||||
raise SynapseError(
|
||||
400,
|
||||
|
@ -208,8 +204,7 @@ class DirectoryHandler(BaseHandler):
|
|||
|
||||
return room_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_appservice_association(
|
||||
async def delete_appservice_association(
|
||||
self, service: ApplicationService, room_alias: RoomAlias
|
||||
):
|
||||
if not service.is_interested_in_alias(room_alias.to_string()):
|
||||
|
@ -218,29 +213,27 @@ class DirectoryHandler(BaseHandler):
|
|||
"This application service has not reserved this kind of alias",
|
||||
errcode=Codes.EXCLUSIVE,
|
||||
)
|
||||
yield self._delete_association(room_alias)
|
||||
await self._delete_association(room_alias)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _delete_association(self, room_alias: RoomAlias):
|
||||
async def _delete_association(self, room_alias: RoomAlias):
|
||||
if not self.hs.is_mine(room_alias):
|
||||
raise SynapseError(400, "Room alias must be local")
|
||||
|
||||
room_id = yield self.store.delete_room_alias(room_alias)
|
||||
room_id = await self.store.delete_room_alias(room_alias)
|
||||
|
||||
return room_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_association(self, room_alias: RoomAlias):
|
||||
async def get_association(self, room_alias: RoomAlias):
|
||||
room_id = None
|
||||
if self.hs.is_mine(room_alias):
|
||||
result = yield self.get_association_from_room_alias(room_alias)
|
||||
result = await self.get_association_from_room_alias(room_alias)
|
||||
|
||||
if result:
|
||||
room_id = result.room_id
|
||||
servers = result.servers
|
||||
else:
|
||||
try:
|
||||
result = yield self.federation.make_query(
|
||||
result = await self.federation.make_query(
|
||||
destination=room_alias.domain,
|
||||
query_type="directory",
|
||||
args={"room_alias": room_alias.to_string()},
|
||||
|
@ -265,7 +258,7 @@ class DirectoryHandler(BaseHandler):
|
|||
Codes.NOT_FOUND,
|
||||
)
|
||||
|
||||
users = yield self.state.get_current_users_in_room(room_id)
|
||||
users = await self.state.get_current_users_in_room(room_id)
|
||||
extra_servers = {get_domain_from_id(u) for u in users}
|
||||
servers = set(extra_servers) | set(servers)
|
||||
|
||||
|
@ -277,13 +270,12 @@ class DirectoryHandler(BaseHandler):
|
|||
|
||||
return {"room_id": room_id, "servers": servers}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_directory_query(self, args):
|
||||
async def on_directory_query(self, args):
|
||||
room_alias = RoomAlias.from_string(args["room_alias"])
|
||||
if not self.hs.is_mine(room_alias):
|
||||
raise SynapseError(400, "Room Alias is not hosted on this homeserver")
|
||||
|
||||
result = yield self.get_association_from_room_alias(room_alias)
|
||||
result = await self.get_association_from_room_alias(room_alias)
|
||||
|
||||
if result is not None:
|
||||
return {"room_id": result.room_id, "servers": result.servers}
|
||||
|
@ -344,16 +336,15 @@ class DirectoryHandler(BaseHandler):
|
|||
ratelimit=False,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_association_from_room_alias(self, room_alias: RoomAlias):
|
||||
result = yield self.store.get_association_from_room_alias(room_alias)
|
||||
async def get_association_from_room_alias(self, room_alias: RoomAlias):
|
||||
result = await self.store.get_association_from_room_alias(room_alias)
|
||||
if not result:
|
||||
# Query AS to see if it exists
|
||||
as_handler = self.appservice_handler
|
||||
result = yield as_handler.query_room_alias_exists(room_alias)
|
||||
result = await as_handler.query_room_alias_exists(room_alias)
|
||||
return result
|
||||
|
||||
def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None):
|
||||
def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None) -> bool:
|
||||
# Any application service "interested" in an alias they are regexing on
|
||||
# can modify the alias.
|
||||
# Users can only modify the alias if ALL the interested services have
|
||||
|
@ -366,12 +357,12 @@ class DirectoryHandler(BaseHandler):
|
|||
for service in interested_services:
|
||||
if user_id == service.sender:
|
||||
# this user IS the app service so they can do whatever they like
|
||||
return defer.succeed(True)
|
||||
return True
|
||||
elif service.is_exclusive_alias(alias.to_string()):
|
||||
# another service has an exclusive lock on this alias.
|
||||
return defer.succeed(False)
|
||||
return False
|
||||
# either no interested services, or no service with an exclusive lock
|
||||
return defer.succeed(True)
|
||||
return True
|
||||
|
||||
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
|
||||
"""Determine whether a user can delete an alias.
|
||||
|
@ -459,8 +450,7 @@ class DirectoryHandler(BaseHandler):
|
|||
|
||||
await self.store.set_room_is_public(room_id, making_public)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def edit_published_appservice_room_list(
|
||||
async def edit_published_appservice_room_list(
|
||||
self, appservice_id: str, network_id: str, room_id: str, visibility: str
|
||||
):
|
||||
"""Add or remove a room from the appservice/network specific public
|
||||
|
@ -475,7 +465,7 @@ class DirectoryHandler(BaseHandler):
|
|||
if visibility not in ["public", "private"]:
|
||||
raise SynapseError(400, "Invalid visibility setting")
|
||||
|
||||
yield self.store.set_room_is_public_appservice(
|
||||
await self.store.set_room_is_public_appservice(
|
||||
room_id, appservice_id, network_id, visibility == "public"
|
||||
)
|
||||
|
||||
|
|
|
@ -376,6 +376,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
room_version = await self.store.get_room_version_id(room_id)
|
||||
state_map = await resolve_events_with_store(
|
||||
self.clock,
|
||||
room_id,
|
||||
room_version,
|
||||
state_maps,
|
||||
|
|
|
@ -881,7 +881,9 @@ class EventCreationHandler(object):
|
|||
"""
|
||||
room_alias = RoomAlias.from_string(room_alias_str)
|
||||
try:
|
||||
mapping = yield directory_handler.get_association(room_alias)
|
||||
mapping = yield defer.ensureDeferred(
|
||||
directory_handler.get_association(room_alias)
|
||||
)
|
||||
except SynapseError as e:
|
||||
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
|
||||
if e.errcode == Codes.NOT_FOUND:
|
||||
|
|
|
@ -208,6 +208,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
|
|||
raise
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", CONTENT_TYPE_LATEST)
|
||||
self.send_header("Content-Length", str(len(output)))
|
||||
self.end_headers()
|
||||
self.wfile.write(output)
|
||||
|
||||
|
@ -261,4 +262,6 @@ class MetricsResource(Resource):
|
|||
|
||||
def render_GET(self, request):
|
||||
request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
|
||||
return generate_latest(self.registry)
|
||||
response = generate_latest(self.registry)
|
||||
request.setHeader(b"Content-Length", str(len(response)))
|
||||
return response
|
||||
|
|
|
@ -133,6 +133,8 @@ class HttpPusher(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _update_badge(self):
|
||||
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
|
||||
# to be largely redundant. perhaps we can remove it.
|
||||
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
|
||||
yield self._send_badge(badge)
|
||||
|
||||
|
|
|
@ -81,7 +81,8 @@ class LoginRestServlet(RestServlet):
|
|||
CAS_TYPE = "m.login.cas"
|
||||
SSO_TYPE = "m.login.sso"
|
||||
TOKEN_TYPE = "m.login.token"
|
||||
JWT_TYPE = "m.login.jwt"
|
||||
JWT_TYPE = "org.matrix.login.jwt"
|
||||
JWT_TYPE_DEPRECATED = "m.login.jwt"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(LoginRestServlet, self).__init__()
|
||||
|
@ -116,6 +117,7 @@ class LoginRestServlet(RestServlet):
|
|||
flows = []
|
||||
if self.jwt_enabled:
|
||||
flows.append({"type": LoginRestServlet.JWT_TYPE})
|
||||
flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})
|
||||
|
||||
if self.cas_enabled:
|
||||
# we advertise CAS for backwards compat, though MSC1721 renamed it
|
||||
|
@ -149,6 +151,7 @@ class LoginRestServlet(RestServlet):
|
|||
try:
|
||||
if self.jwt_enabled and (
|
||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
||||
):
|
||||
result = await self.do_jwt_login(login_submission)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
|
|
|
@ -32,6 +32,7 @@ from synapse.logging.utils import log_function
|
|||
from synapse.state import v1, v2
|
||||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
||||
from synapse.types import StateMap
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.metrics import Measure, measure_func
|
||||
|
@ -414,6 +415,7 @@ class StateHandler(object):
|
|||
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = yield resolve_events_with_store(
|
||||
self.clock,
|
||||
event.room_id,
|
||||
room_version,
|
||||
state_set_ids,
|
||||
|
@ -516,6 +518,7 @@ class StateResolutionHandler(object):
|
|||
logger.info("Resolving conflicted state for %r", room_id)
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = yield resolve_events_with_store(
|
||||
self.clock,
|
||||
room_id,
|
||||
room_version,
|
||||
list(state_groups_ids.values()),
|
||||
|
@ -589,6 +592,7 @@ def _make_state_cache_entry(new_state, state_groups_ids):
|
|||
|
||||
|
||||
def resolve_events_with_store(
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
state_sets: List[StateMap[str]],
|
||||
|
@ -625,7 +629,7 @@ def resolve_events_with_store(
|
|||
)
|
||||
else:
|
||||
return v2.resolve_events_with_store(
|
||||
room_id, room_version, state_sets, event_map, state_res_store
|
||||
clock, room_id, room_version, state_sets, event_map, state_res_store
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -27,12 +27,20 @@ from synapse.api.errors import AuthError
|
|||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import StateMap
|
||||
from synapse.util import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# We want to yield to the reactor occasionally during state res when dealing
|
||||
# with large data sets, so that we don't exhaust the reactor. This is done by
|
||||
# yielding to reactor during loops every N iterations.
|
||||
_YIELD_AFTER_ITERATIONS = 100
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_events_with_store(
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
state_sets: List[StateMap[str]],
|
||||
|
@ -42,13 +50,11 @@ def resolve_events_with_store(
|
|||
"""Resolves the state using the v2 state resolution algorithm
|
||||
|
||||
Args:
|
||||
clock
|
||||
room_id: the room we are working in
|
||||
|
||||
room_version: The room version
|
||||
|
||||
state_sets: List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
|
||||
event_map:
|
||||
a dict from event_id to event, for any events that we happen to
|
||||
have in flight (eg, those currently being persisted). This will be
|
||||
|
@ -113,7 +119,7 @@ def resolve_events_with_store(
|
|||
)
|
||||
|
||||
sorted_power_events = yield _reverse_topological_power_sort(
|
||||
room_id, power_events, event_map, state_res_store, full_conflicted_set
|
||||
clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
|
||||
)
|
||||
|
||||
logger.debug("sorted %d power events", len(sorted_power_events))
|
||||
|
@ -133,15 +139,16 @@ def resolve_events_with_store(
|
|||
# OK, so we've now resolved the power events. Now sort the remaining
|
||||
# events using the mainline of the resolved power level.
|
||||
|
||||
set_power_events = set(sorted_power_events)
|
||||
leftover_events = [
|
||||
ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events
|
||||
ev_id for ev_id in full_conflicted_set if ev_id not in set_power_events
|
||||
]
|
||||
|
||||
logger.debug("sorting %d remaining events", len(leftover_events))
|
||||
|
||||
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
|
||||
leftover_events = yield _mainline_sort(
|
||||
room_id, leftover_events, pl, event_map, state_res_store
|
||||
clock, room_id, leftover_events, pl, event_map, state_res_store
|
||||
)
|
||||
|
||||
logger.debug("resolving remaining events")
|
||||
|
@ -316,12 +323,13 @@ def _add_event_and_auth_chain_to_graph(
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _reverse_topological_power_sort(
|
||||
room_id, event_ids, event_map, state_res_store, auth_diff
|
||||
clock, room_id, event_ids, event_map, state_res_store, auth_diff
|
||||
):
|
||||
"""Returns a list of the event_ids sorted by reverse topological ordering,
|
||||
and then by power level and origin_server_ts
|
||||
|
||||
Args:
|
||||
clock (Clock)
|
||||
room_id (str): the room we are working in
|
||||
event_ids (list[str]): The events to sort
|
||||
event_map (dict[str,FrozenEvent])
|
||||
|
@ -333,18 +341,28 @@ def _reverse_topological_power_sort(
|
|||
"""
|
||||
|
||||
graph = {}
|
||||
for event_id in event_ids:
|
||||
for idx, event_id in enumerate(event_ids, start=1):
|
||||
yield _add_event_and_auth_chain_to_graph(
|
||||
graph, room_id, event_id, event_map, state_res_store, auth_diff
|
||||
)
|
||||
|
||||
# We yield occasionally when we're working with large data sets to
|
||||
# ensure that we don't block the reactor loop for too long.
|
||||
if idx % _YIELD_AFTER_ITERATIONS == 0:
|
||||
yield clock.sleep(0)
|
||||
|
||||
event_to_pl = {}
|
||||
for event_id in graph:
|
||||
for idx, event_id in enumerate(graph, start=1):
|
||||
pl = yield _get_power_level_for_sender(
|
||||
room_id, event_id, event_map, state_res_store
|
||||
)
|
||||
event_to_pl[event_id] = pl
|
||||
|
||||
# We yield occasionally when we're working with large data sets to
|
||||
# ensure that we don't block the reactor loop for too long.
|
||||
if idx % _YIELD_AFTER_ITERATIONS == 0:
|
||||
yield clock.sleep(0)
|
||||
|
||||
def _get_power_order(event_id):
|
||||
ev = event_map[event_id]
|
||||
pl = event_to_pl[event_id]
|
||||
|
@ -422,12 +440,13 @@ def _iterative_auth_checks(
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def _mainline_sort(
|
||||
room_id, event_ids, resolved_power_event_id, event_map, state_res_store
|
||||
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
|
||||
):
|
||||
"""Returns a sorted list of event_ids sorted by mainline ordering based on
|
||||
the given event resolved_power_event_id
|
||||
|
||||
Args:
|
||||
clock (Clock)
|
||||
room_id (str): room we're working in
|
||||
event_ids (list[str]): Events to sort
|
||||
resolved_power_event_id (str): The final resolved power level event ID
|
||||
|
@ -437,8 +456,14 @@ def _mainline_sort(
|
|||
Returns:
|
||||
Deferred[list[str]]: The sorted list
|
||||
"""
|
||||
if not event_ids:
|
||||
# It's possible for there to be no event IDs here to sort, so we can
|
||||
# skip calculating the mainline in that case.
|
||||
return []
|
||||
|
||||
mainline = []
|
||||
pl = resolved_power_event_id
|
||||
idx = 0
|
||||
while pl:
|
||||
mainline.append(pl)
|
||||
pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
|
||||
|
@ -452,17 +477,29 @@ def _mainline_sort(
|
|||
pl = aid
|
||||
break
|
||||
|
||||
# We yield occasionally when we're working with large data sets to
|
||||
# ensure that we don't block the reactor loop for too long.
|
||||
if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0:
|
||||
yield clock.sleep(0)
|
||||
|
||||
idx += 1
|
||||
|
||||
mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
|
||||
|
||||
event_ids = list(event_ids)
|
||||
|
||||
order_map = {}
|
||||
for ev_id in event_ids:
|
||||
for idx, ev_id in enumerate(event_ids, start=1):
|
||||
depth = yield _get_mainline_depth_for_event(
|
||||
event_map[ev_id], mainline_map, event_map, state_res_store
|
||||
)
|
||||
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
|
||||
|
||||
# We yield occasionally when we're working with large data sets to
|
||||
# ensure that we don't block the reactor loop for too long.
|
||||
if idx % _YIELD_AFTER_ITERATIONS == 0:
|
||||
yield clock.sleep(0)
|
||||
|
||||
event_ids.sort(key=lambda ev_id: order_map[ev_id])
|
||||
|
||||
return event_ids
|
||||
|
|
|
@ -81,6 +81,15 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
desc="store_local_media",
|
||||
)
|
||||
|
||||
def mark_local_media_as_safe(self, media_id: str):
|
||||
"""Mark a local media as safe from quarantining."""
|
||||
return self.db.simple_update_one(
|
||||
table="local_media_repository",
|
||||
keyvalues={"media_id": media_id},
|
||||
updatevalues={"safe_from_quarantine": True},
|
||||
desc="mark_local_media_as_safe",
|
||||
)
|
||||
|
||||
def get_url_cache(self, url, ts):
|
||||
"""Get the media_id and ts for a cached URL as of the given timestamp
|
||||
Returns:
|
||||
|
|
|
@ -626,36 +626,10 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
|
||||
def _quarantine_media_in_room_txn(txn):
|
||||
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
|
||||
total_media_quarantined = 0
|
||||
|
||||
# Now update all the tables to set the quarantined_by flag
|
||||
|
||||
txn.executemany(
|
||||
"""
|
||||
UPDATE local_media_repository
|
||||
SET quarantined_by = ?
|
||||
WHERE media_id = ?
|
||||
""",
|
||||
((quarantined_by, media_id) for media_id in local_mxcs),
|
||||
return self._quarantine_media_txn(
|
||||
txn, local_mxcs, remote_mxcs, quarantined_by
|
||||
)
|
||||
|
||||
txn.executemany(
|
||||
"""
|
||||
UPDATE remote_media_cache
|
||||
SET quarantined_by = ?
|
||||
WHERE media_origin = ? AND media_id = ?
|
||||
""",
|
||||
(
|
||||
(quarantined_by, origin, media_id)
|
||||
for origin, media_id in remote_mxcs
|
||||
),
|
||||
)
|
||||
|
||||
total_media_quarantined += len(local_mxcs)
|
||||
total_media_quarantined += len(remote_mxcs)
|
||||
|
||||
return total_media_quarantined
|
||||
|
||||
return self.db.runInteraction(
|
||||
"quarantine_media_in_room", _quarantine_media_in_room_txn
|
||||
)
|
||||
|
@ -805,17 +779,17 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
Returns:
|
||||
The total number of media items quarantined
|
||||
"""
|
||||
total_media_quarantined = 0
|
||||
|
||||
# Update all the tables to set the quarantined_by flag
|
||||
txn.executemany(
|
||||
"""
|
||||
UPDATE local_media_repository
|
||||
SET quarantined_by = ?
|
||||
WHERE media_id = ?
|
||||
WHERE media_id = ? AND safe_from_quarantine = ?
|
||||
""",
|
||||
((quarantined_by, media_id) for media_id in local_mxcs),
|
||||
((quarantined_by, media_id, False) for media_id in local_mxcs),
|
||||
)
|
||||
# Note that a rowcount of -1 can be used to indicate no rows were affected.
|
||||
total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0
|
||||
|
||||
txn.executemany(
|
||||
"""
|
||||
|
@ -825,9 +799,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
""",
|
||||
((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
|
||||
)
|
||||
|
||||
total_media_quarantined += len(local_mxcs)
|
||||
total_media_quarantined += len(remote_mxcs)
|
||||
total_media_quarantined += txn.rowcount if txn.rowcount > 0 else 0
|
||||
|
||||
return total_media_quarantined
|
||||
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- The local_media_repository should have files which do not get quarantined,
|
||||
-- e.g. files from sticker packs.
|
||||
ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT FALSE;
|
|
@ -0,0 +1,18 @@
|
|||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- The local_media_repository should have files which do not get quarantined,
|
||||
-- e.g. files from sticker packs.
|
||||
ALTER TABLE local_media_repository ADD COLUMN safe_from_quarantine BOOLEAN NOT NULL DEFAULT 0;
|
|
@ -220,6 +220,24 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
return hs
|
||||
|
||||
def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
|
||||
"""Ensure a piece of media is quarantined when trying to access it."""
|
||||
request, channel = self.make_request(
|
||||
"GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
|
||||
)
|
||||
request.render(self.download_resource)
|
||||
self.pump(1.0)
|
||||
|
||||
# Should be quarantined
|
||||
self.assertEqual(
|
||||
404,
|
||||
int(channel.code),
|
||||
msg=(
|
||||
"Expected to receive a 404 on accessing quarantined media: %s"
|
||||
% server_and_media_id
|
||||
),
|
||||
)
|
||||
|
||||
def test_quarantine_media_requires_admin(self):
|
||||
self.register_user("nonadmin", "pass", admin=False)
|
||||
non_admin_user_tok = self.login("nonadmin", "pass")
|
||||
|
@ -292,24 +310,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(200, int(channel.code), msg=channel.result["body"])
|
||||
|
||||
# Attempt to access the media
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
server_name_and_media_id,
|
||||
shorthand=False,
|
||||
access_token=admin_user_tok,
|
||||
)
|
||||
request.render(self.download_resource)
|
||||
self.pump(1.0)
|
||||
|
||||
# Should be quarantined
|
||||
self.assertEqual(
|
||||
404,
|
||||
int(channel.code),
|
||||
msg=(
|
||||
"Expected to receive a 404 on accessing quarantined media: %s"
|
||||
% server_name_and_media_id
|
||||
),
|
||||
)
|
||||
self._ensure_quarantined(admin_user_tok, server_name_and_media_id)
|
||||
|
||||
def test_quarantine_all_media_in_room(self, override_url_template=None):
|
||||
self.register_user("room_admin", "pass", admin=True)
|
||||
|
@ -371,45 +372,10 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||
server_and_media_id_2 = mxc_2[6:]
|
||||
|
||||
# Test that we cannot download any of the media anymore
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
server_and_media_id_1,
|
||||
shorthand=False,
|
||||
access_token=non_admin_user_tok,
|
||||
)
|
||||
request.render(self.download_resource)
|
||||
self.pump(1.0)
|
||||
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
|
||||
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
|
||||
|
||||
# Should be quarantined
|
||||
self.assertEqual(
|
||||
404,
|
||||
int(channel.code),
|
||||
msg=(
|
||||
"Expected to receive a 404 on accessing quarantined media: %s"
|
||||
% server_and_media_id_1
|
||||
),
|
||||
)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
server_and_media_id_2,
|
||||
shorthand=False,
|
||||
access_token=non_admin_user_tok,
|
||||
)
|
||||
request.render(self.download_resource)
|
||||
self.pump(1.0)
|
||||
|
||||
# Should be quarantined
|
||||
self.assertEqual(
|
||||
404,
|
||||
int(channel.code),
|
||||
msg=(
|
||||
"Expected to receive a 404 on accessing quarantined media: %s"
|
||||
% server_and_media_id_2
|
||||
),
|
||||
)
|
||||
|
||||
def test_quaraantine_all_media_in_room_deprecated_api_path(self):
|
||||
def test_quarantine_all_media_in_room_deprecated_api_path(self):
|
||||
# Perform the above test with the deprecated API path
|
||||
self.test_quarantine_all_media_in_room("/_synapse/admin/v1/quarantine_media/%s")
|
||||
|
||||
|
@ -449,24 +415,51 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
# Attempt to access each piece of media
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
server_and_media_id_1,
|
||||
shorthand=False,
|
||||
access_token=non_admin_user_tok,
|
||||
)
|
||||
request.render(self.download_resource)
|
||||
self.pump(1.0)
|
||||
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
|
||||
self._ensure_quarantined(admin_user_tok, server_and_media_id_2)
|
||||
|
||||
# Should be quarantined
|
||||
self.assertEqual(
|
||||
404,
|
||||
int(channel.code),
|
||||
msg=(
|
||||
"Expected to receive a 404 on accessing quarantined media: %s"
|
||||
% server_and_media_id_1,
|
||||
),
|
||||
def test_cannot_quarantine_safe_media(self):
|
||||
self.register_user("user_admin", "pass", admin=True)
|
||||
admin_user_tok = self.login("user_admin", "pass")
|
||||
|
||||
non_admin_user = self.register_user("user_nonadmin", "pass", admin=False)
|
||||
non_admin_user_tok = self.login("user_nonadmin", "pass")
|
||||
|
||||
# Upload some media
|
||||
response_1 = self.helper.upload_media(
|
||||
self.upload_resource, self.image_data, tok=non_admin_user_tok
|
||||
)
|
||||
response_2 = self.helper.upload_media(
|
||||
self.upload_resource, self.image_data, tok=non_admin_user_tok
|
||||
)
|
||||
|
||||
# Extract media IDs
|
||||
server_and_media_id_1 = response_1["content_uri"][6:]
|
||||
server_and_media_id_2 = response_2["content_uri"][6:]
|
||||
|
||||
# Mark the second item as safe from quarantine.
|
||||
_, media_id_2 = server_and_media_id_2.split("/")
|
||||
self.get_success(self.store.mark_local_media_as_safe(media_id_2))
|
||||
|
||||
# Quarantine all media by this user
|
||||
url = "/_synapse/admin/v1/user/%s/media/quarantine" % urllib.parse.quote(
|
||||
non_admin_user
|
||||
)
|
||||
request, channel = self.make_request(
|
||||
"POST", url.encode("ascii"), access_token=admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.pump(1.0)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(
|
||||
json.loads(channel.result["body"].decode("utf-8")),
|
||||
{"num_quarantined": 1},
|
||||
"Expected 1 quarantined item",
|
||||
)
|
||||
|
||||
# Attempt to access each piece of media, the first should fail, the
|
||||
# second should succeed.
|
||||
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
|
||||
|
||||
# Attempt to access each piece of media
|
||||
request, channel = self.make_request(
|
||||
|
@ -478,12 +471,12 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||
request.render(self.download_resource)
|
||||
self.pump(1.0)
|
||||
|
||||
# Should be quarantined
|
||||
# Shouldn't be quarantined
|
||||
self.assertEqual(
|
||||
404,
|
||||
200,
|
||||
int(channel.code),
|
||||
msg=(
|
||||
"Expected to receive a 404 on accessing quarantined media: %s"
|
||||
"Expected to receive a 200 on accessing not-quarantined media: %s"
|
||||
% server_and_media_id_2
|
||||
),
|
||||
)
|
||||
|
|
|
@ -526,7 +526,9 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|||
return jwt.encode(token, secret, "HS256").decode("ascii")
|
||||
|
||||
def jwt_login(self, *args):
|
||||
params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
|
||||
params = json.dumps(
|
||||
{"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|
||||
)
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
self.render(request)
|
||||
return channel
|
||||
|
@ -568,7 +570,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEqual(channel.json_body["error"], "Invalid JWT")
|
||||
|
||||
def test_login_no_token(self):
|
||||
params = json.dumps({"type": "m.login.jwt"})
|
||||
params = json.dumps({"type": "org.matrix.login.jwt"})
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
|
@ -640,7 +642,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
|
|||
return jwt.encode(token, secret, "RS256").decode("ascii")
|
||||
|
||||
def jwt_login(self, *args):
|
||||
params = json.dumps({"type": "m.login.jwt", "token": self.jwt_encode(*args)})
|
||||
params = json.dumps(
|
||||
{"type": "org.matrix.login.jwt", "token": self.jwt_encode(*args)}
|
||||
)
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
self.render(request)
|
||||
return channel
|
||||
|
|
|
@ -17,6 +17,8 @@ import itertools
|
|||
|
||||
import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.event_auth import auth_types_for_event
|
||||
|
@ -41,6 +43,11 @@ MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
|
|||
ORIGIN_SERVER_TS = 0
|
||||
|
||||
|
||||
class FakeClock:
|
||||
def sleep(self, msec):
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
class FakeEvent(object):
|
||||
"""A fake event we use as a convenience.
|
||||
|
||||
|
@ -417,6 +424,7 @@ class StateTestCase(unittest.TestCase):
|
|||
state_before = dict(state_at_event[prev_events[0]])
|
||||
else:
|
||||
state_d = resolve_events_with_store(
|
||||
FakeClock(),
|
||||
ROOM_ID,
|
||||
RoomVersions.V2.identifier,
|
||||
[state_at_event[n] for n in prev_events],
|
||||
|
@ -565,6 +573,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
|
|||
# Test that we correctly handle passing `None` as the event_map
|
||||
|
||||
state_d = resolve_events_with_store(
|
||||
FakeClock(),
|
||||
ROOM_ID,
|
||||
RoomVersions.V2.identifier,
|
||||
[self.state_at_bob, self.state_at_charlie],
|
||||
|
|
Loading…
Reference in New Issue