Merge remote-tracking branch 'origin/develop' into erikj/type_server

pull/8060/head
Erik Johnston 2020-08-11 22:03:14 +01:00
commit fdb46b5442
19 changed files with 406 additions and 117 deletions

1
changelog.d/8040.misc Normal file
View File

@ -0,0 +1 @@
Change the default log config to reduce disk I/O and storage for new servers.

1
changelog.d/8050.misc Normal file
View File

@ -0,0 +1 @@
Reduce amount of outbound request logging at INFO level.

1
changelog.d/8051.misc Normal file
View File

@ -0,0 +1 @@
It is no longer necessary to explicitly define `filters` in the logging configuration. (Continuing to do so is redundant but harmless.)

1
changelog.d/8052.feature Normal file
View File

@ -0,0 +1 @@
Allow login to be blocked based on the values of SAML attributes.

1
changelog.d/8058.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `Notifier`.

View File

@ -4,16 +4,10 @@ formatters:
precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
filters:
context:
(): synapse.logging.context.LoggingContextFilter
request: ""
handlers:
console:
class: logging.StreamHandler
formatter: precise
filters: [context]
loggers:
synapse.storage.SQL:

View File

@ -1577,6 +1577,17 @@ saml2_config:
#
#grandfathered_mxid_source_attribute: upn
# It is possible to configure Synapse to only allow logins if SAML attributes
# match particular values. The requirements can be listed under
# `attribute_requirements` as shown below. All of the listed attributes must
# match for the login to be permitted.
#
#attribute_requirements:
# - attribute: userGroup
# value: "staff"
# - attribute: department
# value: "sales"
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#

View File

@ -11,24 +11,33 @@ formatters:
precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
filters:
context:
(): synapse.logging.context.LoggingContextFilter
request: ""
handlers:
file:
class: logging.handlers.RotatingFileHandler
class: logging.handlers.TimedRotatingFileHandler
formatter: precise
filename: /var/log/matrix-synapse/homeserver.log
maxBytes: 104857600
backupCount: 10
filters: [context]
when: midnight
backupCount: 3 # Does not include the current log file.
encoding: utf8
# Default to buffering writes to log file for efficiency. This means that
# will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR
# logs will still be flushed immediately.
buffer:
class: logging.handlers.MemoryHandler
target: file
# The capacity is the number of log lines that are buffered before
# being written to disk. Increasing this will lead to better
# performance, at the expensive of it taking longer for log lines to
# be written to disk.
capacity: 10
flushLevel: 30 # Flush for WARNING logs as well
# A handler that writes logs to stderr. Unused by default, but can be used
# instead of "buffer" and "file" in the logger handlers.
console:
class: logging.StreamHandler
formatter: precise
filters: [context]
loggers:
synapse.storage.SQL:
@ -36,8 +45,23 @@ loggers:
# information such as access tokens.
level: INFO
twisted:
# We send the twisted logging directly to the file handler,
# to work around https://github.com/matrix-org/synapse/issues/3471
# when using "buffer" logger. Use "console" to log to stderr instead.
handlers: [file]
propagate: false
root:
level: INFO
handlers: [file, console]
# Write logs to the `buffer` handler, which will buffer them together in memory,
# then write them to a file.
#
# Replace "buffer" with "console" to log to stderr instead. (Note that you'll
# also need to update the configuation for the `twisted` logger above, in
# this case.)
#
handlers: [buffer]
disable_existing_loggers: false

49
synapse/config/_util.py Normal file
View File

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
import jsonschema
from synapse.config._base import ConfigError
from synapse.types import JsonDict
def validate_config(json_schema: JsonDict, config: Any, config_path: List[str]) -> None:
"""Validates a config setting against a JsonSchema definition
This can be used to validate a section of the config file against a schema
definition. If the validation fails, a ConfigError is raised with a textual
description of the problem.
Args:
json_schema: the schema to validate against
config: the configuration value to be validated
config_path: the path within the config file. This will be used as a basis
for the error message.
"""
try:
jsonschema.validate(config, json_schema)
except jsonschema.ValidationError as e:
# copy `config_path` before modifying it.
path = list(config_path)
for p in list(e.path):
if isinstance(p, int):
path.append("<item %i>" % p)
else:
path.append(str(p))
raise ConfigError(
"Unable to parse configuration: %s at %s" % (e.message, ".".join(path))
)

View File

@ -55,24 +55,33 @@ formatters:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
%(request)s - %(message)s'
filters:
context:
(): synapse.logging.context.LoggingContextFilter
request: ""
handlers:
file:
class: logging.handlers.RotatingFileHandler
class: logging.handlers.TimedRotatingFileHandler
formatter: precise
filename: ${log_file}
maxBytes: 104857600
backupCount: 10
filters: [context]
when: midnight
backupCount: 3 # Does not include the current log file.
encoding: utf8
# Default to buffering writes to log file for efficiency. This means that
# will be a delay for INFO/DEBUG logs to get written, but WARNING/ERROR
# logs will still be flushed immediately.
buffer:
class: logging.handlers.MemoryHandler
target: file
# The capacity is the number of log lines that are buffered before
# being written to disk. Increasing this will lead to better
# performance, at the expensive of it taking longer for log lines to
# be written to disk.
capacity: 10
flushLevel: 30 # Flush for WARNING logs as well
# A handler that writes logs to stderr. Unused by default, but can be used
# instead of "buffer" and "file" in the logger handlers.
console:
class: logging.StreamHandler
formatter: precise
filters: [context]
loggers:
synapse.storage.SQL:
@ -80,9 +89,24 @@ loggers:
# information such as access tokens.
level: INFO
twisted:
# We send the twisted logging directly to the file handler,
# to work around https://github.com/matrix-org/synapse/issues/3471
# when using "buffer" logger. Use "console" to log to stderr instead.
handlers: [file]
propagate: false
root:
level: INFO
handlers: [file, console]
# Write logs to the `buffer` handler, which will buffer them together in memory,
# then write them to a file.
#
# Replace "buffer" with "console" to log to stderr instead. (Note that you'll
# also need to update the configuation for the `twisted` logger above, in
# this case.)
#
handlers: [buffer]
disable_existing_loggers: false
"""
@ -168,11 +192,26 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
handler = logging.StreamHandler()
handler.setFormatter(formatter)
handler.addFilter(LoggingContextFilter(request=""))
logger.addHandler(handler)
else:
logging.config.dictConfig(log_config)
# We add a log record factory that runs all messages through the
# LoggingContextFilter so that we get the context *at the time we log*
# rather than when we write to a handler. This can be done in config using
# filter options, but care must when using e.g. MemoryHandler to buffer
# writes.
log_filter = LoggingContextFilter(request="")
old_factory = logging.getLogRecordFactory()
def factory(*args, **kwargs):
record = old_factory(*args, **kwargs)
log_filter.filter(record)
return record
logging.setLogRecordFactory(factory)
# Route Twisted's native logging through to the standard library logging
# system.
observer = STDLibLogObserver()

View File

@ -15,7 +15,9 @@
# limitations under the License.
import logging
from typing import Any, List
import attr
import jinja2
import pkg_resources
@ -23,6 +25,7 @@ from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError
from ._util import validate_config
logger = logging.getLogger(__name__)
@ -80,6 +83,11 @@ class SAML2Config(Config):
self.saml2_enabled = True
attribute_requirements = saml2_config.get("attribute_requirements") or []
self.attribute_requirements = _parse_attribute_requirements_def(
attribute_requirements
)
self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
"grandfathered_mxid_source_attribute", "uid"
)
@ -341,6 +349,17 @@ class SAML2Config(Config):
#
#grandfathered_mxid_source_attribute: upn
# It is possible to configure Synapse to only allow logins if SAML attributes
# match particular values. The requirements can be listed under
# `attribute_requirements` as shown below. All of the listed attributes must
# match for the login to be permitted.
#
#attribute_requirements:
# - attribute: userGroup
# value: "staff"
# - attribute: department
# value: "sales"
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
@ -368,3 +387,34 @@ class SAML2Config(Config):
""" % {
"config_dir_path": config_dir_path
}
@attr.s(frozen=True)
class SamlAttributeRequirement:
"""Object describing a single requirement for SAML attributes."""
attribute = attr.ib(type=str)
value = attr.ib(type=str)
JSON_SCHEMA = {
"type": "object",
"properties": {"attribute": {"type": "string"}, "value": {"type": "string"}},
"required": ["attribute", "value"],
}
ATTRIBUTE_REQUIREMENTS_SCHEMA = {
"type": "array",
"items": SamlAttributeRequirement.JSON_SCHEMA,
}
def _parse_attribute_requirements_def(
attribute_requirements: Any,
) -> List[SamlAttributeRequirement]:
validate_config(
ATTRIBUTE_REQUIREMENTS_SCHEMA,
attribute_requirements,
config_path=["saml2_config", "attribute_requirements"],
)
return [SamlAttributeRequirement(**x) for x in attribute_requirements]

View File

@ -57,13 +57,10 @@ class EventStreamHandler(BaseHandler):
timeout=0,
as_client_event=True,
affect_presence=True,
only_keys=None,
room_id=None,
is_guest=False,
):
"""Fetches the events stream for a given user.
If `only_keys` is not None, events from keys will be sent down.
"""
if room_id:
@ -93,7 +90,6 @@ class EventStreamHandler(BaseHandler):
auth_user,
pagin_config,
timeout,
only_keys=only_keys,
is_guest=is_guest,
explicit_room_id=room_id,
)

View File

@ -14,15 +14,16 @@
# limitations under the License.
import logging
import re
from typing import Callable, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Callable, Dict, Optional, Set, Tuple
import attr
import saml2
import saml2.response
from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.api.errors import AuthError, SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
@ -34,6 +35,9 @@ from synapse.types import (
from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__)
@ -49,7 +53,7 @@ class Saml2SessionData:
class SamlHandler:
def __init__(self, hs):
def __init__(self, hs: "synapse.server.HomeServer"):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
@ -62,6 +66,7 @@ class SamlHandler:
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
)
self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
@ -73,7 +78,7 @@ class SamlHandler:
self._auth_provider_id = "saml"
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {}
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
@ -165,11 +170,18 @@ class SamlHandler:
saml2.BINDING_HTTP_POST,
outstanding=self._outstanding_requests_dict,
)
except saml2.response.UnsolicitedResponse as e:
# the pysaml2 library helpfully logs an ERROR here, but neglects to log
# the session ID. I don't really want to put the full text of the exception
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
raise SynapseError(400, "Unexpected SAML2 login.")
except Exception as e:
raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,))
raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
if saml2_auth.not_signed:
raise SynapseError(400, "SAML2 response was not signed")
raise SynapseError(400, "SAML2 response was not signed.")
logger.debug("SAML2 response: %s", saml2_auth.origxml)
for assertion in saml2_auth.assertions:
@ -188,6 +200,9 @@ class SamlHandler:
saml2_auth.in_response_to, None
)
for requirement in self._saml2_attribute_requirements:
_check_attribute_requirement(saml2_auth.ava, requirement)
remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
)
@ -294,6 +309,21 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid]
def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
values = ava.get(req.attribute, [])
for v in values:
if v == req.value:
return
logger.info(
"SAML2 attribute %s did not match required value '%s' (was '%s')",
req.attribute,
req.value,
values,
)
raise AuthError(403, "You are not authorized to log in here.")
DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)

View File

@ -297,7 +297,7 @@ class SimpleHttpClient(object):
outgoing_requests_counter.labels(method).inc()
# log request but strip `access_token` (AS requests for example include this)
logger.info("Sending request %s %s", method, redact_uri(uri))
logger.debug("Sending request %s %s", method, redact_uri(uri))
with start_active_span(
"outgoing-client-request",

View File

@ -247,7 +247,7 @@ class MatrixHostnameEndpoint(object):
port = server.port
try:
logger.info("Connecting to %s:%i", host.decode("ascii"), port)
logger.debug("Connecting to %s:%i", host.decode("ascii"), port)
endpoint = HostnameEndpoint(self._reactor, host, port)
if self._tls_options:
endpoint = wrapClientTLS(self._tls_options, endpoint)

View File

@ -29,10 +29,11 @@ from zope.interface import implementer
from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver
from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
import synapse.metrics
import synapse.util.retryutils
@ -74,7 +75,7 @@ MAXINT = sys.maxsize
_next_id = 1
@attr.s
@attr.s(frozen=True)
class MatrixFederationRequest(object):
method = attr.ib()
"""HTTP method
@ -110,26 +111,52 @@ class MatrixFederationRequest(object):
:type: str|None
"""
uri = attr.ib(init=False, type=bytes)
"""The URI of this request
"""
def __attrs_post_init__(self):
global _next_id
self.txn_id = "%s-O-%s" % (self.method, _next_id)
txn_id = "%s-O-%s" % (self.method, _next_id)
_next_id = (_next_id + 1) % (MAXINT - 1)
object.__setattr__(self, "txn_id", txn_id)
destination_bytes = self.destination.encode("ascii")
path_bytes = self.path.encode("ascii")
if self.query:
query_bytes = encode_query_args(self.query)
else:
query_bytes = b""
# The object is frozen so we can pre-compute this.
uri = urllib.parse.urlunparse(
(b"matrix", destination_bytes, path_bytes, None, query_bytes, b"")
)
object.__setattr__(self, "uri", uri)
def get_json(self):
if self.json_callback:
return self.json_callback()
return self.json
async def _handle_json_response(reactor, timeout_sec, request, response):
async def _handle_json_response(
reactor: IReactorTime,
timeout_sec: float,
request: MatrixFederationRequest,
response: IResponse,
start_ms: int,
):
"""
Reads the JSON body of a response, with a timeout
Args:
reactor (IReactor): twisted reactor, for the timeout
timeout_sec (float): number of seconds to wait for response to complete
request (MatrixFederationRequest): the request that triggered the response
response (IResponse): response to the request
reactor: twisted reactor, for the timeout
timeout_sec: number of seconds to wait for response to complete
request: the request that triggered the response
response: response to the request
start_ms: Timestamp when request was made
Returns:
dict: parsed JSON response
@ -143,23 +170,35 @@ async def _handle_json_response(reactor, timeout_sec, request, response):
body = await make_deferred_yieldable(d)
except TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response", request.txn_id, request.destination,
"{%s} [%s] Timed out reading response - %s %s",
request.txn_id,
request.destination,
request.method,
request.uri.decode("ascii"),
)
raise RequestSendFailed(e, can_retry=True) from e
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",
"{%s} [%s] Error reading response %s %s: %s",
request.txn_id,
request.destination,
request.method,
request.uri.decode("ascii"),
e,
)
raise
time_taken_secs = reactor.seconds() - start_ms / 1000
logger.info(
"{%s} [%s] Completed: %d %s",
"{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
time_taken_secs,
request.method,
request.uri.decode("ascii"),
)
return body
@ -261,7 +300,9 @@ class MatrixFederationHttpClient(object):
# 'M_UNRECOGNIZED' which some endpoints can return when omitting a
# trailing slash on Synapse <= v0.99.3.
logger.info("Retrying request with trailing slash")
request.path += "/"
# Request is frozen so we create a new instance
request = attr.evolve(request, path=request.path + "/")
response = await self._send_request(request, **send_request_args)
@ -373,9 +414,7 @@ class MatrixFederationHttpClient(object):
else:
retries_left = MAX_SHORT_RETRIES
url_bytes = urllib.parse.urlunparse(
(b"matrix", destination_bytes, path_bytes, None, query_bytes, b"")
)
url_bytes = request.uri
url_str = url_bytes.decode("ascii")
url_to_sign_bytes = urllib.parse.urlunparse(
@ -402,7 +441,7 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers
logger.info(
logger.debug(
"{%s} [%s] Sending request: %s %s; timeout %fs",
request.txn_id,
request.destination,
@ -436,7 +475,6 @@ class MatrixFederationHttpClient(object):
except DNSLookupError as e:
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e:
logger.info("Failed to send request: %s", e)
raise RequestSendFailed(e, can_retry=True) from e
incoming_responses_counter.labels(
@ -496,7 +534,7 @@ class MatrixFederationHttpClient(object):
break
except RequestSendFailed as e:
logger.warning(
logger.info(
"{%s} [%s] Request failed: %s %s: %s",
request.txn_id,
request.destination,
@ -654,6 +692,8 @@ class MatrixFederationHttpClient(object):
json=data,
)
start_ms = self.clock.time_msec()
response = await self._send_request_with_optional_trailing_slash(
request,
try_trailing_slash_on_400,
@ -664,7 +704,7 @@ class MatrixFederationHttpClient(object):
)
body = await _handle_json_response(
self.reactor, self.default_timeout, request, response
self.reactor, self.default_timeout, request, response, start_ms
)
return body
@ -720,6 +760,8 @@ class MatrixFederationHttpClient(object):
method="POST", destination=destination, path=path, query=args, json=data
)
start_ms = self.clock.time_msec()
response = await self._send_request(
request,
long_retries=long_retries,
@ -733,7 +775,7 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
body = await _handle_json_response(
self.reactor, _sec_timeout, request, response
self.reactor, _sec_timeout, request, response, start_ms,
)
return body
@ -786,6 +828,8 @@ class MatrixFederationHttpClient(object):
method="GET", destination=destination, path=path, query=args
)
start_ms = self.clock.time_msec()
response = await self._send_request_with_optional_trailing_slash(
request,
try_trailing_slash_on_400,
@ -796,7 +840,7 @@ class MatrixFederationHttpClient(object):
)
body = await _handle_json_response(
self.reactor, self.default_timeout, request, response
self.reactor, self.default_timeout, request, response, start_ms
)
return body
@ -846,6 +890,8 @@ class MatrixFederationHttpClient(object):
method="DELETE", destination=destination, path=path, query=args
)
start_ms = self.clock.time_msec()
response = await self._send_request(
request,
long_retries=long_retries,
@ -854,7 +900,7 @@ class MatrixFederationHttpClient(object):
)
body = await _handle_json_response(
self.reactor, self.default_timeout, request, response
self.reactor, self.default_timeout, request, response, start_ms
)
return body
@ -914,12 +960,14 @@ class MatrixFederationHttpClient(object):
)
raise
logger.info(
"{%s} [%s] Completed: %d %s [%d bytes]",
"{%s} [%s] Completed: %d %s [%d bytes] %s %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
length,
request.method,
request.uri.decode("ascii"),
)
return (length, headers)

View File

@ -15,7 +15,17 @@
import logging
from collections import namedtuple
from typing import Callable, Iterable, List, TypeVar
from typing import (
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
)
from prometheus_client import Counter
@ -24,12 +34,14 @@ from twisted.internet import defer
import synapse.server
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import StreamToken
from synapse.streams.config import PaginationConfig
from synapse.types import Collection, StreamToken, UserID
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@ -77,7 +89,13 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class.
"""
def __init__(self, user_id, rooms, current_token, time_now_ms):
def __init__(
self,
user_id: str,
rooms: Collection[str],
current_token: StreamToken,
time_now_ms: int,
):
self.user_id = user_id
self.rooms = set(rooms)
self.current_token = current_token
@ -93,13 +111,13 @@ class _NotifierUserStream(object):
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms):
def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
"""Notify any listeners for this user of a new event from an
event source.
Args:
stream_key(str): The stream the event came from.
stream_id(str): The new id for the stream the event came from.
time_now_ms(int): The current time in milliseconds.
stream_key: The stream the event came from.
stream_id: The new id for the stream the event came from.
time_now_ms: The current time in milliseconds.
"""
self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
self.last_notified_token = self.current_token
@ -112,7 +130,7 @@ class _NotifierUserStream(object):
self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token)
def remove(self, notifier):
def remove(self, notifier: "Notifier"):
""" Remove this listener from all the indexes in the Notifier
it knows about.
"""
@ -123,10 +141,10 @@ class _NotifierUserStream(object):
notifier.user_to_user_stream.pop(self.user_id)
def count_listeners(self):
def count_listeners(self) -> int:
return len(self.notify_deferred.observers())
def new_listener(self, token):
def new_listener(self, token: StreamToken) -> _NotificationListener:
"""Returns a deferred that is resolved when there is a new token
greater than the given token.
@ -159,14 +177,16 @@ class Notifier(object):
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs: "synapse.server.HomeServer"):
self.user_to_user_stream = {}
self.room_to_user_streams = {}
self.user_to_user_stream = {} # type: Dict[str, _NotifierUserStream]
self.room_to_user_streams = {} # type: Dict[str, Set[_NotifierUserStream]]
self.hs = hs
self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
self.pending_new_room_events = (
[]
) # type: List[Tuple[int, EventBase, Collection[str]]]
# Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]]
@ -178,10 +198,9 @@ class Notifier(object):
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
self.federation_sender = None
if hs.should_send_federation():
self.federation_sender = hs.get_federation_sender()
else:
self.federation_sender = None
self.state_handler = hs.get_state_handler()
@ -193,12 +212,12 @@ class Notifier(object):
# when rendering the metrics page, which is likely once per minute at
# most when scraping it.
def count_listeners():
all_user_streams = set()
all_user_streams = set() # type: Set[_NotifierUserStream]
for x in list(self.room_to_user_streams.values()):
all_user_streams |= x
for x in list(self.user_to_user_stream.values()):
all_user_streams.add(x)
for streams in list(self.room_to_user_streams.values()):
all_user_streams |= streams
for stream in list(self.user_to_user_stream.values()):
all_user_streams.add(stream)
return sum(stream.count_listeners() for stream in all_user_streams)
@ -223,7 +242,11 @@ class Notifier(object):
self.replication_callbacks.append(cb)
def on_new_room_event(
self, event, room_stream_id, max_room_stream_id, extra_users=[]
self,
event: EventBase,
room_stream_id: int,
max_room_stream_id: int,
extra_users: Collection[str] = [],
):
""" Used by handlers to inform the notifier something has happened
in the room, room event wise.
@ -241,11 +264,11 @@ class Notifier(object):
self.notify_replication()
def _notify_pending_new_room_events(self, max_room_stream_id):
def _notify_pending_new_room_events(self, max_room_stream_id: int):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
Args:
max_room_stream_id(int): The highest stream_id below which all
max_room_stream_id: The highest stream_id below which all
events have been persisted.
"""
pending = self.pending_new_room_events
@ -258,7 +281,9 @@ class Notifier(object):
else:
self._on_new_room_event(event, room_stream_id, extra_users)
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
def _on_new_room_event(
self, event: EventBase, room_stream_id: int, extra_users: Collection[str] = []
):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
run_as_background_process(
@ -275,13 +300,19 @@ class Notifier(object):
"room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
)
async def _notify_app_services(self, room_stream_id):
async def _notify_app_services(self, room_stream_id: int):
try:
await self.appservice_handler.notify_interested_services(room_stream_id)
except Exception:
logger.exception("Error notifying application services of event")
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
def on_new_event(
self,
stream_key: str,
new_token: int,
users: Collection[str] = [],
rooms: Collection[str] = [],
):
""" Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms.
@ -307,14 +338,19 @@ class Notifier(object):
self.notify_replication()
def on_new_replication_data(self):
def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
self.notify_replication()
async def wait_for_events(
self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
):
self,
user_id: str,
timeout: int,
callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
room_ids=None,
from_token=StreamToken.START,
) -> T:
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
@ -377,19 +413,16 @@ class Notifier(object):
async def get_events_for(
self,
user,
pagination_config,
timeout,
only_keys=None,
is_guest=False,
explicit_room_id=None,
):
user: UserID,
pagination_config: PaginationConfig,
timeout: int,
is_guest: bool = False,
explicit_room_id: str = None,
) -> EventStreamResult:
""" For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning.
If `only_keys` is not None, events from keys will be sent down.
If explicit_room_id is not set, the user's joined rooms will be polled
for events.
If explicit_room_id is set, that room will be polled for events only if
@ -404,11 +437,13 @@ class Notifier(object):
room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
is_peeking = not is_joined
async def check_for_updates(before_token, after_token):
async def check_for_updates(
before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult:
if not after_token.is_after(before_token):
return EventStreamResult([], (from_token, from_token))
events = []
events = [] # type: List[EventBase]
end_token = from_token
for name, source in self.event_sources.sources.items():
@ -417,8 +452,6 @@ class Notifier(object):
after_id = getattr(after_token, keyname)
if before_id == after_id:
continue
if only_keys and name not in only_keys:
continue
new_events, new_key = await source.get_new_events(
user=user,
@ -476,7 +509,9 @@ class Notifier(object):
return result
async def _get_room_ids(self, user, explicit_room_id):
async def _get_room_ids(
self, user: UserID, explicit_room_id: Optional[str]
) -> Tuple[Collection[str], bool]:
joined_room_ids = await self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
@ -486,7 +521,7 @@ class Notifier(object):
raise AuthError(403, "Non-joined access not allowed")
return joined_room_ids, True
async def _is_world_readable(self, room_id):
async def _is_world_readable(self, room_id: str) -> bool:
state = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
@ -496,7 +531,7 @@ class Notifier(object):
return False
@log_function
def remove_expired_streams(self):
def remove_expired_streams(self) -> None:
time_now_ms = self.clock.time_msec()
expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
@ -510,21 +545,21 @@ class Notifier(object):
expired_stream.remove(self)
@log_function
def _register_with_keys(self, user_stream):
def _register_with_keys(self, user_stream: _NotifierUserStream):
self.user_to_user_stream[user_stream.user_id] = user_stream
for room in user_stream.rooms:
s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream)
def _user_joined_room(self, user_id, room_id):
def _user_joined_room(self, user_id: str, room_id: str):
new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None:
room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream)
new_user_stream.rooms.add(room_id)
def notify_replication(self):
def notify_replication(self) -> None:
"""Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks:
cb()

View File

@ -2,10 +2,17 @@
<html lang="en">
<head>
<meta charset="UTF-8">
<title>SSO error</title>
<title>SSO login error</title>
</head>
<body>
<p>Oops! Something went wrong during authentication<span id="errormsg"></span>.</p>
{# a 403 means we have actively rejected their login #}
{% if code == 403 %}
<p>You are not allowed to log in here.</p>
{% else %}
<p>
There was an error during authentication:
</p>
<div id="errormsg" style="margin:20px 80px">{{ msg }}</div>
<p>
If you are seeing this page after clicking a link sent to you via email, make
sure you only click the confirmation link once, and that you open the
@ -37,9 +44,9 @@
// to print one.
let errorDesc = new URLSearchParams(searchStr).get("error_description")
if (errorDesc) {
document.getElementById("errormsg").innerText = ` ("${errorDesc}")`;
document.getElementById("errormsg").innerText = errorDesc;
}
</script>
{% endif %}
</body>
</html>
</html>

View File

@ -198,6 +198,7 @@ commands = mypy \
synapse/logging/ \
synapse/metrics \
synapse/module_api \
synapse/notifier.py \
synapse/push/pusherpool.py \
synapse/push/push_rule_evaluator.py \
synapse/replication \