Verify signatures for server2server requests
parent
10ef8e6e4b
commit
6684855767
|
@ -22,6 +22,7 @@ from .transport import TransportLayer
|
||||||
|
|
||||||
def initialize_http_replication(homeserver):
|
def initialize_http_replication(homeserver):
|
||||||
transport = TransportLayer(
|
transport = TransportLayer(
|
||||||
|
homeserver,
|
||||||
homeserver.hostname,
|
homeserver.hostname,
|
||||||
server=homeserver.get_resource_for_federation(),
|
server=homeserver.get_resource_for_federation(),
|
||||||
client=homeserver.get_http_client()
|
client=homeserver.get_http_client()
|
||||||
|
|
|
@ -54,7 +54,7 @@ class TransportLayer(object):
|
||||||
we receive data.
|
we receive data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, server_name, server, client):
|
def __init__(self, homeserver, server_name, server, client):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
server_name (str): Local home server host
|
server_name (str): Local home server host
|
||||||
|
@ -63,6 +63,7 @@ class TransportLayer(object):
|
||||||
client (synapse.protocol.http.HttpClient): the http client used to
|
client (synapse.protocol.http.HttpClient): the http client used to
|
||||||
send requests
|
send requests
|
||||||
"""
|
"""
|
||||||
|
self.keyring = homeserver.get_keyring()
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
self.server = server
|
self.server = server
|
||||||
self.client = client
|
self.client = client
|
||||||
|
@ -195,6 +196,66 @@ class TransportLayer(object):
|
||||||
|
|
||||||
defer.returnValue(response)
|
defer.returnValue(response)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _authenticate_request(self, request):
|
||||||
|
json_request = {
|
||||||
|
"method": request.method,
|
||||||
|
"uri": request.uri,
|
||||||
|
"destination": self.server_name,
|
||||||
|
"signatures": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
content = None
|
||||||
|
origin = None
|
||||||
|
|
||||||
|
if request.method == "PUT":
|
||||||
|
#TODO: Handle other method types? other content types?
|
||||||
|
content_bytes = request.content.read()
|
||||||
|
content = json.loads(content_bytes)
|
||||||
|
json_request["content"] = content
|
||||||
|
|
||||||
|
def parse_auth_header(header_str):
|
||||||
|
params = auth.split(" ")[1].split(",")
|
||||||
|
param_dict = dict(kv.split("=") for kv in params)
|
||||||
|
def strip_quotes(value):
|
||||||
|
if value.startswith("\""):
|
||||||
|
return value[1:-1]
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
origin = strip_quotes(param_dict["origin"])
|
||||||
|
key = strip_quotes(param_dict["key"])
|
||||||
|
sig = strip_quotes(param_dict["sig"])
|
||||||
|
return (origin, key, sig)
|
||||||
|
|
||||||
|
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||||
|
|
||||||
|
if not auth_headers:
|
||||||
|
#TODO(markjh): Send a 401 response?
|
||||||
|
raise Exception("Missing auth headers")
|
||||||
|
|
||||||
|
for auth in auth_headers:
|
||||||
|
if auth.startswith("X-Matrix"):
|
||||||
|
(origin, key, sig) = parse_auth_header(auth)
|
||||||
|
json_request["origin"] = origin
|
||||||
|
json_request["signatures"].setdefault(origin,{})[key] = sig
|
||||||
|
|
||||||
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
logger.debug("Checking %s %s",
|
||||||
|
origin, encode_canonical_json(json_request))
|
||||||
|
yield self.keyring.verify_json_for_server(origin, json_request)
|
||||||
|
|
||||||
|
defer.returnValue((origin, content))
|
||||||
|
|
||||||
|
def _with_authentication(self, handler):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def new_handler(request, *args, **kwargs):
|
||||||
|
(origin, content) = yield self._authenticate_request(request)
|
||||||
|
response = yield handler(
|
||||||
|
origin, content, request.args, *args, **kwargs
|
||||||
|
)
|
||||||
|
defer.returnValue(response)
|
||||||
|
return new_handler
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def register_received_handler(self, handler):
|
def register_received_handler(self, handler):
|
||||||
""" Register a handler that will be fired when we receive data.
|
""" Register a handler that will be fired when we receive data.
|
||||||
|
@ -208,7 +269,7 @@ class TransportLayer(object):
|
||||||
self.server.register_path(
|
self.server.register_path(
|
||||||
"PUT",
|
"PUT",
|
||||||
re.compile("^" + PREFIX + "/send/([^/]*)/$"),
|
re.compile("^" + PREFIX + "/send/([^/]*)/$"),
|
||||||
self._on_send_request
|
self._with_authentication(self._on_send_request)
|
||||||
)
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
|
@ -226,9 +287,9 @@ class TransportLayer(object):
|
||||||
self.server.register_path(
|
self.server.register_path(
|
||||||
"GET",
|
"GET",
|
||||||
re.compile("^" + PREFIX + "/pull/$"),
|
re.compile("^" + PREFIX + "/pull/$"),
|
||||||
lambda request: handler.on_pull_request(
|
self._with_authentication(
|
||||||
request.args["origin"][0],
|
lambda origin, content, query:
|
||||||
request.args["v"]
|
handler.on_pull_request(query["origin"][0], query["v"])
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -237,8 +298,9 @@ class TransportLayer(object):
|
||||||
self.server.register_path(
|
self.server.register_path(
|
||||||
"GET",
|
"GET",
|
||||||
re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
|
re.compile("^" + PREFIX + "/pdu/([^/]*)/([^/]*)/$"),
|
||||||
lambda request, pdu_origin, pdu_id: handler.on_pdu_request(
|
self._with_authentication(
|
||||||
pdu_origin, pdu_id
|
lambda origin, content, query, pdu_origin, pdu_id:
|
||||||
|
handler.on_pdu_request(pdu_origin, pdu_id)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -246,38 +308,47 @@ class TransportLayer(object):
|
||||||
self.server.register_path(
|
self.server.register_path(
|
||||||
"GET",
|
"GET",
|
||||||
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
|
re.compile("^" + PREFIX + "/state/([^/]*)/$"),
|
||||||
lambda request, context: handler.on_context_state_request(
|
self._with_authentication(
|
||||||
context
|
lambda origin, content, query, context:
|
||||||
|
handler.on_context_state_request(context)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.server.register_path(
|
self.server.register_path(
|
||||||
"GET",
|
"GET",
|
||||||
re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
|
re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
|
||||||
lambda request, context: self._on_backfill_request(
|
self._with_authentication(
|
||||||
context, request.args["v"],
|
lambda origin, content, query, context:
|
||||||
request.args["limit"]
|
self._on_backfill_request(
|
||||||
|
context, query["v"], query["limit"]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.server.register_path(
|
self.server.register_path(
|
||||||
"GET",
|
"GET",
|
||||||
re.compile("^" + PREFIX + "/context/([^/]*)/$"),
|
re.compile("^" + PREFIX + "/context/([^/]*)/$"),
|
||||||
lambda request, context: handler.on_context_pdus_request(context)
|
self._with_authentication(
|
||||||
|
lambda origin, content, query, context:
|
||||||
|
handler.on_context_pdus_request(context)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# This is when we receive a server-server Query
|
# This is when we receive a server-server Query
|
||||||
self.server.register_path(
|
self.server.register_path(
|
||||||
"GET",
|
"GET",
|
||||||
re.compile("^" + PREFIX + "/query/([^/]*)$"),
|
re.compile("^" + PREFIX + "/query/([^/]*)$"),
|
||||||
lambda request, query_type: handler.on_query_request(
|
self._with_authentication(
|
||||||
query_type, {k: v[0] for k, v in request.args.items()}
|
lambda origin, content, query, query_type:
|
||||||
|
handler.on_query_request(
|
||||||
|
query_type, {k: v[0] for k, v in query.items()}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def _on_send_request(self, request, transaction_id):
|
def _on_send_request(self, origin, content, query, transaction_id):
|
||||||
""" Called on PUT /send/<transaction_id>/
|
""" Called on PUT /send/<transaction_id>/
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -292,12 +363,7 @@ class TransportLayer(object):
|
||||||
"""
|
"""
|
||||||
# Parse the request
|
# Parse the request
|
||||||
try:
|
try:
|
||||||
data = request.content.read()
|
transaction_data = content
|
||||||
|
|
||||||
l = data[:20].encode("string_escape")
|
|
||||||
logger.debug("Got data: \"%s\"", l)
|
|
||||||
|
|
||||||
transaction_data = json.loads(data)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Decoded %s: %s",
|
"Decoded %s: %s",
|
||||||
|
|
|
@ -177,16 +177,20 @@ class MatrixHttpClient(BaseHttpClient):
|
||||||
|
|
||||||
request = sign_json(request, self.server_name, self.signing_key)
|
request = sign_json(request, self.server_name, self.signing_key)
|
||||||
|
|
||||||
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
logger.debug("Signing " + " " * 11 + "%s %s",
|
||||||
|
self.server_name, encode_canonical_json(request))
|
||||||
|
|
||||||
auth_headers = []
|
auth_headers = []
|
||||||
|
|
||||||
for key,sig in request["signatures"][self.server_name].items():
|
for key,sig in request["signatures"][self.server_name].items():
|
||||||
auth_headers.append(
|
auth_headers.append(bytes(
|
||||||
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
|
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
|
||||||
self.server_name, key, sig,
|
self.server_name, key, sig,
|
||||||
)
|
)
|
||||||
)
|
))
|
||||||
|
|
||||||
headers_dict["Authorization"] = auth_headers
|
headers_dict[b"Authorization"] = auth_headers
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def put_json(self, destination, path, data={}, json_data_callback=None):
|
def put_json(self, destination, path, data={}, json_data_callback=None):
|
||||||
|
|
|
@ -221,6 +221,7 @@ class FederationTestCase(unittest.TestCase):
|
||||||
json_data_callback=ANY,
|
json_data_callback=ANY,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_recv_edu(self):
|
def test_recv_edu(self):
|
||||||
recv_observer = Mock()
|
recv_observer = Mock()
|
||||||
|
|
|
@ -76,6 +76,9 @@ class MockHttpResource(HttpServer):
|
||||||
mock_content.configure_mock(**config)
|
mock_content.configure_mock(**config)
|
||||||
mock_request.content = mock_content
|
mock_request.content = mock_content
|
||||||
|
|
||||||
|
mock_request.method = http_method
|
||||||
|
mock_request.uri = path
|
||||||
|
|
||||||
# return the right path if the event requires it
|
# return the right path if the event requires it
|
||||||
mock_request.path = path
|
mock_request.path = path
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue