Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

matrix-org/fix_event_sig_checks
Erik Johnston 2018-07-05 10:29:32 +01:00
commit e2a4b7681e
17 changed files with 231 additions and 38 deletions

View File

@ -131,10 +131,14 @@ include the line in your commit or pull request comment::
Signed-off-by: Your Name <your@email.example.org> Signed-off-by: Your Name <your@email.example.org>
...using your real name; unfortunately pseudonyms and anonymous contributions We accept contributions under a legally identifiable name, such as
can't be accepted. Git makes this trivial - just use the -s flag when you do your name on government documentation or common-law names (names
``git commit``, having first set ``user.name`` and ``user.email`` git configs claimed by legitimate usage or repute). Unfortunately, we cannot
(which you should have done anyway :) accept anonymous contributions at this time.
Git allows you to add this signoff automatically when using the ``-s``
flag to ``git commit``, which uses the name and email set in your
``user.name`` and ``user.email`` git configs.
Conclusion Conclusion
~~~~~~~~~~ ~~~~~~~~~~

1
changelog.d/3385.misc Normal file
View File

@ -0,0 +1 @@
Reinstate lost run_on_reactor in unit tests

0
changelog.d/3467.misc Normal file
View File

1
changelog.d/3470.bugfix Normal file
View File

@ -0,0 +1 @@
Fix bug where synapse would explode when receiving unicode in HTTP User-Agent header

1
changelog.d/3473.bugfix Normal file
View File

@ -0,0 +1 @@
Invalidate cache on correct thread to avoid race

0
changelog.d/3474.misc Normal file
View File

1
changelog.d/3480.feature Normal file
View File

@ -0,0 +1 @@
Reject invalid server names in federation requests

1
changelog.d/3483.feature Normal file
View File

@ -0,0 +1 @@
Reject invalid server names in homeserver.yaml

View File

@ -16,6 +16,7 @@
import logging import logging
from synapse.http.endpoint import parse_and_validate_server_name
from ._base import Config, ConfigError from ._base import Config, ConfigError
logger = logging.Logger(__name__) logger = logging.Logger(__name__)
@ -25,6 +26,12 @@ class ServerConfig(Config):
def read_config(self, config): def read_config(self, config):
self.server_name = config["server_name"] self.server_name = config["server_name"]
try:
parse_and_validate_server_name(self.server_name)
except ValueError as e:
raise ConfigError(str(e))
self.pid_file = self.abspath(config.get("pid_file")) self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"] self.web_client = config["web_client"]
self.web_client_location = config.get("web_client_location", None) self.web_client_location = config.get("web_client_location", None)
@ -162,8 +169,8 @@ class ServerConfig(Config):
}) })
def default_config(self, server_name, **kwargs): def default_config(self, server_name, **kwargs):
if ":" in server_name: _, bind_port = parse_and_validate_server_name(server_name)
bind_port = int(server_name.split(":")[1]) if bind_port is not None:
unsecure_port = bind_port - 400 unsecure_port = bind_port - 400
else: else:
bind_port = 8448 bind_port = 8448

View File

@ -76,6 +76,7 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
return return
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
sender_domain = get_domain_from_id(event.sender)
room_id_domain = get_domain_from_id(event.room_id) room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain: if room_id_domain != sender_domain:
raise AuthError( raise AuthError(
@ -524,7 +525,11 @@ def _check_power_levels(event, auth_events):
"to your own" "to your own"
) )
if old_level > user_level or new_level > user_level: # Check if the old and new levels are greater than the user level
# (if defined)
old_level_too_big = old_level is not None and old_level > user_level
new_level_too_big = new_level is not None and new_level > user_level
if old_level_too_big or new_level_too_big:
raise AuthError( raise AuthError(
403, 403,
"You don't have permission to add ops level greater " "You don't have permission to add ops level greater "

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError, FederationDeniedError from synapse.api.errors import Codes, SynapseError, FederationDeniedError
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@ -99,26 +100,6 @@ class Authenticator(object):
origin = None origin = None
def parse_auth_header(header_str):
try:
params = auth.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
if value.startswith("\""):
return value[1:-1]
else:
return value
origin = strip_quotes(param_dict["origin"])
key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
except Exception:
raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers: if not auth_headers:
@ -127,8 +108,8 @@ class Authenticator(object):
) )
for auth in auth_headers: for auth in auth_headers:
if auth.startswith("X-Matrix"): if auth.startswith(b"X-Matrix"):
(origin, key, sig) = parse_auth_header(auth) (origin, key, sig) = _parse_auth_header(auth)
json_request["origin"] = origin json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig json_request["signatures"].setdefault(origin, {})[key] = sig
@ -165,6 +146,48 @@ class Authenticator(object):
logger.exception("Error resetting retry timings on %s", origin) logger.exception("Error resetting retry timings on %s", origin)
def _parse_auth_header(header_bytes):
"""Parse an X-Matrix auth header
Args:
header_bytes (bytes): header value
Returns:
Tuple[str, str, str]: origin, key id, signature.
Raises:
AuthenticationError if the header could not be parsed
"""
try:
header_str = header_bytes.decode('utf-8')
params = header_str.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
if value.startswith(b"\""):
return value[1:-1]
else:
return value
origin = strip_quotes(param_dict["origin"])
# ensure that the origin is a valid server name
parse_and_validate_server_name(origin)
key = strip_quotes(param_dict["key"])
sig = strip_quotes(param_dict["sig"])
return origin, key, sig
except Exception as e:
logger.warn(
"Error parsing auth header '%s': %s",
header_bytes.decode('ascii', 'replace'),
e,
)
raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED,
)
class BaseFederationServlet(object): class BaseFederationServlet(object):
REQUIRE_AUTH = True REQUIRE_AUTH = True

View File

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.error import ConnectError from twisted.internet.error import ConnectError
@ -38,6 +40,71 @@ _Server = collections.namedtuple(
) )
def parse_server_name(server_name):
"""Split a server name into host/port parts.
Args:
server_name (str): server name to parse
Returns:
Tuple[str, int|None]: host/port parts.
Raises:
ValueError if the server name could not be parsed.
"""
try:
if server_name[-1] == ']':
# ipv6 literal, hopefully
return server_name, None
domain_port = server_name.rsplit(":", 1)
domain = domain_port[0]
port = int(domain_port[1]) if domain_port[1:] else None
return domain, port
except Exception:
raise ValueError("Invalid server name '%s'" % server_name)
VALID_HOST_REGEX = re.compile(
"\\A[0-9a-zA-Z.-]+\\Z",
)
def parse_and_validate_server_name(server_name):
"""Split a server name into host/port parts and do some basic validation.
Args:
server_name (str): server name to parse
Returns:
Tuple[str, int|None]: host/port parts.
Raises:
ValueError if the server name could not be parsed.
"""
host, port = parse_server_name(server_name)
# these tests don't need to be bulletproof as we'll find out soon enough
# if somebody is giving us invalid data. What we *do* need is to be sure
# that nobody is sneaking IP literals in that look like hostnames, etc.
# look for ipv6 literals
if host[0] == '[':
if host[-1] != ']':
raise ValueError("Mismatched [...] in server name '%s'" % (
server_name,
))
return host, port
# otherwise it should only be alphanumerics.
if not VALID_HOST_REGEX.match(host):
raise ValueError("Server name '%s' contains invalid characters" % (
server_name,
))
return host, port
def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None, def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout=None): timeout=None):
"""Construct an endpoint for the given matrix destination. """Construct an endpoint for the given matrix destination.
@ -50,9 +117,7 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
timeout (int): connection timeout in seconds timeout (int): connection timeout in seconds
""" """
domain_port = destination.split(":") domain, port = parse_server_name(destination)
domain = domain_port[0]
port = int(domain_port[1]) if domain_port[1:] else None
endpoint_kw_args = {} endpoint_kw_args = {}

View File

@ -107,13 +107,28 @@ class SynapseRequest(Request):
end_time = time.time() end_time = time.time()
# need to decode as it could be raw utf-8 bytes
# from a IDN servname in an auth header
authenticated_entity = self.authenticated_entity
if authenticated_entity is not None:
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
# ...or could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically
# with maximum recursion trying to log errors about
# the charset problem.
# c.f. https://github.com/matrix-org/synapse/issues/3471
user_agent = self.get_user_agent()
if user_agent is not None:
user_agent = user_agent.decode("utf-8", "replace")
self.site.access_logger.info( self.site.access_logger.info(
"%s - %s - {%s}" "%s - %s - {%s}"
" Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" " Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" %sB %s \"%s %s %s\" \"%s\" [%d dbevts]", " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
self.authenticated_entity, authenticated_entity,
end_time - self.start_time, end_time - self.start_time,
ru_utime, ru_utime,
ru_stime, ru_stime,
@ -125,7 +140,7 @@ class SynapseRequest(Request):
self.method, self.method,
self.get_redacted_uri(), self.get_redacted_uri(),
self.clientproto, self.clientproto,
self.get_user_agent(), user_agent,
evt_db_fetch_count, evt_db_fetch_count,
) )

View File

@ -801,7 +801,8 @@ class EventsStore(EventsWorkerStore):
] ]
) )
self._curr_state_delta_stream_cache.entity_has_changed( txn.call_after(
self._curr_state_delta_stream_cache.entity_has_changed,
room_id, max_stream_order, room_id, max_stream_order,
) )

0
tests/http/__init__.py Normal file
View File

View File

@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.endpoint import (
parse_server_name,
parse_and_validate_server_name,
)
from tests import unittest
class ServerNameTestCase(unittest.TestCase):
def test_parse_server_name(self):
test_data = {
'localhost': ('localhost', None),
'my-example.com:1234': ('my-example.com', 1234),
'1.2.3.4': ('1.2.3.4', None),
'[0abc:1def::1234]': ('[0abc:1def::1234]', None),
'1.2.3.4:1': ('1.2.3.4', 1),
'[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080),
}
for i, o in test_data.items():
self.assertEqual(parse_server_name(i), o)
def test_validate_bad_server_names(self):
test_data = [
"", # empty
"localhost:http", # non-numeric port
"1234]", # smells like ipv6 literal but isn't
"[1234",
"underscore_.com",
"percent%65.com",
"1234:5678:80", # too many colons
]
for i in test_data:
try:
parse_and_validate_server_name(i)
self.fail(
"Expected parse_and_validate_server_name('%s') to throw" % (
i,
),
)
except ValueError:
pass

View File

@ -19,13 +19,19 @@ import logging
import mock import mock
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util import logcontext from synapse.util import logcontext
from twisted.internet import defer from twisted.internet import defer, reactor
from synapse.util.caches import descriptors from synapse.util.caches import descriptors
from tests import unittest from tests import unittest
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_on_reactor():
d = defer.Deferred()
reactor.callLater(0, d.callback, 0)
return logcontext.make_deferred_yieldable(d)
class CacheTestCase(unittest.TestCase): class CacheTestCase(unittest.TestCase):
def test_invalidate_all(self): def test_invalidate_all(self):
cache = descriptors.Cache("testcache") cache = descriptors.Cache("testcache")
@ -194,6 +200,8 @@ class DescriptorTestCase(unittest.TestCase):
def fn(self, arg1): def fn(self, arg1):
@defer.inlineCallbacks @defer.inlineCallbacks
def inner_fn(): def inner_fn():
# we want this to behave like an asynchronous function
yield run_on_reactor()
raise SynapseError(400, "blah") raise SynapseError(400, "blah")
return inner_fn() return inner_fn()
@ -203,7 +211,12 @@ class DescriptorTestCase(unittest.TestCase):
with logcontext.LoggingContext() as c1: with logcontext.LoggingContext() as c1:
c1.name = "c1" c1.name = "c1"
try: try:
yield obj.fn(1) d = obj.fn(1)
self.assertEqual(
logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel,
)
yield d
self.fail("No exception thrown") self.fail("No exception thrown")
except SynapseError: except SynapseError:
pass pass