Merge pull request #1624 from matrix-org/kegan/idempotent-requests

Store Promise<Response> instead of Response for HTTP API transactions
pull/1628/head
Kegsay 2016-11-14 12:45:30 +00:00 committed by GitHub
commit 9355a5c42b
7 changed files with 202 additions and 184 deletions

View File

@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
logger = logging.getLogger(__name__)
def get_transaction_key(request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
Args:
request (twisted.web.http.Request): The incoming request. Must
contain an access_token.
Returns:
str: A transaction key
"""
token = get_access_token_from_request(request)
return request.path + "/" + token
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
def __init__(self, clock):
self.clock = clock
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
# Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
See:
fetch_or_execute
"""
return self.fetch_or_execute(
get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
"""Fetches the response for this transaction, or executes the given function
to produce a response for this transaction.
Args:
txn_key (str): A key to ensure idempotency should fetch_or_execute be
called again at a later point in time.
fn (function): A function which returns a tuple of
(response_code, response_dict).
*args: Arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn.
Returns:
Deferred which resolves to a tuple of (response_code, response_dict).
"""
try:
return self.transactions[txn_key][0].observe()
except (KeyError, IndexError):
pass # execute the function instead.
deferred = fn(*args, **kwargs)
observable = ObservableDeferred(deferred)
self.transactions[txn_key] = (observable, self.clock.time_msec())
return observable.observe()
def _cleanup(self):
now = self.clock.time_msec()
for key in self.transactions.keys():
ts = self.transactions[key][1]
if now > (ts + CLEANUP_PERIOD_MS): # after cleanup period
del self.transactions[key]

View File

@ -18,7 +18,8 @@
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.api.urls import CLIENT_PREFIX from synapse.api.urls import CLIENT_PREFIX
from .transactions import HttpTransactionStore from synapse.rest.client.transactions import HttpTransactionCache
import re import re
import logging import logging
@ -59,4 +60,4 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs self.hs = hs
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth() self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore() self.txns = HttpTransactionCache(hs.get_clock())

View File

@ -53,19 +53,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
client_path_patterns("/createRoom(?:/.*)?$"), client_path_patterns("/createRoom(?:/.*)?$"),
self.on_OPTIONS) self.on_OPTIONS)
@defer.inlineCallbacks
def on_PUT(self, request, txn_id): def on_PUT(self, request, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self.on_POST, request
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
response = yield self.on_POST(request)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -214,19 +205,10 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def on_GET(self, request, room_id, event_type, txn_id): def on_GET(self, request, room_id, event_type, txn_id):
return (200, "Not implemented") return (200, "Not implemented")
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, txn_id): def on_PUT(self, request, room_id, event_type, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self.on_POST, request, room_id, event_type, txn_id
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
response = yield self.on_POST(request, room_id, event_type, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
@ -283,19 +265,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
defer.returnValue((200, {"room_id": room_id})) defer.returnValue((200, {"room_id": room_id}))
@defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id): def on_PUT(self, request, room_identifier, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self.on_POST, request, room_identifier, txn_id
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
response = yield self.on_POST(request, room_identifier, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing # TODO: Needs unit testing
@ -537,21 +510,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, txn_id): def on_PUT(self, request, room_id, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self.on_POST, request, room_id, txn_id
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
response = yield self.on_POST(
request, room_id, txn_id
)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing # TODO: Needs unit testing
@ -623,21 +585,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
return False return False
return True return True
@defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id): def on_PUT(self, request, room_id, membership_action, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self.on_POST, request, room_id, membership_action, txn_id
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
response = yield self.on_POST(
request, room_id, membership_action, txn_id
)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet):
@ -669,19 +620,10 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
defer.returnValue((200, {"event_id": event.event_id})) defer.returnValue((200, {"event_id": event.event_id}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_id, txn_id): def on_PUT(self, request, room_id, event_id, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self.on_POST, request, room_id, event_id, txn_id
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
response = yield self.on_POST(request, room_id, event_id, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
class RoomTypingRestServlet(ClientV1RestServlet): class RoomTypingRestServlet(ClientV1RestServlet):

View File

@ -1,97 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
from synapse.api.auth import get_access_token_from_request
logger = logging.getLogger(__name__)
# FIXME: elsewhere we use FooStore to indicate something in the storage layer...
class HttpTransactionStore(object):
def __init__(self):
# { key : (txn_id, response) }
self.transactions = {}
def get_response(self, key, txn_id):
"""Retrieve a response for this request.
Args:
key (str): A transaction-independent key for this request. Usually
this is a combination of the path (without the transaction id)
and the user's access token.
txn_id (str): The transaction ID for this request
Returns:
A tuple of (HTTP response code, response content) or None.
"""
try:
logger.debug("get_response TxnId: %s", txn_id)
(last_txn_id, response) = self.transactions[key]
if txn_id == last_txn_id:
logger.info("get_response: Returning a response for %s", txn_id)
return response
except KeyError:
pass
return None
def store_response(self, key, txn_id, response):
"""Stores an HTTP response tuple.
Args:
key (str): A transaction-independent key for this request. Usually
this is a combination of the path (without the transaction id)
and the user's access token.
txn_id (str): The transaction ID for this request.
response (tuple): A tuple of (HTTP response code, response content)
"""
logger.debug("store_response TxnId: %s", txn_id)
self.transactions[key] = (txn_id, response)
def store_client_transaction(self, request, txn_id, response):
"""Stores the request/response pair of an HTTP transaction.
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
response (tuple): A tuple of (response code, response dict)
txn_id (str): The transaction ID for this request.
"""
self.store_response(self._get_key(request), txn_id, response)
def get_client_transaction(self, request, txn_id):
"""Retrieves a stored response if there was one.
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
txn_id (str): The transaction ID for this request.
Returns:
The response tuple.
Raises:
KeyError if the transaction was not found.
"""
response = self.get_response(self._get_key(request), txn_id)
if response is None:
raise KeyError("Transaction not found.")
return response
def _get_key(self, request):
token = get_access_token_from_request(request)
path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token

View File

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.http import servlet from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.v1.transactions import HttpTransactionStore from synapse.rest.client.transactions import HttpTransactionCache
from ._base import client_v2_patterns from ._base import client_v2_patterns
@ -40,18 +40,16 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__() super(SendToDeviceRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.txns = HttpTransactionStore() self.txns = HttpTransactionCache(hs.get_clock())
self.device_message_handler = hs.get_device_message_handler() self.device_message_handler = hs.get_device_message_handler()
@defer.inlineCallbacks
def on_PUT(self, request, message_type, txn_id): def on_PUT(self, request, message_type, txn_id):
try: return self.txns.fetch_or_execute_request(
defer.returnValue( request, self._put, request, message_type, txn_id
self.txns.get_client_transaction(request, txn_id)
) )
except KeyError:
pass
@defer.inlineCallbacks
def _put(self, request, message_type, txn_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -63,7 +61,6 @@ class SendToDeviceRestServlet(servlet.RestServlet):
) )
response = (200, {}) response = (200, {})
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response) defer.returnValue(response)

View File

@ -34,7 +34,7 @@ class Clock(object):
"""A small utility that obtains current time-of-day so that time may be """A small utility that obtains current time-of-day so that time may be
mocked during unit-tests. mocked during unit-tests.
TODO(paul): Also move the sleep() functionallity into it TODO(paul): Also move the sleep() functionality into it
""" """
def time(self): def time(self):
@ -46,6 +46,14 @@ class Clock(object):
return int(self.time() * 1000) return int(self.time() * 1000)
def looping_call(self, f, msec): def looping_call(self, f, msec):
"""Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time.
Args:
f(function): The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds.
"""
l = task.LoopingCall(f) l = task.LoopingCall(f)
l.start(msec / 1000.0, now=False) l.start(msec / 1000.0, now=False)
return l return l

View File

@ -0,0 +1,69 @@
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
from twisted.internet import defer
from mock import Mock, call
from tests import unittest
from tests.utils import MockClock
class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self):
self.clock = MockClock()
self.cache = HttpTransactionCache(self.clock)
self.mock_http_response = (200, "GOOD JOB!")
self.mock_key = "foo"
@defer.inlineCallbacks
def test_executes_given_function(self):
cb = Mock(
return_value=defer.succeed(self.mock_http_response)
)
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg"
)
cb.assert_called_once_with("some_arg", keyword="arg")
self.assertEqual(res, self.mock_http_response)
@defer.inlineCallbacks
def test_deduplicates_based_on_key(self):
cb = Mock(
return_value=defer.succeed(self.mock_http_response)
)
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
)
self.assertEqual(res, self.mock_http_response)
# expect only a single call to do the work
cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0)
@defer.inlineCallbacks
def test_cleans_up(self):
cb = Mock(
return_value=defer.succeed(self.mock_http_response)
)
yield self.cache.fetch_or_execute(
self.mock_key, cb, "an arg"
)
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
yield self.cache.fetch_or_execute(
self.mock_key, cb, "an arg"
)
# still using cache
cb.assert_called_once_with("an arg")
self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
yield self.cache.fetch_or_execute(
self.mock_key, cb, "an arg"
)
# no longer using cache
self.assertEqual(cb.call_count, 2)
self.assertEqual(
cb.call_args_list,
[call("an arg",), call("an arg",)]
)