Use HTTPStatus constants in place of literals in tests. (#13297)

pull/13300/head
Dirk Klimpel 2022-07-15 21:31:27 +02:00 committed by GitHub
parent 7b67e93d49
commit 96cf81e312
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 308 additions and 238 deletions

1
changelog.d/13297.misc Normal file
View File

@ -0,0 +1 @@
Use `HTTPStatus` constants in place of literals in tests.

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from http import HTTPStatus
from unittest.mock import Mock from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
@ -50,7 +51,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request( channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
) )
self.assertEqual(200, channel.code) self.assertEqual(HTTPStatus.OK, channel.code)
complexity = channel.json_body["v1"] complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity) self.assertTrue(complexity > 0, complexity)
@ -62,7 +63,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request( channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,) "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
) )
self.assertEqual(200, channel.code) self.assertEqual(HTTPStatus.OK, channel.code)
complexity = channel.json_body["v1"] complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23) self.assertEqual(complexity, 1.23)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from http import HTTPStatus
from parameterized import parameterized from parameterized import parameterized
@ -58,7 +59,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,), "/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content, query_content,
) )
self.assertEqual(400, channel.code, channel.result) self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON") self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
@ -119,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request( channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,) "GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,)
) )
self.assertEqual(403, channel.code, channel.result) self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@ -153,7 +154,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}" f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
f"?ver={DEFAULT_ROOM_VERSION}", f"?ver={DEFAULT_ROOM_VERSION}",
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body return channel.json_body
def test_send_join(self): def test_send_join(self):
@ -171,7 +172,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
f"/_matrix/federation/v2/send_join/{self._room_id}/x", f"/_matrix/federation/v2/send_join/{self._room_id}/x",
content=join_event_dict, content=join_event_dict,
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# we should get complete room state back # we should get complete room state back
returned_state = [ returned_state = [
@ -226,7 +227,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true", f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
content=join_event_dict, content=join_event_dict,
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# expect a reduced room state # expect a reduced room state
returned_state = [ returned_state = [

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from http import HTTPStatus
from typing import Dict, List from typing import Dict, List
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
@ -255,7 +256,7 @@ class FederationKnockingTestCase(
RoomVersions.V7.identifier, RoomVersions.V7.identifier,
), ),
) )
self.assertEqual(200, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# Note: We don't expect the knock membership event to be sent over federation as # Note: We don't expect the knock membership event to be sent over federation as
# part of the stripped room state, as the knocking homeserver already has that # part of the stripped room state, as the knocking homeserver already has that
@ -293,7 +294,7 @@ class FederationKnockingTestCase(
% (room_id, signed_knock_event.event_id), % (room_id, signed_knock_event.event_id),
signed_knock_event_json, signed_knock_event_json,
) )
self.assertEqual(200, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# Check that we got the stripped room state in return # Check that we got the stripped room state in return
room_state_events = channel.json_body["knock_state_events"] room_state_events = channel.json_body["knock_state_events"]

View File

@ -14,6 +14,7 @@
"""Tests for the password_auth_provider interface""" """Tests for the password_auth_provider interface"""
from http import HTTPStatus
from typing import Any, Type, Union from typing import Any, Type, Union
from unittest.mock import Mock from unittest.mock import Mock
@ -188,14 +189,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# check_password must return an awaitable # check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(True) mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p") channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"]) self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:test", "p") mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
# login with mxid should work too # login with mxid should work too
channel = self._send_password_login("@u:bz", "p") channel = self._send_password_login("@u:bz", "p")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:bz", channel.json_body["user_id"]) self.assertEqual("@u:bz", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p") mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
mock_password_provider.reset_mock() mock_password_provider.reset_mock()
@ -204,7 +205,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# in these cases, but at least we can guard against the API changing # in these cases, but at least we can guard against the API changing
# unexpectedly # unexpectedly
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ") channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"]) self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with( mock_password_provider.check_password.assert_called_once_with(
"@ USER🙂NAME :test", " pASS😢word " "@ USER🙂NAME :test", " pASS😢word "
@ -258,10 +259,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# check_password must return an awaitable # check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(False) mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p") channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result) self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
channel = self._send_password_login("localuser", "localpass") channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"]) self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
@ -382,7 +383,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type") # login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("u", "p") channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_password.assert_not_called() mock_password_provider.check_password.assert_not_called()
@override_config(legacy_providers_config(LegacyCustomAuthProvider)) @override_config(legacy_providers_config(LegacyCustomAuthProvider))
@ -406,14 +407,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login with missing param should be rejected # login with missing param should be rejected
channel = self._send_login("test.login_type", "u") channel = self._send_login("test.login_type", "u")
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = make_awaitable( mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None) ("@user:bz", None)
) )
channel = self._send_login("test.login_type", "u", test_field="y") channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"]) self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with( mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"} "u", "test.login_type", {"test_field": "y"}
@ -427,7 +428,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@ MALFORMED! :bz", None) ("@ MALFORMED! :bz", None)
) )
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"]) self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with( mock_password_provider.check_auth.assert_called_once_with(
" USER🙂NAME ", "test.login_type", {"test_field": " abc "} " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
@ -510,7 +511,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@user:bz", callback) ("@user:bz", callback)
) )
channel = self._send_login("test.login_type", "u", test_field="y") channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"]) self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with( mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"} "u", "test.login_type", {"test_field": "y"}
@ -549,7 +550,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type") # login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass") channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.assert_not_called()
@override_config( @override_config(
@ -584,7 +585,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type") # login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass") channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.assert_not_called()
@override_config( @override_config(
@ -615,7 +616,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type") # login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass") channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called() mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called() mock_password_provider.check_password.assert_not_called()
@ -646,13 +647,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@localuser:test", None) ("@localuser:test", None)
) )
channel = self._send_login("test.login_type", "localuser", test_field="") channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
tok1 = channel.json_body["access_token"] tok1 = channel.json_body["access_token"]
channel = self._send_login( channel = self._send_login(
"test.login_type", "localuser", test_field="", device_id="dev2" "test.login_type", "localuser", test_field="", device_id="dev2"
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# make the initial request which returns a 401 # make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2") channel = self._delete_device(tok1, "dev2")
@ -721,7 +722,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# password login shouldn't work and should be rejected with a 400 # password login shouldn't work and should be rejected with a 400
# ("unknown login type") # ("unknown login type")
channel = self._send_password_login("localuser", "localpass") channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_on_logged_out(self): def test_on_logged_out(self):
"""Tests that the on_logged_out callback is called when the user logs out.""" """Tests that the on_logged_out callback is called when the user logs out."""
@ -884,7 +885,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}, },
access_token=tok, access_token=tok,
) )
self.assertEqual(channel.code, 403, channel.result) self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
self.assertEqual( self.assertEqual(
channel.json_body["errcode"], channel.json_body["errcode"],
Codes.THREEPID_DENIED, Codes.THREEPID_DENIED,
@ -906,7 +907,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}, },
access_token=tok, access_token=tok,
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertIn("sid", channel.json_body) self.assertIn("sid", channel.json_body)
m.assert_called_once_with("email", "bar@test.com", registration) m.assert_called_once_with("email", "bar@test.com", registration)
@ -949,12 +950,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"register", "register",
{"auth": {"session": session, "type": LoginType.DUMMY}}, {"auth": {"session": session, "type": LoginType.DUMMY}},
) )
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
return channel.json_body return channel.json_body
def _get_login_flows(self) -> JsonDict: def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login") channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
return channel.json_body["flows"] return channel.json_body["flows"]
def _send_password_login(self, user: str, password: str) -> FakeChannel: def _send_password_login(self, user: str, password: str) -> FakeChannel:

View File

@ -1379,7 +1379,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body, content=body,
) )
self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@ -1434,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body, content=body,
) )
self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@ -1512,7 +1512,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "admin": False}, content={"password": "abc123", "admin": False},
) )
self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"]) self.assertFalse(channel.json_body["admin"])
@ -1550,7 +1550,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
) )
# Admin user is not blocked by mau anymore # Admin user is not blocked by mau anymore
self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"]) self.assertFalse(channel.json_body["admin"])
@ -1585,7 +1585,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body, content=body,
) )
self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@ -1626,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body, content=body,
) )
self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@ -1666,7 +1666,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body, content=body,
) )
self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"]) self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"]) self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
@ -2407,7 +2407,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123"}, content={"password": "abc123"},
) )
self.assertEqual(201, channel.code, msg=channel.json_body) self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"]) self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"]) self.assertEqual("bob", channel.json_body["displayname"])

View File

@ -15,6 +15,7 @@ import json
import os import os
import re import re
from email.parser import Parser from email.parser import Parser
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock from unittest.mock import Mock
@ -98,7 +99,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
) )
self.assertEqual(channel.code, 403, channel.result) self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
def test_basic_password_reset(self) -> None: def test_basic_password_reset(self) -> None:
"""Test basic password reset flow""" """Test basic password reset flow"""
@ -347,7 +348,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False, shorthand=False,
) )
self.assertEqual(200, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the # Now POST to the same endpoint, mimicking the same behaviour as clicking the
# password reset confirm button # password reset confirm button
@ -362,7 +363,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False, shorthand=False,
content_is_form=True, content_is_form=True,
) )
self.assertEqual(200, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def _get_link_from_email(self) -> str: def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent" assert self.email_attempts, "No emails have been sent"
@ -390,7 +391,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
new_password: str, new_password: str,
session_id: str, session_id: str,
client_secret: str, client_secret: str,
expected_code: int = 200, expected_code: int = HTTPStatus.OK,
) -> None: ) -> None:
channel = self.make_request( channel = self.make_request(
"POST", "POST",
@ -715,7 +716,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
}, },
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user # Get user
@ -725,7 +728,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"]) self.assertFalse(channel.json_body["threepids"])
def test_delete_email(self) -> None: def test_delete_email(self) -> None:
@ -747,7 +750,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": self.email}, {"medium": "email", "address": self.email},
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# Get user # Get user
channel = self.make_request( channel = self.make_request(
@ -756,7 +759,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"]) self.assertFalse(channel.json_body["threepids"])
def test_delete_email_if_disabled(self) -> None: def test_delete_email_if_disabled(self) -> None:
@ -781,7 +784,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user # Get user
@ -791,7 +796,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"]) self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
@ -817,7 +822,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
}, },
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user # Get user
@ -827,7 +834,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"]) self.assertFalse(channel.json_body["threepids"])
def test_no_valid_token(self) -> None: def test_no_valid_token(self) -> None:
@ -852,7 +859,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
}, },
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"]) self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user # Get user
@ -862,7 +871,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"]) self.assertFalse(channel.json_body["threepids"])
@override_config({"next_link_domain_whitelist": None}) @override_config({"next_link_domain_whitelist": None})
@ -872,7 +881,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com", "something@example.com",
"some_secret", "some_secret",
next_link="https://example.com/a/good/site", next_link="https://example.com/a/good/site",
expect_code=200, expect_code=HTTPStatus.OK,
) )
@override_config({"next_link_domain_whitelist": None}) @override_config({"next_link_domain_whitelist": None})
@ -884,7 +893,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com", "something@example.com",
"some_secret", "some_secret",
next_link="some-protocol://abcdefghijklmopqrstuvwxyz", next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
expect_code=200, expect_code=HTTPStatus.OK,
) )
@override_config({"next_link_domain_whitelist": None}) @override_config({"next_link_domain_whitelist": None})
@ -895,7 +904,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com", "something@example.com",
"some_secret", "some_secret",
next_link="file:///host/path", next_link="file:///host/path",
expect_code=400, expect_code=HTTPStatus.BAD_REQUEST,
) )
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]}) @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
@ -907,28 +916,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com", "something@example.com",
"some_secret", "some_secret",
next_link=None, next_link=None,
expect_code=200, expect_code=HTTPStatus.OK,
) )
self._request_token( self._request_token(
"something@example.com", "something@example.com",
"some_secret", "some_secret",
next_link="https://example.com/some/good/page", next_link="https://example.com/some/good/page",
expect_code=200, expect_code=HTTPStatus.OK,
) )
self._request_token( self._request_token(
"something@example.com", "something@example.com",
"some_secret", "some_secret",
next_link="https://example.org/some/also/good/page", next_link="https://example.org/some/also/good/page",
expect_code=200, expect_code=HTTPStatus.OK,
) )
self._request_token( self._request_token(
"something@example.com", "something@example.com",
"some_secret", "some_secret",
next_link="https://bad.example.org/some/bad/page", next_link="https://bad.example.org/some/bad/page",
expect_code=400, expect_code=HTTPStatus.BAD_REQUEST,
) )
@override_config({"next_link_domain_whitelist": []}) @override_config({"next_link_domain_whitelist": []})
@ -940,7 +949,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com", "something@example.com",
"some_secret", "some_secret",
next_link="https://example.com/a/page", next_link="https://example.com/a/page",
expect_code=400, expect_code=HTTPStatus.BAD_REQUEST,
) )
def _request_token( def _request_token(
@ -948,7 +957,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
email: str, email: str,
client_secret: str, client_secret: str,
next_link: Optional[str] = None, next_link: Optional[str] = None,
expect_code: int = 200, expect_code: int = HTTPStatus.OK,
) -> Optional[str]: ) -> Optional[str]:
"""Request a validation token to add an email address to a user's account """Request a validation token to add an email address to a user's account
@ -993,7 +1002,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
b"account/3pid/email/requestToken", b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1}, {"client_secret": client_secret, "email": email, "send_attempt": 1},
) )
self.assertEqual(400, channel.code, msg=channel.result["body"]) self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(expected_errcode, channel.json_body["errcode"]) self.assertEqual(expected_errcode, channel.json_body["errcode"])
self.assertEqual(expected_error, channel.json_body["error"]) self.assertEqual(expected_error, channel.json_body["error"])
@ -1002,7 +1013,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "") path = link.replace("https://example.com", "")
channel = self.make_request("GET", path, shorthand=False) channel = self.make_request("GET", path, shorthand=False)
self.assertEqual(200, channel.code, channel.result) self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def _get_link_from_email(self) -> str: def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent" assert self.email_attempts, "No emails have been sent"
@ -1052,7 +1063,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# Get user # Get user
channel = self.make_request( channel = self.make_request(
@ -1061,7 +1072,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok, access_token=self.user_id_tok,
) )
self.assertEqual(200, channel.code, msg=channel.result["body"]) self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
@ -1092,7 +1103,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that not providing any MXID raises an error.""" """Tests that not providing any MXID raises an error."""
self._test_status( self._test_status(
users=None, users=None,
expected_status_code=400, expected_status_code=HTTPStatus.BAD_REQUEST,
expected_errcode=Codes.MISSING_PARAM, expected_errcode=Codes.MISSING_PARAM,
) )
@ -1100,7 +1111,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that providing an invalid MXID raises an error.""" """Tests that providing an invalid MXID raises an error."""
self._test_status( self._test_status(
users=["bad:test"], users=["bad:test"],
expected_status_code=400, expected_status_code=HTTPStatus.BAD_REQUEST,
expected_errcode=Codes.INVALID_PARAM, expected_errcode=Codes.INVALID_PARAM,
) )
@ -1286,7 +1297,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
def _test_status( def _test_status(
self, self,
users: Optional[List[str]], users: Optional[List[str]],
expected_status_code: int = 200, expected_status_code: int = HTTPStatus.OK,
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None, expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
expected_failures: Optional[List[str]] = None, expected_failures: Optional[List[str]] = None,
expected_errcode: Optional[str] = None, expected_errcode: Optional[str] = None,

View File

@ -14,6 +14,7 @@
import json import json
import time import time
import urllib.parse import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from unittest.mock import Mock from unittest.mock import Mock
from urllib.parse import urlencode from urllib.parse import urlencode
@ -261,20 +262,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
} }
channel = self.make_request(b"POST", LOGIN_URL, params) channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
access_token = channel.json_body["access_token"] access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"] device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token # we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# time passes # time passes
self.reactor.advance(24 * 3600) self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted # ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True) self.assertEqual(channel.json_body["soft_logout"], True)
@ -288,7 +289,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# more requests with the expired token should still return a soft-logout # more requests with the expired token should still return a soft-logout
self.reactor.advance(3600) self.reactor.advance(3600)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True) self.assertEqual(channel.json_body["soft_logout"], True)
@ -296,7 +297,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self._delete_device(access_token_2, "kermit", "monkey", device_id) self._delete_device(access_token_2, "kermit", "monkey", device_id)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], False) self.assertEqual(channel.json_body["soft_logout"], False)
@ -307,7 +308,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token b"DELETE", "devices/" + device_id, access_token=access_token
) )
self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
# check it's a UI-Auth fail # check it's a UI-Auth fail
self.assertEqual( self.assertEqual(
set(channel.json_body.keys()), set(channel.json_body.keys()),
@ -330,7 +331,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token=access_token, access_token=access_token,
content={"auth": auth}, content={"auth": auth},
) )
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@override_config({"session_lifetime": "24h"}) @override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None: def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
@ -341,14 +342,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token # we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# time passes # time passes
self.reactor.advance(24 * 3600) self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted # ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True) self.assertEqual(channel.json_body["soft_logout"], True)
@ -367,14 +368,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token # we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# time passes # time passes
self.reactor.advance(24 * 3600) self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted # ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token) channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, 401, channel.result) self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN") self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True) self.assertEqual(channel.json_body["soft_logout"], True)
@ -466,7 +467,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_get_login_flows(self) -> None: def test_get_login_flows(self) -> None:
"""GET /login should return password and SSO flows""" """GET /login should return password and SSO flows"""
channel = self.make_request("GET", "/_matrix/client/r0/login") channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
expected_flow_types = [ expected_flow_types = [
"m.login.cas", "m.login.cas",
@ -494,14 +495,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"""/login/sso/redirect should redirect to an identity picker""" """/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker # first hit the redirect url, which should redirect to our idp picker
channel = self._make_sso_redirect_request(None) channel = self._make_sso_redirect_request(None)
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
uri = location_headers[0] uri = location_headers[0]
# hitting that picker should give us some HTML # hitting that picker should give us some HTML
channel = self.make_request("GET", uri) channel = self.make_request("GET", uri)
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
# parse the form to check it has fields assumed elsewhere in this class # parse the form to check it has fields assumed elsewhere in this class
html = channel.result["body"].decode("utf-8") html = channel.result["body"].decode("utf-8")
@ -530,7 +531,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=cas", + "&idp=cas",
shorthand=False, shorthand=False,
) )
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
cas_uri = location_headers[0] cas_uri = location_headers[0]
@ -555,7 +556,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml", + "&idp=saml",
) )
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
saml_uri = location_headers[0] saml_uri = location_headers[0]
@ -579,7 +580,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=oidc", + "&idp=oidc",
) )
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
oidc_uri = location_headers[0] oidc_uri = location_headers[0]
@ -606,7 +607,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
# that should serve a confirmation page # that should serve a confirmation page
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
content_type_headers = channel.headers.getRawHeaders("Content-Type") content_type_headers = channel.headers.getRawHeaders("Content-Type")
assert content_type_headers assert content_type_headers
self.assertTrue(content_type_headers[-1].startswith("text/html")) self.assertTrue(content_type_headers[-1].startswith("text/html"))
@ -634,7 +635,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"/login", "/login",
content={"type": "m.login.token", "token": login_token}, content={"type": "m.login.token", "token": login_token},
) )
self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test") self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self) -> None: def test_multi_sso_redirect_to_unknown(self) -> None:
@ -643,18 +644,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"GET", "GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz", "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
) )
self.assertEqual(channel.code, 400, channel.result) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None: def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404""" """If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request("xxx") channel = self._make_sso_redirect_request("xxx")
self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_client_idp_redirect_to_oidc(self) -> None: def test_client_idp_redirect_to_oidc(self) -> None:
"""If the client pick a known IdP, redirect to it""" """If the client pick a known IdP, redirect to it"""
channel = self._make_sso_redirect_request("oidc") channel = self._make_sso_redirect_request("oidc")
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
oidc_uri = location_headers[0] oidc_uri = location_headers[0]
@ -765,7 +766,7 @@ class CASTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", cas_ticket_url) channel = self.make_request("GET", cas_ticket_url)
# Test that the response is HTML. # Test that the response is HTML.
self.assertEqual(channel.code, 200, channel.result) self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
content_type_header_value = "" content_type_header_value = ""
for header in channel.result.get("headers", []): for header in channel.result.get("headers", []):
if header[0] == b"Content-Type": if header[0] == b"Content-Type":
@ -1246,7 +1247,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
) )
# that should redirect to the username picker # that should redirect to the username picker
self.assertEqual(channel.code, 302, channel.result) self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
location_headers = channel.headers.getRawHeaders("Location") location_headers = channel.headers.getRawHeaders("Location")
assert location_headers assert location_headers
picker_url = location_headers[0] picker_url = location_headers[0]
@ -1290,7 +1291,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
("Content-Length", str(len(content))), ("Content-Length", str(len(content))),
], ],
) )
self.assertEqual(chan.code, 302, chan.result) self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
location_headers = chan.headers.getRawHeaders("Location") location_headers = chan.headers.getRawHeaders("Location")
assert location_headers assert location_headers
@ -1300,7 +1301,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
path=location_headers[0], path=location_headers[0],
custom_headers=[("Cookie", "username_mapping_session=" + session_id)], custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
) )
self.assertEqual(chan.code, 302, chan.result) self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
location_headers = chan.headers.getRawHeaders("Location") location_headers = chan.headers.getRawHeaders("Location")
assert location_headers assert location_headers
@ -1325,5 +1326,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
"/login", "/login",
content={"type": "m.login.token", "token": login_token}, content={"type": "m.login.token", "token": login_token},
) )
self.assertEqual(chan.code, 200, chan.result) self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
self.assertEqual(chan.json_body["user_id"], "@bobby:test") self.assertEqual(chan.json_body["user_id"], "@bobby:test")

File diff suppressed because it is too large Load Diff