Replace simple_async_mock with AsyncMock (#16180)

Python 3.8 has a native AsyncMock, use it instead of a custom
implementation.
pull/16185/head
Patrick Cloke 2023-08-25 09:27:21 -04:00 committed by GitHub
parent 5c9402b9fd
commit a8a46b1336
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 140 additions and 160 deletions

1
changelog.d/16180.misc Normal file
View File

@ -0,0 +1 @@
Use `AsyncMock` instead of custom code.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
import pymacaroons import pymacaroons
@ -35,7 +35,6 @@ from synapse.types import Requester, UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import simple_async_mock
from tests.unittest import override_config from tests.unittest import override_config
from tests.utils import mock_getRawHeaders from tests.utils import mock_getRawHeaders
@ -60,16 +59,16 @@ class AuthTestCase(unittest.HomeserverTestCase):
# this is overridden for the appservice tests # this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
self.store.insert_client_ip = simple_async_mock(None) self.store.insert_client_ip = AsyncMock(return_value=None)
self.store.is_support_user = simple_async_mock(False) self.store.is_support_user = AsyncMock(return_value=False)
def test_get_user_by_req_user_valid_token(self) -> None: def test_get_user_by_req_user_valid_token(self) -> None:
user_info = TokenLookupResult( user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device" user_id=self.test_user, token_id=5, device_id="device"
) )
self.store.get_user_by_access_token = simple_async_mock(user_info) self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
self.store.mark_access_token_as_used = simple_async_mock(None) self.store.mark_access_token_as_used = AsyncMock(return_value=None)
self.store.get_user_locked_status = simple_async_mock(False) self.store.get_user_locked_status = AsyncMock(return_value=False)
request = Mock(args={}) request = Mock(args={})
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
@ -78,7 +77,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(requester.user.to_string(), self.test_user) self.assertEqual(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self) -> None: def test_get_user_by_req_user_bad_token(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
@ -91,7 +90,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_get_user_by_req_user_missing_token(self) -> None: def test_get_user_by_req_user_missing_token(self) -> None:
user_info = TokenLookupResult(user_id=self.test_user, token_id=5) user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = simple_async_mock(user_info) self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@ -106,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
@ -125,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ip_range_whitelist=IPSet(["192.168/16"]), ip_range_whitelist=IPSet(["192.168/16"]),
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.getClientAddress.return_value.host = "192.168.10.10" request.getClientAddress.return_value.host = "192.168.10.10"
@ -144,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ip_range_whitelist=IPSet(["192.168/16"]), ip_range_whitelist=IPSet(["192.168/16"]),
) )
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.getClientAddress.return_value.host = "131.111.8.42" request.getClientAddress.return_value.host = "131.111.8.42"
@ -158,7 +157,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_get_user_by_req_appservice_bad_token(self) -> None: def test_get_user_by_req_appservice_bad_token(self) -> None:
self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_app_service_by_token = Mock(return_value=None)
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
@ -172,7 +171,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_get_user_by_req_appservice_missing_token(self) -> None: def test_get_user_by_req_appservice_missing_token(self) -> None:
app_service = Mock(token="foobar", url="a_url", sender=self.test_user) app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@ -190,8 +189,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value. # This just needs to return a truth-y value.
self.store.get_user_by_id = simple_async_mock({"is_guest": False}) self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
@ -210,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
app_service.is_interested_in_user = Mock(return_value=False) app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
@ -234,10 +233,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value. # This just needs to return a truth-y value.
self.store.get_user_by_id = simple_async_mock({"is_guest": False}) self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
# This also needs to just return a truth-y value # This also needs to just return a truth-y value
self.store.get_device = simple_async_mock({"hidden": False}) self.store.get_device = AsyncMock(return_value={"hidden": False})
request = Mock(args={}) request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
@ -266,10 +265,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
app_service.is_interested_in_user = Mock(return_value=True) app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_app_service_by_token = Mock(return_value=app_service)
# This just needs to return a truth-y value. # This just needs to return a truth-y value.
self.store.get_user_by_id = simple_async_mock({"is_guest": False}) self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
# This also needs to just return a falsey value # This also needs to just return a falsey value
self.store.get_device = simple_async_mock(None) self.store.get_device = AsyncMock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
@ -283,8 +282,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None: def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None:
self.store.get_user_by_access_token = simple_async_mock( self.store.get_user_by_access_token = AsyncMock(
TokenLookupResult( return_value=TokenLookupResult(
user_id="@baldrick:matrix.org", user_id="@baldrick:matrix.org",
device_id="device", device_id="device",
token_id=5, token_id=5,
@ -292,9 +291,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
token_used=True, token_used=True,
) )
) )
self.store.insert_client_ip = simple_async_mock(None) self.store.insert_client_ip = AsyncMock(return_value=None)
self.store.mark_access_token_as_used = simple_async_mock(None) self.store.mark_access_token_as_used = AsyncMock(return_value=None)
self.store.get_user_locked_status = simple_async_mock(False) self.store.get_user_locked_status = AsyncMock(return_value=False)
request = Mock(args={}) request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
@ -304,8 +303,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
self.auth._track_puppeted_user_ips = True self.auth._track_puppeted_user_ips = True
self.store.get_user_by_access_token = simple_async_mock( self.store.get_user_by_access_token = AsyncMock(
TokenLookupResult( return_value=TokenLookupResult(
user_id="@baldrick:matrix.org", user_id="@baldrick:matrix.org",
device_id="device", device_id="device",
token_id=5, token_id=5,
@ -313,9 +312,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
token_used=True, token_used=True,
) )
) )
self.store.get_user_locked_status = simple_async_mock(False) self.store.get_user_locked_status = AsyncMock(return_value=False)
self.store.insert_client_ip = simple_async_mock(None) self.store.insert_client_ip = AsyncMock(return_value=None)
self.store.mark_access_token_as_used = simple_async_mock(None) self.store.mark_access_token_as_used = AsyncMock(return_value=None)
request = Mock(args={}) request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
@ -324,7 +323,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.store.insert_client_ip.call_count, 2) self.assertEqual(self.store.insert_client_ip.call_count, 2)
def test_get_user_from_macaroon(self) -> None: def test_get_user_from_macaroon(self) -> None:
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
@ -342,8 +341,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
def test_get_guest_user_from_macaroon(self) -> None: def test_get_guest_user_from_macaroon(self) -> None:
self.store.get_user_by_id = simple_async_mock({"is_guest": True}) self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
self.store.get_user_by_access_token = simple_async_mock(None) self.store.get_user_by_access_token = AsyncMock(return_value=None)
user_id = "@baldrick:matrix.org" user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
@ -373,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = simple_async_mock(lots_of_users) self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)
e = self.get_failure( e = self.get_failure(
self.auth_blocking.check_auth_blocking(), ResourceLimitError self.auth_blocking.check_auth_blocking(), ResourceLimitError
@ -383,25 +382,27 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(e.value.code, 403) self.assertEqual(e.value.code, 403)
# Ensure does not throw an error # Ensure does not throw an error
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) self.store.get_monthly_active_count = AsyncMock(
return_value=small_number_of_users
)
self.get_success(self.auth_blocking.check_auth_blocking()) self.get_success(self.auth_blocking.check_auth_blocking())
def test_blocking_mau__depending_on_user_type(self) -> None: def test_blocking_mau__depending_on_user_type(self) -> None:
self.auth_blocking._max_mau_value = 50 self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = simple_async_mock(100) self.store.get_monthly_active_count = AsyncMock(return_value=100)
# Support users allowed # Support users allowed
self.get_success( self.get_success(
self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT) self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
) )
self.store.get_monthly_active_count = simple_async_mock(100) self.store.get_monthly_active_count = AsyncMock(return_value=100)
# Bots not allowed # Bots not allowed
self.get_failure( self.get_failure(
self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT), self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
ResourceLimitError, ResourceLimitError,
) )
self.store.get_monthly_active_count = simple_async_mock(100) self.store.get_monthly_active_count = AsyncMock(return_value=100)
# Real users not allowed # Real users not allowed
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
@ -412,9 +413,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._track_appservice_user_ips = False self.auth_blocking._track_appservice_user_ips = False
self.store.get_monthly_active_count = simple_async_mock(100) self.store.get_monthly_active_count = AsyncMock(return_value=100)
self.store.user_last_seen_monthly_active = simple_async_mock() self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
self.store.is_trial_user = simple_async_mock() self.store.is_trial_user = AsyncMock(return_value=False)
appservice = ApplicationService( appservice = ApplicationService(
"abcd", "abcd",
@ -443,9 +444,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._track_appservice_user_ips = True self.auth_blocking._track_appservice_user_ips = True
self.store.get_monthly_active_count = simple_async_mock(100) self.store.get_monthly_active_count = AsyncMock(return_value=100)
self.store.user_last_seen_monthly_active = simple_async_mock() self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
self.store.is_trial_user = simple_async_mock() self.store.is_trial_user = AsyncMock(return_value=False)
appservice = ApplicationService( appservice = ApplicationService(
"abcd", "abcd",
@ -473,7 +474,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_reserved_threepid(self) -> None: def test_reserved_threepid(self) -> None:
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1 self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = simple_async_mock(2) self.store.get_monthly_active_count = AsyncMock(return_value=2)
threepid = {"medium": "email", "address": "reserved@server.com"} threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.auth_blocking._mau_limits_reserved_threepids = [threepid] self.auth_blocking._mau_limits_reserved_threepids = [threepid]

View File

@ -13,14 +13,13 @@
# limitations under the License. # limitations under the License.
import re import re
from typing import Any, Generator from typing import Any, Generator
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.appservice import ApplicationService, Namespace from synapse.appservice import ApplicationService, Namespace
from tests import unittest from tests import unittest
from tests.test_utils import simple_async_mock
def _regex(regex: str, exclusive: bool = True) -> Namespace: def _regex(regex: str, exclusive: bool = True) -> Namespace:
@ -43,8 +42,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
self.store = Mock() self.store = Mock()
self.store.get_aliases_for_room = simple_async_mock([]) self.store.get_aliases_for_room = AsyncMock(return_value=[])
self.store.get_local_users_in_room = simple_async_mock([]) self.store.get_local_users_in_room = AsyncMock(return_value=[])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_regex_user_id_prefix_match( def test_regex_user_id_prefix_match(
@ -127,10 +126,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.store.get_aliases_for_room = simple_async_mock( self.store.get_aliases_for_room = AsyncMock(
["#irc_foobar:matrix.org", "#athing:matrix.org"] return_value=["#irc_foobar:matrix.org", "#athing:matrix.org"]
) )
self.store.get_local_users_in_room = simple_async_mock([]) self.store.get_local_users_in_room = AsyncMock(return_value=[])
self.assertTrue( self.assertTrue(
( (
yield self.service.is_interested_in_event( yield self.service.is_interested_in_event(
@ -182,10 +181,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.service.namespaces[ApplicationService.NS_ALIASES].append( self.service.namespaces[ApplicationService.NS_ALIASES].append(
_regex("#irc_.*:matrix.org") _regex("#irc_.*:matrix.org")
) )
self.store.get_aliases_for_room = simple_async_mock( self.store.get_aliases_for_room = AsyncMock(
["#xmpp_foobar:matrix.org", "#athing:matrix.org"] return_value=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
) )
self.store.get_local_users_in_room = simple_async_mock([]) self.store.get_local_users_in_room = AsyncMock(return_value=[])
self.assertFalse( self.assertFalse(
( (
yield defer.ensureDeferred( yield defer.ensureDeferred(
@ -205,8 +204,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
) )
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
self.event.sender = "@irc_foobar:matrix.org" self.event.sender = "@irc_foobar:matrix.org"
self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"]) self.store.get_aliases_for_room = AsyncMock(
self.store.get_local_users_in_room = simple_async_mock([]) return_value=["#irc_barfoo:matrix.org"]
)
self.store.get_local_users_in_room = AsyncMock(return_value=[])
self.assertTrue( self.assertTrue(
( (
yield self.service.is_interested_in_event( yield self.service.is_interested_in_event(
@ -235,10 +236,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]: def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
# Note that @irc_fo:here is the AS user. # Note that @irc_fo:here is the AS user.
self.store.get_local_users_in_room = simple_async_mock( self.store.get_local_users_in_room = AsyncMock(
["@alice:here", "@irc_fo:here", "@bob:here"] return_value=["@alice:here", "@irc_fo:here", "@bob:here"]
) )
self.store.get_aliases_for_room = simple_async_mock([]) self.store.get_aliases_for_room = AsyncMock(return_value=[])
self.event.sender = "@xmpp_foobar:matrix.org" self.event.sender = "@xmpp_foobar:matrix.org"
self.assertTrue( self.assertTrue(

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional, Sequence, Tuple, cast from typing import List, Optional, Sequence, Tuple, cast
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
@ -37,7 +37,6 @@ from synapse.types import DeviceListUpdates, JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import simple_async_mock
from ..utils import MockClock from ..utils import MockClock
@ -62,10 +61,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
txn = Mock(id=txn_id, service=service, events=events) txn = Mock(id=txn_id, service=service, events=events)
# mock methods # mock methods
self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) self.store.get_appservice_state = AsyncMock(
txn.send = simple_async_mock(True) return_value=ApplicationServiceState.UP
txn.complete = simple_async_mock(True) )
self.store.create_appservice_txn = simple_async_mock(txn) txn.send = AsyncMock(return_value=True)
txn.complete = AsyncMock(return_value=True)
self.store.create_appservice_txn = AsyncMock(return_value=txn)
# actual call # actual call
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@ -89,10 +90,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
events = [Mock(), Mock()] events = [Mock(), Mock()]
txn = Mock(id="idhere", service=service, events=events) txn = Mock(id="idhere", service=service, events=events)
self.store.get_appservice_state = simple_async_mock( self.store.get_appservice_state = AsyncMock(
ApplicationServiceState.DOWN return_value=ApplicationServiceState.DOWN
) )
self.store.create_appservice_txn = simple_async_mock(txn) self.store.create_appservice_txn = AsyncMock(return_value=txn)
# actual call # actual call
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@ -118,10 +119,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
txn = Mock(id=txn_id, service=service, events=events) txn = Mock(id=txn_id, service=service, events=events)
# mock methods # mock methods
self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) self.store.get_appservice_state = AsyncMock(
self.store.set_appservice_state = simple_async_mock(True) return_value=ApplicationServiceState.UP
txn.send = simple_async_mock(False) # fails to send )
self.store.create_appservice_txn = simple_async_mock(txn) self.store.set_appservice_state = AsyncMock(return_value=True)
txn.send = AsyncMock(return_value=False) # fails to send
self.store.create_appservice_txn = AsyncMock(return_value=txn)
# actual call # actual call
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@ -150,7 +153,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.as_api = Mock() self.as_api = Mock()
self.store = Mock() self.store = Mock()
self.service = Mock() self.service = Mock()
self.callback = simple_async_mock() self.callback = AsyncMock()
self.recoverer = _Recoverer( self.recoverer = _Recoverer(
clock=cast(Clock, self.clock), clock=cast(Clock, self.clock),
as_api=self.as_api, as_api=self.as_api,
@ -174,8 +177,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover() self.recoverer.recover()
# shouldn't have called anything prior to waiting for exp backoff # shouldn't have called anything prior to waiting for exp backoff
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = simple_async_mock(True) txn.send = AsyncMock(return_value=True)
txn.complete = simple_async_mock(None) txn.complete = AsyncMock(return_value=None)
# wait for exp backoff # wait for exp backoff
self.clock.advance_time(2) self.clock.advance_time(2)
self.assertEqual(1, txn.send.call_count) self.assertEqual(1, txn.send.call_count)
@ -202,8 +205,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.recoverer.recover() self.recoverer.recover()
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = simple_async_mock(False) txn.send = AsyncMock(return_value=False)
txn.complete = simple_async_mock(None) txn.complete = AsyncMock(return_value=None)
self.clock.advance_time(2) self.clock.advance_time(2)
self.assertEqual(1, txn.send.call_count) self.assertEqual(1, txn.send.call_count)
self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, txn.complete.call_count)
@ -216,7 +219,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEqual(3, txn.send.call_count) self.assertEqual(3, txn.send.call_count)
self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, txn.complete.call_count)
self.assertEqual(0, self.callback.call_count) self.assertEqual(0, self.callback.call_count)
txn.send = simple_async_mock(True) # successfully send the txn txn.send = AsyncMock(return_value=True) # successfully send the txn
pop_txn = True # returns the txn the first time, then no more. pop_txn = True # returns the txn the first time, then no more.
self.clock.advance_time(16) self.clock.advance_time(16)
self.assertEqual(1, txn.send.call_count) # new mock reset call count self.assertEqual(1, txn.send.call_count) # new mock reset call count
@ -244,7 +247,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None:
self.scheduler = ApplicationServiceScheduler(hs) self.scheduler = ApplicationServiceScheduler(hs)
self.txn_ctrl = Mock() self.txn_ctrl = Mock()
self.txn_ctrl.send = simple_async_mock() self.txn_ctrl.send = AsyncMock()
# Replace instantiated _TransactionController instances with our Mock # Replace instantiated _TransactionController instances with our Mock
self.scheduler.txn_ctrl = self.txn_ctrl self.scheduler.txn_ctrl = self.txn_ctrl

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
import attr import attr
@ -30,7 +30,6 @@ from synapse.types import JsonDict, StreamToken, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests.handlers.test_sync import generate_sync_config from tests.handlers.test_sync import generate_sync_config
from tests.test_utils import simple_async_mock
from tests.unittest import ( from tests.unittest import (
FederatingHomeserverTestCase, FederatingHomeserverTestCase,
HomeserverTestCase, HomeserverTestCase,
@ -157,7 +156,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation. # Mock out the calls over federation.
self.fed_transport_client = Mock(spec=["send_transaction"]) self.fed_transport_client = Mock(spec=["send_transaction"])
self.fed_transport_client.send_transaction = simple_async_mock({}) self.fed_transport_client.send_transaction = AsyncMock(return_value={})
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
federation_transport_client=self.fed_transport_client, federation_transport_client=self.fed_transport_client,

View File

@ -36,7 +36,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
from tests.test_utils import event_injection, simple_async_mock from tests.test_utils import event_injection
from tests.unittest import override_config from tests.unittest import override_config
from tests.utils import MockClock from tests.utils import MockClock
@ -399,7 +399,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.hs = hs self.hs = hs
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that # Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events # we can track any outgoing ephemeral events
self.send_mock = simple_async_mock() self.send_mock = AsyncMock()
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment]
# Mock out application services, and allow defining our own in tests # Mock out application services, and allow defining our own in tests
@ -897,7 +897,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
# will be sent over the wire # will be sent over the wire
self.put_json = simple_async_mock() self.put_json = AsyncMock()
hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment] hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment]
# Mock out application services, and allow defining our own in tests # Mock out application services, and allow defining our own in tests
@ -1003,7 +1003,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that # Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track what's going out # we can track what's going out
self.send_mock = simple_async_mock() self.send_mock = AsyncMock()
hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method.
# Define an application service for the tests # Define an application service for the tests

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict from typing import Any, Dict
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -20,7 +20,6 @@ from synapse.handlers.cas import CasResponse
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests. # These are a few constants that are used as config parameters in the tests.
@ -61,7 +60,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
cas_response = CasResponse("test_user", {}) cas_response = CasResponse("test_user", {})
request = _mock_request() request = _mock_request()
@ -89,7 +88,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
# Map a user via SSO. # Map a user via SSO.
cas_response = CasResponse("test_user", {}) cas_response = CasResponse("test_user", {})
@ -129,7 +128,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
cas_response = CasResponse("föö", {}) cas_response = CasResponse("föö", {})
request = _mock_request() request = _mock_request()
@ -160,7 +159,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department. # The response doesn't have the proper userGroup or department.
cas_response = CasResponse("test_user", {}) cas_response = CasResponse("test_user", {})

View File

@ -39,7 +39,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock from tests.test_utils import FakeResponse, get_awaitable_result
from tests.unittest import HomeserverTestCase, skip_unless from tests.unittest import HomeserverTestCase, skip_unless
from tests.utils import mock_getRawHeaders from tests.utils import mock_getRawHeaders
@ -147,7 +147,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_inactive_token(self) -> None: def test_inactive_token(self) -> None:
"""The handler should return a 403 where the token is inactive.""" """The handler should return a 403 where the token is inactive."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={"active": False}, payload={"active": False},
@ -166,7 +166,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_no_scope(self) -> None: def test_active_no_scope(self) -> None:
"""The handler should return a 403 where no scope is given.""" """The handler should return a 403 where no scope is given."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={"active": True}, payload={"active": True},
@ -185,7 +185,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_user_no_subject(self) -> None: def test_active_user_no_subject(self) -> None:
"""The handler should return a 500 when no subject is present.""" """The handler should return a 500 when no subject is present."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])}, payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
@ -204,7 +204,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_no_user_scope(self) -> None: def test_active_no_user_scope(self) -> None:
"""The handler should return a 500 when no subject is present.""" """The handler should return a 500 when no subject is present."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={ payload={
@ -227,7 +227,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_admin_not_user(self) -> None: def test_active_admin_not_user(self) -> None:
"""The handler should raise when the scope has admin right but not user.""" """The handler should raise when the scope has admin right but not user."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={ payload={
@ -251,7 +251,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_admin(self) -> None: def test_active_admin(self) -> None:
"""The handler should return a requester with admin rights.""" """The handler should return a requester with admin rights."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={ payload={
@ -281,7 +281,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_admin_highest_privilege(self) -> None: def test_active_admin_highest_privilege(self) -> None:
"""The handler should resolve to the most permissive scope.""" """The handler should resolve to the most permissive scope."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={ payload={
@ -313,7 +313,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_user(self) -> None: def test_active_user(self) -> None:
"""The handler should return a requester with normal user rights.""" """The handler should return a requester with normal user rights."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={ payload={
@ -344,7 +344,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
"""The handler should return a requester with normal user rights """The handler should return a requester with normal user rights
and an user ID matching the one specified in query param `user_id`""" and an user ID matching the one specified in query param `user_id`"""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={ payload={
@ -378,7 +378,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_user_with_device(self) -> None: def test_active_user_with_device(self) -> None:
"""The handler should return a requester with normal user rights and a device ID.""" """The handler should return a requester with normal user rights and a device ID."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={ payload={
@ -408,7 +408,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_multiple_devices(self) -> None: def test_multiple_devices(self) -> None:
"""The handler should raise an error if multiple devices are found in the scope.""" """The handler should raise an error if multiple devices are found in the scope."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={ payload={
@ -433,7 +433,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_guest_not_allowed(self) -> None: def test_active_guest_not_allowed(self) -> None:
"""The handler should return an insufficient scope error.""" """The handler should return an insufficient scope error."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={ payload={
@ -463,7 +463,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_guest_allowed(self) -> None: def test_active_guest_allowed(self) -> None:
"""The handler should return a requester with guest user rights and a device ID.""" """The handler should return a requester with guest user rights and a device ID."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={ payload={
@ -499,19 +499,19 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders() request.requestHeaders.getRawHeaders = mock_getRawHeaders()
# The introspection endpoint is returning an error. # The introspection endpoint is returning an error.
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse(code=500, body=b"Internal Server Error") return_value=FakeResponse(code=500, body=b"Internal Server Error")
) )
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503) self.assertEqual(error.value.code, 503)
# The introspection endpoint request fails. # The introspection endpoint request fails.
self.http_client.request = simple_async_mock(raises=Exception()) self.http_client.request = AsyncMock(side_effect=Exception())
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503) self.assertEqual(error.value.code, 503)
# The introspection endpoint does not return a JSON object. # The introspection endpoint does not return a JSON object.
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, payload=["this is an array", "not an object"] code=200, payload=["this is an array", "not an object"]
) )
@ -520,7 +520,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
self.assertEqual(error.value.code, 503) self.assertEqual(error.value.code, 503)
# The introspection endpoint does not return valid JSON. # The introspection endpoint does not return valid JSON.
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse(code=200, body=b"this is not valid JSON") return_value=FakeResponse(code=200, body=b"this is not valid JSON")
) )
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
@ -528,7 +528,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_introspection_token_cache(self) -> None: def test_introspection_token_cache(self) -> None:
access_token = "open_sesame" access_token = "open_sesame"
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={"active": "true", "scope": "guest", "jti": access_token}, payload={"active": "true", "scope": "guest", "jti": access_token},
@ -559,7 +559,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
# test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a # test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
# token with a soon-to-expire `exp` field to the cache # token with a soon-to-expire `exp` field to the cache
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={ payload={
@ -640,7 +640,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_cross_signing(self) -> None: def test_cross_signing(self) -> None:
"""Try uploading device keys with OAuth delegation enabled.""" """Try uploading device keys with OAuth delegation enabled."""
self.http_client.request = simple_async_mock( self.http_client.request = AsyncMock(
return_value=FakeResponse.json( return_value=FakeResponse.json(
code=200, code=200,
payload={ payload={

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
from unittest.mock import ANY, Mock, patch from unittest.mock import ANY, AsyncMock, Mock, patch
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
import pymacaroons import pymacaroons
@ -28,7 +28,7 @@ from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon from synapse.util.macaroons import get_value_from_macaroon
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock from tests.test_utils import FakeResponse, get_awaitable_result
from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
@ -164,7 +164,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler = hs.get_auth_handler() auth_handler = hs.get_auth_handler()
# Mock the complete SSO login method. # Mock the complete SSO login method.
self.complete_sso_login = simple_async_mock() self.complete_sso_login = AsyncMock()
auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment]
return hs return hs

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional, Set, Tuple from typing import Any, Dict, Optional, Set, Tuple
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
import attr import attr
@ -25,7 +25,6 @@ from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
# Check if we have the dependencies to run the tests. # Check if we have the dependencies to run the tests.
@ -134,7 +133,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
# send a mocked-up SAML response to the callback # send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@ -164,7 +163,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
# Map a user via SSO. # Map a user via SSO.
saml_response = FakeAuthnResponse( saml_response = FakeAuthnResponse(
@ -206,7 +205,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
# mock out the error renderer too # mock out the error renderer too
sso_handler = self.hs.get_sso_handler() sso_handler = self.hs.get_sso_handler()
@ -227,7 +226,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler and error renderer # stub out the auth handler and error renderer
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
sso_handler = self.hs.get_sso_handler() sso_handler = self.hs.get_sso_handler()
sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
@ -312,7 +311,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler # stub out the auth handler
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] auth_handler.complete_sso_login = AsyncMock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department. # The response doesn't have the proper userGroup or department.
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -33,7 +33,6 @@ from synapse.util import Clock
from tests.events.test_presence_router import send_presence_update, sync_presence from tests.events.test_presence_router import send_presence_update, sync_presence
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.test_utils import simple_async_mock
from tests.test_utils.event_injection import inject_member_event from tests.test_utils.event_injection import inject_member_event
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
@ -70,7 +69,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation. # Mock out the calls over federation.
self.fed_transport_client = Mock(spec=["send_transaction"]) self.fed_transport_client = Mock(spec=["send_transaction"])
self.fed_transport_client.send_transaction = simple_async_mock({}) self.fed_transport_client.send_transaction = AsyncMock(return_value={})
return self.setup_test_homeserver( return self.setup_test_homeserver(
federation_transport_client=self.fed_transport_client, federation_transport_client=self.fed_transport_client,
@ -579,9 +578,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
"""Test that the module API can join a remote room.""" """Test that the module API can join a remote room."""
# Necessary to fake a remote join. # Necessary to fake a remote join.
fake_stream_id = 1 fake_stream_id = 1
mocked_remote_join = simple_async_mock( mocked_remote_join = AsyncMock(return_value=("fake-event-id", fake_stream_id))
return_value=("fake-event-id", fake_stream_id)
)
self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment] self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment]
fake_remote_host = f"{self.module_api.server_name}-remote" fake_remote_host = f"{self.module_api.server_name}-remote"

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from typing import Any, Optional from typing import Any, Optional
from unittest.mock import patch from unittest.mock import AsyncMock, patch
from parameterized import parameterized from parameterized import parameterized
@ -28,7 +28,6 @@ from synapse.server import HomeServer
from synapse.types import JsonDict, create_requester from synapse.types import JsonDict, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
@ -191,7 +190,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# Mock the method which calculates push rules -- we do this instead of # Mock the method which calculates push rules -- we do this instead of
# e.g. checking the results in the database because we want to ensure # e.g. checking the results in the database because we want to ensure
# that code isn't even running. # that code isn't even running.
bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment] bulk_evaluator._action_for_event_by_user = AsyncMock() # type: ignore[assignment]
# Ensure no actions are generated! # Ensure no actions are generated!
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from unittest.mock import Mock from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -20,7 +20,6 @@ from synapse.rest.client import login, notifications, receipts, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -45,7 +44,7 @@ class HTTPPusherTests(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation. # Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"]) fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = simple_async_mock({}) fed_transport_client.send_transaction = AsyncMock(return_value={})
return self.setup_test_homeserver( return self.setup_test_homeserver(
federation_transport_client=fed_transport_client, federation_transport_client=fed_transport_client,

View File

@ -32,7 +32,6 @@ from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import simple_async_mock
from tests.unittest import override_config from tests.unittest import override_config
@ -348,8 +347,8 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
# Mock out the AsyncContextManager # Mock out the AsyncContextManager
class MockCM: class MockCM:
__aenter__ = simple_async_mock(return_value=None) __aenter__ = AsyncMock(return_value=None)
__aexit__ = simple_async_mock(return_value=None) __aexit__ = AsyncMock(return_value=None)
self._update_ctx_manager = MockCM self._update_ctx_manager = MockCM

View File

@ -19,8 +19,7 @@ import json
import sys import sys
import warnings import warnings
from binascii import unhexlify from binascii import unhexlify
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar
from unittest.mock import Mock
import attr import attr
import zope.interface import zope.interface
@ -62,10 +61,6 @@ def setup_awaitable_errors() -> Callable[[], None]:
""" """
warnings.simplefilter("error", RuntimeWarning) warnings.simplefilter("error", RuntimeWarning)
# unraisablehook was added in Python 3.8.
if not hasattr(sys, "unraisablehook"):
return lambda: None
# State shared between unraisablehook and check_for_unraisable_exceptions. # State shared between unraisablehook and check_for_unraisable_exceptions.
unraisable_exceptions = [] unraisable_exceptions = []
orig_unraisablehook = sys.unraisablehook orig_unraisablehook = sys.unraisablehook
@ -88,18 +83,6 @@ def setup_awaitable_errors() -> Callable[[], None]:
return cleanup return cleanup
def simple_async_mock(
return_value: Optional[TV] = None, raises: Optional[Exception] = None
) -> Mock:
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
if raises:
raise raises
return return_value
return Mock(side_effect=cb)
# Type ignore: it does not fully implement IResponse, but is good enough for tests # Type ignore: it does not fully implement IResponse, but is good enough for tests
@zope.interface.implementer(IResponse) @zope.interface.implementer(IResponse)
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)