Replace simple_async_mock with AsyncMock (#16180)
Python 3.8 has a native AsyncMock, use it instead of a custom implementation.pull/16185/head
parent
5c9402b9fd
commit
a8a46b1336
|
@ -0,0 +1 @@
|
||||||
|
Use `AsyncMock` instead of custom code.
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from unittest.mock import Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
|
||||||
|
@ -35,7 +35,6 @@ from synapse.types import Requester, UserID
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import simple_async_mock
|
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
from tests.utils import mock_getRawHeaders
|
from tests.utils import mock_getRawHeaders
|
||||||
|
|
||||||
|
@ -60,16 +59,16 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
# this is overridden for the appservice tests
|
# this is overridden for the appservice tests
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
|
|
||||||
self.store.insert_client_ip = simple_async_mock(None)
|
self.store.insert_client_ip = AsyncMock(return_value=None)
|
||||||
self.store.is_support_user = simple_async_mock(False)
|
self.store.is_support_user = AsyncMock(return_value=False)
|
||||||
|
|
||||||
def test_get_user_by_req_user_valid_token(self) -> None:
|
def test_get_user_by_req_user_valid_token(self) -> None:
|
||||||
user_info = TokenLookupResult(
|
user_info = TokenLookupResult(
|
||||||
user_id=self.test_user, token_id=5, device_id="device"
|
user_id=self.test_user, token_id=5, device_id="device"
|
||||||
)
|
)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(user_info)
|
self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
|
||||||
self.store.mark_access_token_as_used = simple_async_mock(None)
|
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
|
||||||
self.store.get_user_locked_status = simple_async_mock(False)
|
self.store.get_user_locked_status = AsyncMock(return_value=False)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
|
@ -78,7 +77,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(requester.user.to_string(), self.test_user)
|
self.assertEqual(requester.user.to_string(), self.test_user)
|
||||||
|
|
||||||
def test_get_user_by_req_user_bad_token(self) -> None:
|
def test_get_user_by_req_user_bad_token(self) -> None:
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
|
@ -91,7 +90,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_get_user_by_req_user_missing_token(self) -> None:
|
def test_get_user_by_req_user_missing_token(self) -> None:
|
||||||
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
|
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(user_info)
|
self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
@ -106,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
|
||||||
)
|
)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
|
@ -125,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
ip_range_whitelist=IPSet(["192.168/16"]),
|
ip_range_whitelist=IPSet(["192.168/16"]),
|
||||||
)
|
)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientAddress.return_value.host = "192.168.10.10"
|
request.getClientAddress.return_value.host = "192.168.10.10"
|
||||||
|
@ -144,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
ip_range_whitelist=IPSet(["192.168/16"]),
|
ip_range_whitelist=IPSet(["192.168/16"]),
|
||||||
)
|
)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientAddress.return_value.host = "131.111.8.42"
|
request.getClientAddress.return_value.host = "131.111.8.42"
|
||||||
|
@ -158,7 +157,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_bad_token(self) -> None:
|
def test_get_user_by_req_appservice_bad_token(self) -> None:
|
||||||
self.store.get_app_service_by_token = Mock(return_value=None)
|
self.store.get_app_service_by_token = Mock(return_value=None)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
|
@ -172,7 +171,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
def test_get_user_by_req_appservice_missing_token(self) -> None:
|
def test_get_user_by_req_appservice_missing_token(self) -> None:
|
||||||
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
@ -190,8 +189,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
app_service.is_interested_in_user = Mock(return_value=True)
|
app_service.is_interested_in_user = Mock(return_value=True)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
# This just needs to return a truth-y value.
|
# This just needs to return a truth-y value.
|
||||||
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
|
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
|
@ -210,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
app_service.is_interested_in_user = Mock(return_value=False)
|
app_service.is_interested_in_user = Mock(return_value=False)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
|
@ -234,10 +233,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
app_service.is_interested_in_user = Mock(return_value=True)
|
app_service.is_interested_in_user = Mock(return_value=True)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
# This just needs to return a truth-y value.
|
# This just needs to return a truth-y value.
|
||||||
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
|
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
# This also needs to just return a truth-y value
|
# This also needs to just return a truth-y value
|
||||||
self.store.get_device = simple_async_mock({"hidden": False})
|
self.store.get_device = AsyncMock(return_value={"hidden": False})
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
|
@ -266,10 +265,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
app_service.is_interested_in_user = Mock(return_value=True)
|
app_service.is_interested_in_user = Mock(return_value=True)
|
||||||
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
self.store.get_app_service_by_token = Mock(return_value=app_service)
|
||||||
# This just needs to return a truth-y value.
|
# This just needs to return a truth-y value.
|
||||||
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
|
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
# This also needs to just return a falsey value
|
# This also needs to just return a falsey value
|
||||||
self.store.get_device = simple_async_mock(None)
|
self.store.get_device = AsyncMock(return_value=None)
|
||||||
|
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
|
@ -283,8 +282,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
|
self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
|
||||||
|
|
||||||
def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None:
|
def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None:
|
||||||
self.store.get_user_by_access_token = simple_async_mock(
|
self.store.get_user_by_access_token = AsyncMock(
|
||||||
TokenLookupResult(
|
return_value=TokenLookupResult(
|
||||||
user_id="@baldrick:matrix.org",
|
user_id="@baldrick:matrix.org",
|
||||||
device_id="device",
|
device_id="device",
|
||||||
token_id=5,
|
token_id=5,
|
||||||
|
@ -292,9 +291,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
token_used=True,
|
token_used=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.store.insert_client_ip = simple_async_mock(None)
|
self.store.insert_client_ip = AsyncMock(return_value=None)
|
||||||
self.store.mark_access_token_as_used = simple_async_mock(None)
|
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
|
||||||
self.store.get_user_locked_status = simple_async_mock(False)
|
self.store.get_user_locked_status = AsyncMock(return_value=False)
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
|
@ -304,8 +303,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
|
def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
|
||||||
self.auth._track_puppeted_user_ips = True
|
self.auth._track_puppeted_user_ips = True
|
||||||
self.store.get_user_by_access_token = simple_async_mock(
|
self.store.get_user_by_access_token = AsyncMock(
|
||||||
TokenLookupResult(
|
return_value=TokenLookupResult(
|
||||||
user_id="@baldrick:matrix.org",
|
user_id="@baldrick:matrix.org",
|
||||||
device_id="device",
|
device_id="device",
|
||||||
token_id=5,
|
token_id=5,
|
||||||
|
@ -313,9 +312,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
token_used=True,
|
token_used=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.store.get_user_locked_status = simple_async_mock(False)
|
self.store.get_user_locked_status = AsyncMock(return_value=False)
|
||||||
self.store.insert_client_ip = simple_async_mock(None)
|
self.store.insert_client_ip = AsyncMock(return_value=None)
|
||||||
self.store.mark_access_token_as_used = simple_async_mock(None)
|
self.store.mark_access_token_as_used = AsyncMock(return_value=None)
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.getClientAddress.return_value.host = "127.0.0.1"
|
request.getClientAddress.return_value.host = "127.0.0.1"
|
||||||
request.args[b"access_token"] = [self.test_token]
|
request.args[b"access_token"] = [self.test_token]
|
||||||
|
@ -324,7 +323,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(self.store.insert_client_ip.call_count, 2)
|
self.assertEqual(self.store.insert_client_ip.call_count, 2)
|
||||||
|
|
||||||
def test_get_user_from_macaroon(self) -> None:
|
def test_get_user_from_macaroon(self) -> None:
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
user_id = "@baldrick:matrix.org"
|
user_id = "@baldrick:matrix.org"
|
||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
|
@ -342,8 +341,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_guest_user_from_macaroon(self) -> None:
|
def test_get_guest_user_from_macaroon(self) -> None:
|
||||||
self.store.get_user_by_id = simple_async_mock({"is_guest": True})
|
self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
|
||||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
self.store.get_user_by_access_token = AsyncMock(return_value=None)
|
||||||
|
|
||||||
user_id = "@baldrick:matrix.org"
|
user_id = "@baldrick:matrix.org"
|
||||||
macaroon = pymacaroons.Macaroon(
|
macaroon = pymacaroons.Macaroon(
|
||||||
|
@ -373,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
|
|
||||||
self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
|
self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)
|
||||||
|
|
||||||
e = self.get_failure(
|
e = self.get_failure(
|
||||||
self.auth_blocking.check_auth_blocking(), ResourceLimitError
|
self.auth_blocking.check_auth_blocking(), ResourceLimitError
|
||||||
|
@ -383,25 +382,27 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(e.value.code, 403)
|
self.assertEqual(e.value.code, 403)
|
||||||
|
|
||||||
# Ensure does not throw an error
|
# Ensure does not throw an error
|
||||||
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
|
self.store.get_monthly_active_count = AsyncMock(
|
||||||
|
return_value=small_number_of_users
|
||||||
|
)
|
||||||
self.get_success(self.auth_blocking.check_auth_blocking())
|
self.get_success(self.auth_blocking.check_auth_blocking())
|
||||||
|
|
||||||
def test_blocking_mau__depending_on_user_type(self) -> None:
|
def test_blocking_mau__depending_on_user_type(self) -> None:
|
||||||
self.auth_blocking._max_mau_value = 50
|
self.auth_blocking._max_mau_value = 50
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
|
|
||||||
self.store.get_monthly_active_count = simple_async_mock(100)
|
self.store.get_monthly_active_count = AsyncMock(return_value=100)
|
||||||
# Support users allowed
|
# Support users allowed
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
|
self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
|
||||||
)
|
)
|
||||||
self.store.get_monthly_active_count = simple_async_mock(100)
|
self.store.get_monthly_active_count = AsyncMock(return_value=100)
|
||||||
# Bots not allowed
|
# Bots not allowed
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
|
self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
|
||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
self.store.get_monthly_active_count = simple_async_mock(100)
|
self.store.get_monthly_active_count = AsyncMock(return_value=100)
|
||||||
# Real users not allowed
|
# Real users not allowed
|
||||||
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
|
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
|
||||||
|
|
||||||
|
@ -412,9 +413,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
self.auth_blocking._track_appservice_user_ips = False
|
self.auth_blocking._track_appservice_user_ips = False
|
||||||
|
|
||||||
self.store.get_monthly_active_count = simple_async_mock(100)
|
self.store.get_monthly_active_count = AsyncMock(return_value=100)
|
||||||
self.store.user_last_seen_monthly_active = simple_async_mock()
|
self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
|
||||||
self.store.is_trial_user = simple_async_mock()
|
self.store.is_trial_user = AsyncMock(return_value=False)
|
||||||
|
|
||||||
appservice = ApplicationService(
|
appservice = ApplicationService(
|
||||||
"abcd",
|
"abcd",
|
||||||
|
@ -443,9 +444,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
self.auth_blocking._track_appservice_user_ips = True
|
self.auth_blocking._track_appservice_user_ips = True
|
||||||
|
|
||||||
self.store.get_monthly_active_count = simple_async_mock(100)
|
self.store.get_monthly_active_count = AsyncMock(return_value=100)
|
||||||
self.store.user_last_seen_monthly_active = simple_async_mock()
|
self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
|
||||||
self.store.is_trial_user = simple_async_mock()
|
self.store.is_trial_user = AsyncMock(return_value=False)
|
||||||
|
|
||||||
appservice = ApplicationService(
|
appservice = ApplicationService(
|
||||||
"abcd",
|
"abcd",
|
||||||
|
@ -473,7 +474,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||||
def test_reserved_threepid(self) -> None:
|
def test_reserved_threepid(self) -> None:
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
self.auth_blocking._max_mau_value = 1
|
self.auth_blocking._max_mau_value = 1
|
||||||
self.store.get_monthly_active_count = simple_async_mock(2)
|
self.store.get_monthly_active_count = AsyncMock(return_value=2)
|
||||||
threepid = {"medium": "email", "address": "reserved@server.com"}
|
threepid = {"medium": "email", "address": "reserved@server.com"}
|
||||||
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
|
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
|
||||||
self.auth_blocking._mau_limits_reserved_threepids = [threepid]
|
self.auth_blocking._mau_limits_reserved_threepids = [threepid]
|
||||||
|
|
|
@ -13,14 +13,13 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import re
|
import re
|
||||||
from typing import Any, Generator
|
from typing import Any, Generator
|
||||||
from unittest.mock import Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.appservice import ApplicationService, Namespace
|
from synapse.appservice import ApplicationService, Namespace
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import simple_async_mock
|
|
||||||
|
|
||||||
|
|
||||||
def _regex(regex: str, exclusive: bool = True) -> Namespace:
|
def _regex(regex: str, exclusive: bool = True) -> Namespace:
|
||||||
|
@ -43,8 +42,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store = Mock()
|
self.store = Mock()
|
||||||
self.store.get_aliases_for_room = simple_async_mock([])
|
self.store.get_aliases_for_room = AsyncMock(return_value=[])
|
||||||
self.store.get_local_users_in_room = simple_async_mock([])
|
self.store.get_local_users_in_room = AsyncMock(return_value=[])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_regex_user_id_prefix_match(
|
def test_regex_user_id_prefix_match(
|
||||||
|
@ -127,10 +126,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||||
_regex("#irc_.*:matrix.org")
|
_regex("#irc_.*:matrix.org")
|
||||||
)
|
)
|
||||||
self.store.get_aliases_for_room = simple_async_mock(
|
self.store.get_aliases_for_room = AsyncMock(
|
||||||
["#irc_foobar:matrix.org", "#athing:matrix.org"]
|
return_value=["#irc_foobar:matrix.org", "#athing:matrix.org"]
|
||||||
)
|
)
|
||||||
self.store.get_local_users_in_room = simple_async_mock([])
|
self.store.get_local_users_in_room = AsyncMock(return_value=[])
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(
|
(
|
||||||
yield self.service.is_interested_in_event(
|
yield self.service.is_interested_in_event(
|
||||||
|
@ -182,10 +181,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
self.service.namespaces[ApplicationService.NS_ALIASES].append(
|
||||||
_regex("#irc_.*:matrix.org")
|
_regex("#irc_.*:matrix.org")
|
||||||
)
|
)
|
||||||
self.store.get_aliases_for_room = simple_async_mock(
|
self.store.get_aliases_for_room = AsyncMock(
|
||||||
["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
|
return_value=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
|
||||||
)
|
)
|
||||||
self.store.get_local_users_in_room = simple_async_mock([])
|
self.store.get_local_users_in_room = AsyncMock(return_value=[])
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
(
|
(
|
||||||
yield defer.ensureDeferred(
|
yield defer.ensureDeferred(
|
||||||
|
@ -205,8 +204,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
self.event.sender = "@irc_foobar:matrix.org"
|
self.event.sender = "@irc_foobar:matrix.org"
|
||||||
self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"])
|
self.store.get_aliases_for_room = AsyncMock(
|
||||||
self.store.get_local_users_in_room = simple_async_mock([])
|
return_value=["#irc_barfoo:matrix.org"]
|
||||||
|
)
|
||||||
|
self.store.get_local_users_in_room = AsyncMock(return_value=[])
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(
|
(
|
||||||
yield self.service.is_interested_in_event(
|
yield self.service.is_interested_in_event(
|
||||||
|
@ -235,10 +236,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
|
def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
|
||||||
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
|
||||||
# Note that @irc_fo:here is the AS user.
|
# Note that @irc_fo:here is the AS user.
|
||||||
self.store.get_local_users_in_room = simple_async_mock(
|
self.store.get_local_users_in_room = AsyncMock(
|
||||||
["@alice:here", "@irc_fo:here", "@bob:here"]
|
return_value=["@alice:here", "@irc_fo:here", "@bob:here"]
|
||||||
)
|
)
|
||||||
self.store.get_aliases_for_room = simple_async_mock([])
|
self.store.get_aliases_for_room = AsyncMock(return_value=[])
|
||||||
|
|
||||||
self.event.sender = "@xmpp_foobar:matrix.org"
|
self.event.sender = "@xmpp_foobar:matrix.org"
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import List, Optional, Sequence, Tuple, cast
|
from typing import List, Optional, Sequence, Tuple, cast
|
||||||
from unittest.mock import Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
|
@ -37,7 +37,6 @@ from synapse.types import DeviceListUpdates, JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import simple_async_mock
|
|
||||||
|
|
||||||
from ..utils import MockClock
|
from ..utils import MockClock
|
||||||
|
|
||||||
|
@ -62,10 +61,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
||||||
txn = Mock(id=txn_id, service=service, events=events)
|
txn = Mock(id=txn_id, service=service, events=events)
|
||||||
|
|
||||||
# mock methods
|
# mock methods
|
||||||
self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
|
self.store.get_appservice_state = AsyncMock(
|
||||||
txn.send = simple_async_mock(True)
|
return_value=ApplicationServiceState.UP
|
||||||
txn.complete = simple_async_mock(True)
|
)
|
||||||
self.store.create_appservice_txn = simple_async_mock(txn)
|
txn.send = AsyncMock(return_value=True)
|
||||||
|
txn.complete = AsyncMock(return_value=True)
|
||||||
|
self.store.create_appservice_txn = AsyncMock(return_value=txn)
|
||||||
|
|
||||||
# actual call
|
# actual call
|
||||||
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
||||||
|
@ -89,10 +90,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
||||||
events = [Mock(), Mock()]
|
events = [Mock(), Mock()]
|
||||||
|
|
||||||
txn = Mock(id="idhere", service=service, events=events)
|
txn = Mock(id="idhere", service=service, events=events)
|
||||||
self.store.get_appservice_state = simple_async_mock(
|
self.store.get_appservice_state = AsyncMock(
|
||||||
ApplicationServiceState.DOWN
|
return_value=ApplicationServiceState.DOWN
|
||||||
)
|
)
|
||||||
self.store.create_appservice_txn = simple_async_mock(txn)
|
self.store.create_appservice_txn = AsyncMock(return_value=txn)
|
||||||
|
|
||||||
# actual call
|
# actual call
|
||||||
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
||||||
|
@ -118,10 +119,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
||||||
txn = Mock(id=txn_id, service=service, events=events)
|
txn = Mock(id=txn_id, service=service, events=events)
|
||||||
|
|
||||||
# mock methods
|
# mock methods
|
||||||
self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
|
self.store.get_appservice_state = AsyncMock(
|
||||||
self.store.set_appservice_state = simple_async_mock(True)
|
return_value=ApplicationServiceState.UP
|
||||||
txn.send = simple_async_mock(False) # fails to send
|
)
|
||||||
self.store.create_appservice_txn = simple_async_mock(txn)
|
self.store.set_appservice_state = AsyncMock(return_value=True)
|
||||||
|
txn.send = AsyncMock(return_value=False) # fails to send
|
||||||
|
self.store.create_appservice_txn = AsyncMock(return_value=txn)
|
||||||
|
|
||||||
# actual call
|
# actual call
|
||||||
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
|
||||||
|
@ -150,7 +153,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
||||||
self.as_api = Mock()
|
self.as_api = Mock()
|
||||||
self.store = Mock()
|
self.store = Mock()
|
||||||
self.service = Mock()
|
self.service = Mock()
|
||||||
self.callback = simple_async_mock()
|
self.callback = AsyncMock()
|
||||||
self.recoverer = _Recoverer(
|
self.recoverer = _Recoverer(
|
||||||
clock=cast(Clock, self.clock),
|
clock=cast(Clock, self.clock),
|
||||||
as_api=self.as_api,
|
as_api=self.as_api,
|
||||||
|
@ -174,8 +177,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
||||||
self.recoverer.recover()
|
self.recoverer.recover()
|
||||||
# shouldn't have called anything prior to waiting for exp backoff
|
# shouldn't have called anything prior to waiting for exp backoff
|
||||||
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
|
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
|
||||||
txn.send = simple_async_mock(True)
|
txn.send = AsyncMock(return_value=True)
|
||||||
txn.complete = simple_async_mock(None)
|
txn.complete = AsyncMock(return_value=None)
|
||||||
# wait for exp backoff
|
# wait for exp backoff
|
||||||
self.clock.advance_time(2)
|
self.clock.advance_time(2)
|
||||||
self.assertEqual(1, txn.send.call_count)
|
self.assertEqual(1, txn.send.call_count)
|
||||||
|
@ -202,8 +205,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.recoverer.recover()
|
self.recoverer.recover()
|
||||||
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
|
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
|
||||||
txn.send = simple_async_mock(False)
|
txn.send = AsyncMock(return_value=False)
|
||||||
txn.complete = simple_async_mock(None)
|
txn.complete = AsyncMock(return_value=None)
|
||||||
self.clock.advance_time(2)
|
self.clock.advance_time(2)
|
||||||
self.assertEqual(1, txn.send.call_count)
|
self.assertEqual(1, txn.send.call_count)
|
||||||
self.assertEqual(0, txn.complete.call_count)
|
self.assertEqual(0, txn.complete.call_count)
|
||||||
|
@ -216,7 +219,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
|
||||||
self.assertEqual(3, txn.send.call_count)
|
self.assertEqual(3, txn.send.call_count)
|
||||||
self.assertEqual(0, txn.complete.call_count)
|
self.assertEqual(0, txn.complete.call_count)
|
||||||
self.assertEqual(0, self.callback.call_count)
|
self.assertEqual(0, self.callback.call_count)
|
||||||
txn.send = simple_async_mock(True) # successfully send the txn
|
txn.send = AsyncMock(return_value=True) # successfully send the txn
|
||||||
pop_txn = True # returns the txn the first time, then no more.
|
pop_txn = True # returns the txn the first time, then no more.
|
||||||
self.clock.advance_time(16)
|
self.clock.advance_time(16)
|
||||||
self.assertEqual(1, txn.send.call_count) # new mock reset call count
|
self.assertEqual(1, txn.send.call_count) # new mock reset call count
|
||||||
|
@ -244,7 +247,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None:
|
||||||
self.scheduler = ApplicationServiceScheduler(hs)
|
self.scheduler = ApplicationServiceScheduler(hs)
|
||||||
self.txn_ctrl = Mock()
|
self.txn_ctrl = Mock()
|
||||||
self.txn_ctrl.send = simple_async_mock()
|
self.txn_ctrl.send = AsyncMock()
|
||||||
|
|
||||||
# Replace instantiated _TransactionController instances with our Mock
|
# Replace instantiated _TransactionController instances with our Mock
|
||||||
self.scheduler.txn_ctrl = self.txn_ctrl
|
self.scheduler.txn_ctrl = self.txn_ctrl
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -30,7 +30,6 @@ from synapse.types import JsonDict, StreamToken, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.handlers.test_sync import generate_sync_config
|
from tests.handlers.test_sync import generate_sync_config
|
||||||
from tests.test_utils import simple_async_mock
|
|
||||||
from tests.unittest import (
|
from tests.unittest import (
|
||||||
FederatingHomeserverTestCase,
|
FederatingHomeserverTestCase,
|
||||||
HomeserverTestCase,
|
HomeserverTestCase,
|
||||||
|
@ -157,7 +156,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
# Mock out the calls over federation.
|
# Mock out the calls over federation.
|
||||||
self.fed_transport_client = Mock(spec=["send_transaction"])
|
self.fed_transport_client = Mock(spec=["send_transaction"])
|
||||||
self.fed_transport_client.send_transaction = simple_async_mock({})
|
self.fed_transport_client.send_transaction = AsyncMock(return_value={})
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
federation_transport_client=self.fed_transport_client,
|
federation_transport_client=self.fed_transport_client,
|
||||||
|
|
|
@ -36,7 +36,7 @@ from synapse.util import Clock
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import event_injection, simple_async_mock
|
from tests.test_utils import event_injection
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
from tests.utils import MockClock
|
from tests.utils import MockClock
|
||||||
|
|
||||||
|
@ -399,7 +399,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
|
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
|
||||||
# we can track any outgoing ephemeral events
|
# we can track any outgoing ephemeral events
|
||||||
self.send_mock = simple_async_mock()
|
self.send_mock = AsyncMock()
|
||||||
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment]
|
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment]
|
||||||
|
|
||||||
# Mock out application services, and allow defining our own in tests
|
# Mock out application services, and allow defining our own in tests
|
||||||
|
@ -897,7 +897,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
|
||||||
|
|
||||||
# Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
|
# Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
|
||||||
# will be sent over the wire
|
# will be sent over the wire
|
||||||
self.put_json = simple_async_mock()
|
self.put_json = AsyncMock()
|
||||||
hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment]
|
hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment]
|
||||||
|
|
||||||
# Mock out application services, and allow defining our own in tests
|
# Mock out application services, and allow defining our own in tests
|
||||||
|
@ -1003,7 +1003,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
|
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
|
||||||
# we can track what's going out
|
# we can track what's going out
|
||||||
self.send_mock = simple_async_mock()
|
self.send_mock = AsyncMock()
|
||||||
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method.
|
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method.
|
||||||
|
|
||||||
# Define an application service for the tests
|
# Define an application service for the tests
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from unittest.mock import Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from synapse.handlers.cas import CasResponse
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.test_utils import simple_async_mock
|
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
# These are a few constants that are used as config parameters in the tests.
|
# These are a few constants that are used as config parameters in the tests.
|
||||||
|
@ -61,7 +60,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
|
||||||
|
|
||||||
cas_response = CasResponse("test_user", {})
|
cas_response = CasResponse("test_user", {})
|
||||||
request = _mock_request()
|
request = _mock_request()
|
||||||
|
@ -89,7 +88,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
|
||||||
|
|
||||||
# Map a user via SSO.
|
# Map a user via SSO.
|
||||||
cas_response = CasResponse("test_user", {})
|
cas_response = CasResponse("test_user", {})
|
||||||
|
@ -129,7 +128,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
|
||||||
|
|
||||||
cas_response = CasResponse("föö", {})
|
cas_response = CasResponse("föö", {})
|
||||||
request = _mock_request()
|
request = _mock_request()
|
||||||
|
@ -160,7 +159,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
|
||||||
|
|
||||||
# The response doesn't have the proper userGroup or department.
|
# The response doesn't have the proper userGroup or department.
|
||||||
cas_response = CasResponse("test_user", {})
|
cas_response = CasResponse("test_user", {})
|
||||||
|
|
|
@ -39,7 +39,7 @@ from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
from tests.test_utils import FakeResponse, get_awaitable_result
|
||||||
from tests.unittest import HomeserverTestCase, skip_unless
|
from tests.unittest import HomeserverTestCase, skip_unless
|
||||||
from tests.utils import mock_getRawHeaders
|
from tests.utils import mock_getRawHeaders
|
||||||
|
|
||||||
|
@ -147,7 +147,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_inactive_token(self) -> None:
|
def test_inactive_token(self) -> None:
|
||||||
"""The handler should return a 403 where the token is inactive."""
|
"""The handler should return a 403 where the token is inactive."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={"active": False},
|
payload={"active": False},
|
||||||
|
@ -166,7 +166,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_active_no_scope(self) -> None:
|
def test_active_no_scope(self) -> None:
|
||||||
"""The handler should return a 403 where no scope is given."""
|
"""The handler should return a 403 where no scope is given."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={"active": True},
|
payload={"active": True},
|
||||||
|
@ -185,7 +185,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_active_user_no_subject(self) -> None:
|
def test_active_user_no_subject(self) -> None:
|
||||||
"""The handler should return a 500 when no subject is present."""
|
"""The handler should return a 500 when no subject is present."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
|
payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
|
||||||
|
@ -204,7 +204,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_active_no_user_scope(self) -> None:
|
def test_active_no_user_scope(self) -> None:
|
||||||
"""The handler should return a 500 when no subject is present."""
|
"""The handler should return a 500 when no subject is present."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={
|
payload={
|
||||||
|
@ -227,7 +227,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_active_admin_not_user(self) -> None:
|
def test_active_admin_not_user(self) -> None:
|
||||||
"""The handler should raise when the scope has admin right but not user."""
|
"""The handler should raise when the scope has admin right but not user."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={
|
payload={
|
||||||
|
@ -251,7 +251,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_active_admin(self) -> None:
|
def test_active_admin(self) -> None:
|
||||||
"""The handler should return a requester with admin rights."""
|
"""The handler should return a requester with admin rights."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={
|
payload={
|
||||||
|
@ -281,7 +281,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_active_admin_highest_privilege(self) -> None:
|
def test_active_admin_highest_privilege(self) -> None:
|
||||||
"""The handler should resolve to the most permissive scope."""
|
"""The handler should resolve to the most permissive scope."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={
|
payload={
|
||||||
|
@ -313,7 +313,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_active_user(self) -> None:
|
def test_active_user(self) -> None:
|
||||||
"""The handler should return a requester with normal user rights."""
|
"""The handler should return a requester with normal user rights."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={
|
payload={
|
||||||
|
@ -344,7 +344,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
"""The handler should return a requester with normal user rights
|
"""The handler should return a requester with normal user rights
|
||||||
and an user ID matching the one specified in query param `user_id`"""
|
and an user ID matching the one specified in query param `user_id`"""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={
|
payload={
|
||||||
|
@ -378,7 +378,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_active_user_with_device(self) -> None:
|
def test_active_user_with_device(self) -> None:
|
||||||
"""The handler should return a requester with normal user rights and a device ID."""
|
"""The handler should return a requester with normal user rights and a device ID."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={
|
payload={
|
||||||
|
@ -408,7 +408,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_multiple_devices(self) -> None:
|
def test_multiple_devices(self) -> None:
|
||||||
"""The handler should raise an error if multiple devices are found in the scope."""
|
"""The handler should raise an error if multiple devices are found in the scope."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={
|
payload={
|
||||||
|
@ -433,7 +433,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_active_guest_not_allowed(self) -> None:
|
def test_active_guest_not_allowed(self) -> None:
|
||||||
"""The handler should return an insufficient scope error."""
|
"""The handler should return an insufficient scope error."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={
|
payload={
|
||||||
|
@ -463,7 +463,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_active_guest_allowed(self) -> None:
|
def test_active_guest_allowed(self) -> None:
|
||||||
"""The handler should return a requester with guest user rights and a device ID."""
|
"""The handler should return a requester with guest user rights and a device ID."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={
|
payload={
|
||||||
|
@ -499,19 +499,19 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||||
|
|
||||||
# The introspection endpoint is returning an error.
|
# The introspection endpoint is returning an error.
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse(code=500, body=b"Internal Server Error")
|
return_value=FakeResponse(code=500, body=b"Internal Server Error")
|
||||||
)
|
)
|
||||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||||
self.assertEqual(error.value.code, 503)
|
self.assertEqual(error.value.code, 503)
|
||||||
|
|
||||||
# The introspection endpoint request fails.
|
# The introspection endpoint request fails.
|
||||||
self.http_client.request = simple_async_mock(raises=Exception())
|
self.http_client.request = AsyncMock(side_effect=Exception())
|
||||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||||
self.assertEqual(error.value.code, 503)
|
self.assertEqual(error.value.code, 503)
|
||||||
|
|
||||||
# The introspection endpoint does not return a JSON object.
|
# The introspection endpoint does not return a JSON object.
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200, payload=["this is an array", "not an object"]
|
code=200, payload=["this is an array", "not an object"]
|
||||||
)
|
)
|
||||||
|
@ -520,7 +520,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
self.assertEqual(error.value.code, 503)
|
self.assertEqual(error.value.code, 503)
|
||||||
|
|
||||||
# The introspection endpoint does not return valid JSON.
|
# The introspection endpoint does not return valid JSON.
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse(code=200, body=b"this is not valid JSON")
|
return_value=FakeResponse(code=200, body=b"this is not valid JSON")
|
||||||
)
|
)
|
||||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||||
|
@ -528,7 +528,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
|
|
||||||
def test_introspection_token_cache(self) -> None:
|
def test_introspection_token_cache(self) -> None:
|
||||||
access_token = "open_sesame"
|
access_token = "open_sesame"
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={"active": "true", "scope": "guest", "jti": access_token},
|
payload={"active": "true", "scope": "guest", "jti": access_token},
|
||||||
|
@ -559,7 +559,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
|
|
||||||
# test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
|
# test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
|
||||||
# token with a soon-to-expire `exp` field to the cache
|
# token with a soon-to-expire `exp` field to the cache
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={
|
payload={
|
||||||
|
@ -640,7 +640,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||||
def test_cross_signing(self) -> None:
|
def test_cross_signing(self) -> None:
|
||||||
"""Try uploading device keys with OAuth delegation enabled."""
|
"""Try uploading device keys with OAuth delegation enabled."""
|
||||||
|
|
||||||
self.http_client.request = simple_async_mock(
|
self.http_client.request = AsyncMock(
|
||||||
return_value=FakeResponse.json(
|
return_value=FakeResponse.json(
|
||||||
code=200,
|
code=200,
|
||||||
payload={
|
payload={
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
|
from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
|
||||||
from unittest.mock import ANY, Mock, patch
|
from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||||
from urllib.parse import parse_qs, urlparse
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
import pymacaroons
|
import pymacaroons
|
||||||
|
@ -28,7 +28,7 @@ from synapse.util import Clock
|
||||||
from synapse.util.macaroons import get_value_from_macaroon
|
from synapse.util.macaroons import get_value_from_macaroon
|
||||||
from synapse.util.stringutils import random_string
|
from synapse.util.stringutils import random_string
|
||||||
|
|
||||||
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
from tests.test_utils import FakeResponse, get_awaitable_result
|
||||||
from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
|
from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
auth_handler = hs.get_auth_handler()
|
auth_handler = hs.get_auth_handler()
|
||||||
# Mock the complete SSO login method.
|
# Mock the complete SSO login method.
|
||||||
self.complete_sso_login = simple_async_mock()
|
self.complete_sso_login = AsyncMock()
|
||||||
auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment]
|
auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment]
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Set, Tuple
|
from typing import Any, Dict, Optional, Set, Tuple
|
||||||
from unittest.mock import Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
|
@ -25,7 +25,6 @@ from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.test_utils import simple_async_mock
|
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
# Check if we have the dependencies to run the tests.
|
# Check if we have the dependencies to run the tests.
|
||||||
|
@ -134,7 +133,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
|
||||||
|
|
||||||
# send a mocked-up SAML response to the callback
|
# send a mocked-up SAML response to the callback
|
||||||
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
||||||
|
@ -164,7 +163,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
|
||||||
|
|
||||||
# Map a user via SSO.
|
# Map a user via SSO.
|
||||||
saml_response = FakeAuthnResponse(
|
saml_response = FakeAuthnResponse(
|
||||||
|
@ -206,7 +205,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
|
||||||
|
|
||||||
# mock out the error renderer too
|
# mock out the error renderer too
|
||||||
sso_handler = self.hs.get_sso_handler()
|
sso_handler = self.hs.get_sso_handler()
|
||||||
|
@ -227,7 +226,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler and error renderer
|
# stub out the auth handler and error renderer
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
|
||||||
sso_handler = self.hs.get_sso_handler()
|
sso_handler = self.hs.get_sso_handler()
|
||||||
sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
|
sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
|
||||||
|
|
||||||
|
@ -312,7 +311,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
|
||||||
|
|
||||||
# The response doesn't have the proper userGroup or department.
|
# The response doesn't have the proper userGroup or department.
|
||||||
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from unittest.mock import Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
@ -33,7 +33,6 @@ from synapse.util import Clock
|
||||||
|
|
||||||
from tests.events.test_presence_router import send_presence_update, sync_presence
|
from tests.events.test_presence_router import send_presence_update, sync_presence
|
||||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
from tests.test_utils import simple_async_mock
|
|
||||||
from tests.test_utils.event_injection import inject_member_event
|
from tests.test_utils.event_injection import inject_member_event
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
|
@ -70,7 +69,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
# Mock out the calls over federation.
|
# Mock out the calls over federation.
|
||||||
self.fed_transport_client = Mock(spec=["send_transaction"])
|
self.fed_transport_client = Mock(spec=["send_transaction"])
|
||||||
self.fed_transport_client.send_transaction = simple_async_mock({})
|
self.fed_transport_client.send_transaction = AsyncMock(return_value={})
|
||||||
|
|
||||||
return self.setup_test_homeserver(
|
return self.setup_test_homeserver(
|
||||||
federation_transport_client=self.fed_transport_client,
|
federation_transport_client=self.fed_transport_client,
|
||||||
|
@ -579,9 +578,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
||||||
"""Test that the module API can join a remote room."""
|
"""Test that the module API can join a remote room."""
|
||||||
# Necessary to fake a remote join.
|
# Necessary to fake a remote join.
|
||||||
fake_stream_id = 1
|
fake_stream_id = 1
|
||||||
mocked_remote_join = simple_async_mock(
|
mocked_remote_join = AsyncMock(return_value=("fake-event-id", fake_stream_id))
|
||||||
return_value=("fake-event-id", fake_stream_id)
|
|
||||||
)
|
|
||||||
self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment]
|
self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment]
|
||||||
fake_remote_host = f"{self.module_api.server_name}-remote"
|
fake_remote_host = f"{self.module_api.server_name}-remote"
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
|
@ -28,7 +28,6 @@ from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict, create_requester
|
from synapse.types import JsonDict, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.test_utils import simple_async_mock
|
|
||||||
from tests.unittest import HomeserverTestCase, override_config
|
from tests.unittest import HomeserverTestCase, override_config
|
||||||
|
|
||||||
|
|
||||||
|
@ -191,7 +190,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
|
||||||
# Mock the method which calculates push rules -- we do this instead of
|
# Mock the method which calculates push rules -- we do this instead of
|
||||||
# e.g. checking the results in the database because we want to ensure
|
# e.g. checking the results in the database because we want to ensure
|
||||||
# that code isn't even running.
|
# that code isn't even running.
|
||||||
bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment]
|
bulk_evaluator._action_for_event_by_user = AsyncMock() # type: ignore[assignment]
|
||||||
|
|
||||||
# Ensure no actions are generated!
|
# Ensure no actions are generated!
|
||||||
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
|
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from unittest.mock import Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from synapse.rest.client import login, notifications, receipts, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.test_utils import simple_async_mock
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,7 +44,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
# Mock out the calls over federation.
|
# Mock out the calls over federation.
|
||||||
fed_transport_client = Mock(spec=["send_transaction"])
|
fed_transport_client = Mock(spec=["send_transaction"])
|
||||||
fed_transport_client.send_transaction = simple_async_mock({})
|
fed_transport_client.send_transaction = AsyncMock(return_value={})
|
||||||
|
|
||||||
return self.setup_test_homeserver(
|
return self.setup_test_homeserver(
|
||||||
federation_transport_client=fed_transport_client,
|
federation_transport_client=fed_transport_client,
|
||||||
|
|
|
@ -32,7 +32,6 @@ from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import simple_async_mock
|
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
|
|
||||||
|
|
||||||
|
@ -348,8 +347,8 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Mock out the AsyncContextManager
|
# Mock out the AsyncContextManager
|
||||||
class MockCM:
|
class MockCM:
|
||||||
__aenter__ = simple_async_mock(return_value=None)
|
__aenter__ = AsyncMock(return_value=None)
|
||||||
__aexit__ = simple_async_mock(return_value=None)
|
__aexit__ = AsyncMock(return_value=None)
|
||||||
|
|
||||||
self._update_ctx_manager = MockCM
|
self._update_ctx_manager = MockCM
|
||||||
|
|
||||||
|
|
|
@ -19,8 +19,7 @@ import json
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from binascii import unhexlify
|
from binascii import unhexlify
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
|
from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import zope.interface
|
import zope.interface
|
||||||
|
@ -62,10 +61,6 @@ def setup_awaitable_errors() -> Callable[[], None]:
|
||||||
"""
|
"""
|
||||||
warnings.simplefilter("error", RuntimeWarning)
|
warnings.simplefilter("error", RuntimeWarning)
|
||||||
|
|
||||||
# unraisablehook was added in Python 3.8.
|
|
||||||
if not hasattr(sys, "unraisablehook"):
|
|
||||||
return lambda: None
|
|
||||||
|
|
||||||
# State shared between unraisablehook and check_for_unraisable_exceptions.
|
# State shared between unraisablehook and check_for_unraisable_exceptions.
|
||||||
unraisable_exceptions = []
|
unraisable_exceptions = []
|
||||||
orig_unraisablehook = sys.unraisablehook
|
orig_unraisablehook = sys.unraisablehook
|
||||||
|
@ -88,18 +83,6 @@ def setup_awaitable_errors() -> Callable[[], None]:
|
||||||
return cleanup
|
return cleanup
|
||||||
|
|
||||||
|
|
||||||
def simple_async_mock(
|
|
||||||
return_value: Optional[TV] = None, raises: Optional[Exception] = None
|
|
||||||
) -> Mock:
|
|
||||||
# AsyncMock is not available in python3.5, this mimics part of its behaviour
|
|
||||||
async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
|
|
||||||
if raises:
|
|
||||||
raise raises
|
|
||||||
return return_value
|
|
||||||
|
|
||||||
return Mock(side_effect=cb)
|
|
||||||
|
|
||||||
|
|
||||||
# Type ignore: it does not fully implement IResponse, but is good enough for tests
|
# Type ignore: it does not fully implement IResponse, but is good enough for tests
|
||||||
@zope.interface.implementer(IResponse)
|
@zope.interface.implementer(IResponse)
|
||||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
|
Loading…
Reference in New Issue