Remove some boilerplate in tests (#4156)

pull/4161/head
Amber Brown 2018-11-07 03:00:00 +11:00 committed by GitHub
parent 0f5e51f726
commit e62f7f17b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 163 additions and 217 deletions

1
changelog.d/4156.misc Normal file
View File

@ -0,0 +1 @@
HTTP tests have been refactored to contain less boilerplate.

View File

@ -19,24 +19,17 @@ import json
from mock import Mock from mock import Mock
from synapse.http.server import JsonResource
from synapse.rest.client.v1.admin import register_servlets from synapse.rest.client.v1.admin import register_servlets
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)
class UserRegisterTestCase(unittest.TestCase): class UserRegisterTestCase(unittest.HomeserverTestCase):
def setUp(self):
servlets = [register_servlets]
def make_homeserver(self, reactor, clock):
self.clock = ThreadedMemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.url = "/_matrix/client/r0/admin/register" self.url = "/_matrix/client/r0/admin/register"
self.registration_handler = Mock() self.registration_handler = Mock()
@ -50,17 +43,14 @@ class UserRegisterTestCase(unittest.TestCase):
self.secrets = Mock() self.secrets = Mock()
self.hs = setup_test_homeserver( self.hs = self.setup_test_homeserver()
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
)
self.hs.config.registration_shared_secret = u"shared" self.hs.config.registration_shared_secret = u"shared"
self.hs.get_media_repository = Mock() self.hs.get_media_repository = Mock()
self.hs.get_deactivate_account_handler = Mock() self.hs.get_deactivate_account_handler = Mock()
self.resource = JsonResource(self.hs) return self.hs
register_servlets(self.hs, self.resource)
def test_disabled(self): def test_disabled(self):
""" """
@ -69,8 +59,8 @@ class UserRegisterTestCase(unittest.TestCase):
""" """
self.hs.config.registration_shared_secret = None self.hs.config.registration_shared_secret = None
request, channel = make_request("POST", self.url, b'{}') request, channel = self.make_request("POST", self.url, b'{}')
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual( self.assertEqual(
@ -87,8 +77,8 @@ class UserRegisterTestCase(unittest.TestCase):
self.hs.get_secrets = Mock(return_value=secrets) self.hs.get_secrets = Mock(return_value=secrets)
request, channel = make_request("GET", self.url) request, channel = self.make_request("GET", self.url)
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(channel.json_body, {"nonce": "abcd"}) self.assertEqual(channel.json_body, {"nonce": "abcd"})
@ -97,25 +87,25 @@ class UserRegisterTestCase(unittest.TestCase):
Calling GET on the endpoint will return a randomised nonce, which will Calling GET on the endpoint will return a randomised nonce, which will
only last for SALT_TIMEOUT (60s). only last for SALT_TIMEOUT (60s).
""" """
request, channel = make_request("GET", self.url) request, channel = self.make_request("GET", self.url)
render(request, self.resource, self.clock) self.render(request)
nonce = channel.json_body["nonce"] nonce = channel.json_body["nonce"]
# 59 seconds # 59 seconds
self.clock.advance(59) self.reactor.advance(59)
body = json.dumps({"nonce": nonce}) body = json.dumps({"nonce": nonce})
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"]) self.assertEqual('username must be specified', channel.json_body["error"])
# 61 seconds # 61 seconds
self.clock.advance(2) self.reactor.advance(2)
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"]) self.assertEqual('unrecognised nonce', channel.json_body["error"])
@ -124,8 +114,8 @@ class UserRegisterTestCase(unittest.TestCase):
""" """
Only the provided nonce can be used, as it's checked in the MAC. Only the provided nonce can be used, as it's checked in the MAC.
""" """
request, channel = make_request("GET", self.url) request, channel = self.make_request("GET", self.url)
render(request, self.resource, self.clock) self.render(request)
nonce = channel.json_body["nonce"] nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@ -141,8 +131,8 @@ class UserRegisterTestCase(unittest.TestCase):
"mac": want_mac, "mac": want_mac,
} }
) )
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("HMAC incorrect", channel.json_body["error"]) self.assertEqual("HMAC incorrect", channel.json_body["error"])
@ -152,8 +142,8 @@ class UserRegisterTestCase(unittest.TestCase):
When the correct nonce is provided, and the right key is provided, the When the correct nonce is provided, and the right key is provided, the
user is registered. user is registered.
""" """
request, channel = make_request("GET", self.url) request, channel = self.make_request("GET", self.url)
render(request, self.resource, self.clock) self.render(request)
nonce = channel.json_body["nonce"] nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@ -169,8 +159,8 @@ class UserRegisterTestCase(unittest.TestCase):
"mac": want_mac, "mac": want_mac,
} }
) )
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"]) self.assertEqual("@bob:test", channel.json_body["user_id"])
@ -179,8 +169,8 @@ class UserRegisterTestCase(unittest.TestCase):
""" """
A valid unrecognised nonce. A valid unrecognised nonce.
""" """
request, channel = make_request("GET", self.url) request, channel = self.make_request("GET", self.url)
render(request, self.resource, self.clock) self.render(request)
nonce = channel.json_body["nonce"] nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@ -196,15 +186,15 @@ class UserRegisterTestCase(unittest.TestCase):
"mac": want_mac, "mac": want_mac,
} }
) )
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@bob:test", channel.json_body["user_id"]) self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it # Now, try and reuse it
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('unrecognised nonce', channel.json_body["error"]) self.assertEqual('unrecognised nonce', channel.json_body["error"])
@ -217,8 +207,8 @@ class UserRegisterTestCase(unittest.TestCase):
""" """
def nonce(): def nonce():
request, channel = make_request("GET", self.url) request, channel = self.make_request("GET", self.url)
render(request, self.resource, self.clock) self.render(request)
return channel.json_body["nonce"] return channel.json_body["nonce"]
# #
@ -227,8 +217,8 @@ class UserRegisterTestCase(unittest.TestCase):
# Must be present # Must be present
body = json.dumps({}) body = json.dumps({})
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('nonce must be specified', channel.json_body["error"]) self.assertEqual('nonce must be specified', channel.json_body["error"])
@ -239,32 +229,32 @@ class UserRegisterTestCase(unittest.TestCase):
# Must be present # Must be present
body = json.dumps({"nonce": nonce()}) body = json.dumps({"nonce": nonce()})
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('username must be specified', channel.json_body["error"]) self.assertEqual('username must be specified', channel.json_body["error"])
# Must be a string # Must be a string
body = json.dumps({"nonce": nonce(), "username": 1234}) body = json.dumps({"nonce": nonce(), "username": 1234})
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"]) self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes # Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"}) body = json.dumps({"nonce": nonce(), "username": u"abcd\u0000"})
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"]) self.assertEqual('Invalid username', channel.json_body["error"])
# Must not have null bytes # Must not have null bytes
body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid username', channel.json_body["error"]) self.assertEqual('Invalid username', channel.json_body["error"])
@ -275,16 +265,16 @@ class UserRegisterTestCase(unittest.TestCase):
# Must be present # Must be present
body = json.dumps({"nonce": nonce(), "username": "a"}) body = json.dumps({"nonce": nonce(), "username": "a"})
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('password must be specified', channel.json_body["error"]) self.assertEqual('password must be specified', channel.json_body["error"])
# Must be a string # Must be a string
body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"]) self.assertEqual('Invalid password', channel.json_body["error"])
@ -293,16 +283,16 @@ class UserRegisterTestCase(unittest.TestCase):
body = json.dumps( body = json.dumps(
{"nonce": nonce(), "username": "a", "password": u"abcd\u0000"} {"nonce": nonce(), "username": "a", "password": u"abcd\u0000"}
) )
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"]) self.assertEqual('Invalid password', channel.json_body["error"])
# Super long # Super long
body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
request, channel = make_request("POST", self.url, body.encode('utf8')) request, channel = self.make_request("POST", self.url, body.encode('utf8'))
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual('Invalid password', channel.json_body["error"]) self.assertEqual('Invalid password', channel.json_body["error"])

View File

@ -45,11 +45,11 @@ class CreateUserServletTestCase(unittest.TestCase):
) )
handlers = Mock(registration_handler=self.registration_handler) handlers = Mock(registration_handler=self.registration_handler)
self.clock = MemoryReactorClock() self.reactor = MemoryReactorClock()
self.hs_clock = Clock(self.clock) self.hs_clock = Clock(self.reactor)
self.hs = self.hs = setup_test_homeserver( self.hs = self.hs = setup_test_homeserver(
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor
) )
self.hs.get_datastore = Mock(return_value=self.datastore) self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.get_handlers = Mock(return_value=handlers) self.hs.get_handlers = Mock(return_value=handlers)
@ -76,8 +76,8 @@ class CreateUserServletTestCase(unittest.TestCase):
return_value=(user_id, token) return_value=(user_id, token)
) )
request, channel = make_request(b"POST", url, request_data) request, channel = make_request(self.reactor, b"POST", url, request_data)
render(request, res, self.clock) render(request, res, self.reactor)
self.assertEquals(channel.result["code"], b"200") self.assertEquals(channel.result["code"], b"200")

View File

@ -169,7 +169,7 @@ class RestHelper(object):
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
request, channel = make_request( request, channel = make_request(
"POST", path, json.dumps(content).encode('utf8') self.hs.get_reactor(), "POST", path, json.dumps(content).encode('utf8')
) )
render(request, self.resource, self.hs.get_reactor()) render(request, self.resource, self.hs.get_reactor())
@ -217,7 +217,9 @@ class RestHelper(object):
data = {"membership": membership} data = {"membership": membership}
request, channel = make_request("PUT", path, json.dumps(data).encode('utf8')) request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(data).encode('utf8')
)
render(request, self.resource, self.hs.get_reactor()) render(request, self.resource, self.hs.get_reactor())
@ -228,18 +230,6 @@ class RestHelper(object):
self.auth_user_id = temp_id self.auth_user_id = temp_id
@defer.inlineCallbacks
def register(self, user_id):
(code, response) = yield self.mock_resource.trigger(
"POST",
"/_matrix/client/r0/register",
json.dumps(
{"user": user_id, "password": "test", "type": "m.login.password"}
),
)
self.assertEquals(200, code)
defer.returnValue(response)
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200): def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None: if txn_id is None:
txn_id = "m%s" % (str(time.time())) txn_id = "m%s" % (str(time.time()))
@ -251,7 +241,9 @@ class RestHelper(object):
if tok: if tok:
path = path + "?access_token=%s" % tok path = path + "?access_token=%s" % tok
request, channel = make_request("PUT", path, json.dumps(content).encode('utf8')) request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps(content).encode('utf8')
)
render(request, self.resource, self.hs.get_reactor()) render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, ( assert int(channel.result["code"]) == expect_code, (

View File

@ -13,84 +13,47 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import synapse.types
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha import filter from synapse.rest.client.v2_alpha import filter
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock as MemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)
PATH_PREFIX = "/_matrix/client/v2_alpha" PATH_PREFIX = "/_matrix/client/v2_alpha"
class FilterTestCase(unittest.TestCase): class FilterTestCase(unittest.HomeserverTestCase):
USER_ID = "@apple:test" user_id = "@apple:test"
hijack_auth = True
EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}} EXAMPLE_FILTER = {"room": {"timeline": {"types": ["m.room.message"]}}}
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}' EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
TO_REGISTER = [filter] servlets = [filter.register_servlets]
def setUp(self): def prepare(self, reactor, clock, hs):
self.clock = MemoryReactorClock() self.filtering = hs.get_filtering()
self.hs_clock = Clock(self.clock) self.store = hs.get_datastore()
self.hs = setup_test_homeserver(
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
)
self.auth = self.hs.get_auth()
def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.USER_ID),
"token_id": 1,
"is_guest": False,
}
def get_user_by_req(request, allow_guest=False, rights="access"):
return synapse.types.create_requester(
UserID.from_string(self.USER_ID), 1, False, None
)
self.auth.get_user_by_access_token = get_user_by_access_token
self.auth.get_user_by_req = get_user_by_req
self.store = self.hs.get_datastore()
self.filtering = self.hs.get_filtering()
self.resource = JsonResource(self.hs)
for r in self.TO_REGISTER:
r.register_servlets(self.hs, self.resource)
def test_add_filter(self): def test_add_filter(self):
request, channel = make_request( request, channel = self.make_request(
"POST", "POST",
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), "/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, {"filter_id": "0"}) self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0) filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
self.clock.advance(0) self.pump()
self.assertEquals(filter.result, self.EXAMPLE_FILTER) self.assertEquals(filter.result, self.EXAMPLE_FILTER)
def test_add_filter_for_other_user(self): def test_add_filter_for_other_user(self):
request, channel = make_request( request, channel = self.make_request(
"POST", "POST",
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"), "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.result["code"], b"403")
self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN) self.assertEquals(channel.json_body["errcode"], Codes.FORBIDDEN)
@ -98,12 +61,12 @@ class FilterTestCase(unittest.TestCase):
def test_add_filter_non_local_user(self): def test_add_filter_non_local_user(self):
_is_mine = self.hs.is_mine _is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False self.hs.is_mine = lambda target_user: False
request, channel = make_request( request, channel = self.make_request(
"POST", "POST",
"/_matrix/client/r0/user/%s/filter" % (self.USER_ID), "/_matrix/client/r0/user/%s/filter" % (self.user_id),
self.EXAMPLE_FILTER_JSON, self.EXAMPLE_FILTER_JSON,
) )
render(request, self.resource, self.clock) self.render(request)
self.hs.is_mine = _is_mine self.hs.is_mine = _is_mine
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.result["code"], b"403")
@ -113,21 +76,21 @@ class FilterTestCase(unittest.TestCase):
filter_id = self.filtering.add_user_filter( filter_id = self.filtering.add_user_filter(
user_localpart="apple", user_filter=self.EXAMPLE_FILTER user_localpart="apple", user_filter=self.EXAMPLE_FILTER
) )
self.clock.advance(1) self.reactor.advance(1)
filter_id = filter_id.result filter_id = filter_id.result
request, channel = make_request( request, channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.USER_ID, filter_id) "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
self.assertEquals(channel.json_body, self.EXAMPLE_FILTER) self.assertEquals(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self): def test_get_filter_non_existant(self):
request, channel = make_request( request, channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.USER_ID) "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")
self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND) self.assertEquals(channel.json_body["errcode"], Codes.NOT_FOUND)
@ -135,18 +98,18 @@ class FilterTestCase(unittest.TestCase):
# Currently invalid params do not have an appropriate errcode # Currently invalid params do not have an appropriate errcode
# in errors.py # in errors.py
def test_get_filter_invalid_id(self): def test_get_filter_invalid_id(self):
request, channel = make_request( request, channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.USER_ID) "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")
# No ID also returns an invalid_id error # No ID also returns an invalid_id error
def test_get_filter_no_id(self): def test_get_filter_no_id(self):
request, channel = make_request( request, channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.USER_ID) "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")

View File

@ -3,22 +3,19 @@ import json
from mock import Mock from mock import Mock
from twisted.python import failure from twisted.python import failure
from twisted.test.proto_helpers import MemoryReactorClock
from synapse.api.errors import InteractiveAuthIncompleteError from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha.register import register_servlets from synapse.rest.client.v2_alpha.register import register_servlets
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import make_request, render, setup_test_homeserver
class RegisterRestServletTestCase(unittest.TestCase): class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def setUp(self):
servlets = [register_servlets]
def make_homeserver(self, reactor, clock):
self.clock = MemoryReactorClock()
self.hs_clock = Clock(self.clock)
self.url = b"/_matrix/client/r0/register" self.url = b"/_matrix/client/r0/register"
self.appservice = None self.appservice = None
@ -46,9 +43,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
identity_handler=self.identity_handler, identity_handler=self.identity_handler,
login_handler=self.login_handler, login_handler=self.login_handler,
) )
self.hs = setup_test_homeserver( self.hs = self.setup_test_homeserver()
self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.clock
)
self.hs.get_auth = Mock(return_value=self.auth) self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers) self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
@ -58,8 +53,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.config.registrations_require_3pid = [] self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = [] self.hs.config.auto_join_rooms = []
self.resource = JsonResource(self.hs) return self.hs
register_servlets(self.hs, self.resource)
def test_POST_appservice_registration_valid(self): def test_POST_appservice_registration_valid(self):
user_id = "@kermit:muppet" user_id = "@kermit:muppet"
@ -69,10 +63,10 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
request_data = json.dumps({"username": "kermit"}) request_data = json.dumps({"username": "kermit"})
request, channel = make_request( request, channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = { det_data = {
@ -85,25 +79,25 @@ class RegisterRestServletTestCase(unittest.TestCase):
def test_POST_appservice_registration_invalid(self): def test_POST_appservice_registration_invalid(self):
self.appservice = None # no application service exists self.appservice = None # no application service exists
request_data = json.dumps({"username": "kermit"}) request_data = json.dumps({"username": "kermit"})
request, channel = make_request( request, channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
) )
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(channel.result["code"], b"401", channel.result) self.assertEquals(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self): def test_POST_bad_password(self):
request_data = json.dumps({"username": "kermit", "password": 666}) request_data = json.dumps({"username": "kermit", "password": 666})
request, channel = make_request(b"POST", self.url, request_data) request, channel = self.make_request(b"POST", self.url, request_data)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid password") self.assertEquals(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self): def test_POST_bad_username(self):
request_data = json.dumps({"username": 777, "password": "monkey"}) request_data = json.dumps({"username": 777, "password": "monkey"})
request, channel = make_request(b"POST", self.url, request_data) request, channel = self.make_request(b"POST", self.url, request_data)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(channel.result["code"], b"400", channel.result) self.assertEquals(channel.result["code"], b"400", channel.result)
self.assertEquals(channel.json_body["error"], "Invalid username") self.assertEquals(channel.json_body["error"], "Invalid username")
@ -121,8 +115,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
self.device_handler.check_device_registered = Mock(return_value=device_id) self.device_handler.check_device_registered = Mock(return_value=device_id)
request, channel = make_request(b"POST", self.url, request_data) request, channel = self.make_request(b"POST", self.url, request_data)
render(request, self.resource, self.clock) self.render(request)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
@ -143,8 +137,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
self.registration_handler.register = Mock(return_value=("@user:id", "t")) self.registration_handler.register = Mock(return_value=("@user:id", "t"))
request, channel = make_request(b"POST", self.url, request_data) request, channel = self.make_request(b"POST", self.url, request_data)
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Registration has been disabled") self.assertEquals(channel.json_body["error"], "Registration has been disabled")
@ -155,8 +149,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.config.allow_guest_access = True self.hs.config.allow_guest_access = True
self.registration_handler.register = Mock(return_value=(user_id, None)) self.registration_handler.register = Mock(return_value=(user_id, None))
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
render(request, self.resource, self.clock) self.render(request)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
@ -169,8 +163,8 @@ class RegisterRestServletTestCase(unittest.TestCase):
def test_POST_disabled_guest_registration(self): def test_POST_disabled_guest_registration(self):
self.hs.config.allow_guest_access = False self.hs.config.allow_guest_access = False
request, channel = make_request(b"POST", self.url + b"?kind=guest", b"{}") request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
render(request, self.resource, self.clock) self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled") self.assertEquals(channel.json_body["error"], "Guest access is disabled")

View File

@ -34,6 +34,7 @@ class FakeChannel(object):
wire). wire).
""" """
_reactor = attr.ib()
result = attr.ib(default=attr.Factory(dict)) result = attr.ib(default=attr.Factory(dict))
_producer = None _producer = None
@ -63,6 +64,15 @@ class FakeChannel(object):
def registerProducer(self, producer, streaming): def registerProducer(self, producer, streaming):
self._producer = producer self._producer = producer
self.producerStreaming = streaming
def _produce():
if self._producer:
self._producer.resumeProducing()
self._reactor.callLater(0.1, _produce)
if not streaming:
self._reactor.callLater(0.0, _produce)
def unregisterProducer(self): def unregisterProducer(self):
if self._producer is None: if self._producer is None:
@ -105,7 +115,13 @@ class FakeSite:
def make_request( def make_request(
method, path, content=b"", access_token=None, request=SynapseRequest, shorthand=True reactor,
method,
path,
content=b"",
access_token=None,
request=SynapseRequest,
shorthand=True,
): ):
""" """
Make a web request using the given method and path, feed it the Make a web request using the given method and path, feed it the
@ -138,7 +154,7 @@ def make_request(
content = content.encode('utf8') content = content.encode('utf8')
site = FakeSite() site = FakeSite()
channel = FakeChannel() channel = FakeChannel(reactor)
req = request(site, channel) req = request(site, channel)
req.process = lambda: b"" req.process = lambda: b""

View File

@ -21,30 +21,20 @@ from mock import Mock, NonCallableMock
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http.server import JsonResource
from synapse.rest.client.v2_alpha import register, sync from synapse.rest.client.v2_alpha import register, sync
from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import (
ThreadedMemoryReactorClock,
make_request,
render,
setup_test_homeserver,
)
class TestMauLimit(unittest.TestCase): class TestMauLimit(unittest.HomeserverTestCase):
def setUp(self):
self.reactor = ThreadedMemoryReactorClock()
self.clock = Clock(self.reactor)
self.hs = setup_test_homeserver( servlets = [register.register_servlets, sync.register_servlets]
self.addCleanup,
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver(
"red", "red",
http_client=None, http_client=None,
clock=self.clock,
reactor=self.reactor,
federation_client=Mock(), federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]), ratelimiter=NonCallableMock(spec_set=["send_message"]),
) )
@ -63,10 +53,7 @@ class TestMauLimit(unittest.TestCase):
self.hs.config.server_notices_mxid_display_name = None self.hs.config.server_notices_mxid_display_name = None
self.hs.config.server_notices_mxid_avatar_url = None self.hs.config.server_notices_mxid_avatar_url = None
self.hs.config.server_notices_room_name = "Test Server Notice Room" self.hs.config.server_notices_room_name = "Test Server Notice Room"
return self.hs
self.resource = JsonResource(self.hs)
register.register_servlets(self.hs, self.resource)
sync.register_servlets(self.hs, self.resource)
def test_simple_deny_mau(self): def test_simple_deny_mau(self):
# Create and sync so that the MAU counts get updated # Create and sync so that the MAU counts get updated
@ -193,8 +180,8 @@ class TestMauLimit(unittest.TestCase):
} }
) )
request, channel = make_request("POST", "/register", request_data) request, channel = self.make_request("POST", "/register", request_data)
render(request, self.resource, self.reactor) self.render(request)
if channel.code != 200: if channel.code != 200:
raise HttpResponseException( raise HttpResponseException(
@ -206,10 +193,10 @@ class TestMauLimit(unittest.TestCase):
return access_token return access_token
def do_sync_for_user(self, token): def do_sync_for_user(self, token):
request, channel = make_request( request, channel = self.make_request(
"GET", "/sync", access_token=token "GET", "/sync", access_token=token
) )
render(request, self.resource, self.reactor) self.render(request)
if channel.code != 200: if channel.code != 200:
raise HttpResponseException( raise HttpResponseException(

View File

@ -57,7 +57,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback "GET", [re.compile("^/_matrix/foo/(?P<room_id>[^/]*)$")], _callback
) )
request, channel = make_request(b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83") request, channel = make_request(
self.reactor, b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
)
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]})
@ -75,7 +77,7 @@ class JsonResourceTests(unittest.TestCase):
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/_matrix/foo") request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'500') self.assertEqual(channel.result["code"], b'500')
@ -98,7 +100,7 @@ class JsonResourceTests(unittest.TestCase):
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/_matrix/foo") request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'500') self.assertEqual(channel.result["code"], b'500')
@ -115,7 +117,7 @@ class JsonResourceTests(unittest.TestCase):
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/_matrix/foo") request, channel = make_request(self.reactor, b"GET", b"/_matrix/foo")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'403') self.assertEqual(channel.result["code"], b'403')
@ -136,7 +138,7 @@ class JsonResourceTests(unittest.TestCase):
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback) res.register_paths("GET", [re.compile("^/_matrix/foo$")], _callback)
request, channel = make_request(b"GET", b"/_matrix/foobar") request, channel = make_request(self.reactor, b"GET", b"/_matrix/foobar")
render(request, res, self.reactor) render(request, res, self.reactor)
self.assertEqual(channel.result["code"], b'400') self.assertEqual(channel.result["code"], b'400')

View File

@ -23,7 +23,6 @@ from synapse.rest.client.v2_alpha.register import register_servlets
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import make_request
class TermsTestCase(unittest.HomeserverTestCase): class TermsTestCase(unittest.HomeserverTestCase):
@ -92,7 +91,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.registration_handler.check_username = Mock(return_value=True) self.registration_handler.check_username = Mock(return_value=True)
request, channel = make_request(b"POST", self.url, request_data) request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request) self.render(request)
# We don't bother checking that the response is correct - we'll leave that to # We don't bother checking that the response is correct - we'll leave that to
@ -110,7 +109,7 @@ class TermsTestCase(unittest.HomeserverTestCase):
}, },
} }
) )
request, channel = make_request(b"POST", self.url, request_data) request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request) self.render(request)
# We're interested in getting a response that looks like a successful # We're interested in getting a response that looks like a successful

View File

@ -189,11 +189,11 @@ class HomeserverTestCase(TestCase):
for servlet in self.servlets: for servlet in self.servlets:
servlet(self.hs, self.resource) servlet(self.hs, self.resource)
if hasattr(self, "user_id"):
from tests.rest.client.v1.utils import RestHelper from tests.rest.client.v1.utils import RestHelper
self.helper = RestHelper(self.hs, self.resource, self.user_id) self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
if hasattr(self, "user_id"):
if self.hijack_auth: if self.hijack_auth:
def get_user_by_access_token(token=None, allow_guest=False): def get_user_by_access_token(token=None, allow_guest=False):
@ -285,7 +285,9 @@ class HomeserverTestCase(TestCase):
if isinstance(content, dict): if isinstance(content, dict):
content = json.dumps(content).encode('utf8') content = json.dumps(content).encode('utf8')
return make_request(method, path, content, access_token, request, shorthand) return make_request(
self.reactor, method, path, content, access_token, request, shorthand
)
def render(self, request): def render(self, request):
""" """