Add final type hint to tests.unittest. (#15072)
Adds a return type to HomeServerTestCase.make_homeserver and deal with any variables which are no longer Any.pull/15037/head
parent
119e0795a5
commit
42aea0d8af
|
@ -0,0 +1 @@
|
||||||
|
Improve type hints.
|
3
mypy.ini
3
mypy.ini
|
@ -56,9 +56,6 @@ disallow_untyped_defs = False
|
||||||
[mypy-synapse.storage.database]
|
[mypy-synapse.storage.database]
|
||||||
disallow_untyped_defs = False
|
disallow_untyped_defs = False
|
||||||
|
|
||||||
[mypy-tests.unittest]
|
|
||||||
disallow_untyped_defs = False
|
|
||||||
|
|
||||||
[mypy-tests.util.caches.test_descriptors]
|
[mypy-tests.util.caches.test_descriptors]
|
||||||
disallow_untyped_defs = False
|
disallow_untyped_defs = False
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,9 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
# Listen with the config
|
# Listen with the config
|
||||||
self.hs._listen_http(parse_listener_def(0, config))
|
hs = self.hs
|
||||||
|
assert isinstance(hs, GenericWorkerServer)
|
||||||
|
hs._listen_http(parse_listener_def(0, config))
|
||||||
|
|
||||||
# Grab the resource from the site that was told to listen
|
# Grab the resource from the site that was told to listen
|
||||||
site = self.reactor.tcpServers[0][1]
|
site = self.reactor.tcpServers[0][1]
|
||||||
|
@ -115,7 +117,9 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
|
||||||
}
|
}
|
||||||
|
|
||||||
# Listen with the config
|
# Listen with the config
|
||||||
self.hs._listener_http(self.hs.config, parse_listener_def(0, config))
|
hs = self.hs
|
||||||
|
assert isinstance(hs, SynapseHomeServer)
|
||||||
|
hs._listener_http(self.hs.config, parse_listener_def(0, config))
|
||||||
|
|
||||||
# Grab the resource from the site that was told to listen
|
# Grab the resource from the site that was told to listen
|
||||||
site = self.reactor.tcpServers[0][1]
|
site = self.reactor.tcpServers[0][1]
|
||||||
|
|
|
@ -192,7 +192,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
key1 = signedjson.key.generate_signing_key("1")
|
key1 = signedjson.key.generate_signing_key("1")
|
||||||
r = self.hs.get_datastores().main.store_server_verify_keys(
|
r = self.hs.get_datastores().main.store_server_verify_keys(
|
||||||
"server9",
|
"server9",
|
||||||
time.time() * 1000,
|
int(time.time() * 1000),
|
||||||
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
|
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
|
||||||
)
|
)
|
||||||
self.get_success(r)
|
self.get_success(r)
|
||||||
|
@ -287,7 +287,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
|
||||||
key1 = signedjson.key.generate_signing_key("1")
|
key1 = signedjson.key.generate_signing_key("1")
|
||||||
r = self.hs.get_datastores().main.store_server_verify_keys(
|
r = self.hs.get_datastores().main.store_server_verify_keys(
|
||||||
"server9",
|
"server9",
|
||||||
time.time() * 1000,
|
int(time.time() * 1000),
|
||||||
# None is not a valid value in FetchKeyResult, but we're abusing this
|
# None is not a valid value in FetchKeyResult, but we're abusing this
|
||||||
# API to insert null values into the database. The nulls get converted
|
# API to insert null values into the database. The nulls get converted
|
||||||
# to 0 when fetched in KeyStore.get_server_verify_keys.
|
# to 0 when fetched in KeyStore.get_server_verify_keys.
|
||||||
|
@ -466,9 +466,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
key_json = self.get_success(
|
key_json = self.get_success(
|
||||||
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
|
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
|
||||||
)
|
)
|
||||||
res = key_json[lookup_triplet]
|
res_keys = key_json[lookup_triplet]
|
||||||
self.assertEqual(len(res), 1)
|
self.assertEqual(len(res_keys), 1)
|
||||||
res = res[0]
|
res = res_keys[0]
|
||||||
self.assertEqual(res["key_id"], testverifykey_id)
|
self.assertEqual(res["key_id"], testverifykey_id)
|
||||||
self.assertEqual(res["from_server"], SERVER_NAME)
|
self.assertEqual(res["from_server"], SERVER_NAME)
|
||||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
||||||
|
@ -584,9 +584,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
key_json = self.get_success(
|
key_json = self.get_success(
|
||||||
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
|
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
|
||||||
)
|
)
|
||||||
res = key_json[lookup_triplet]
|
res_keys = key_json[lookup_triplet]
|
||||||
self.assertEqual(len(res), 1)
|
self.assertEqual(len(res_keys), 1)
|
||||||
res = res[0]
|
res = res_keys[0]
|
||||||
self.assertEqual(res["key_id"], testverifykey_id)
|
self.assertEqual(res["key_id"], testverifykey_id)
|
||||||
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
|
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
|
||||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
||||||
|
@ -705,9 +705,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
|
||||||
key_json = self.get_success(
|
key_json = self.get_success(
|
||||||
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
|
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
|
||||||
)
|
)
|
||||||
res = key_json[lookup_triplet]
|
res_keys = key_json[lookup_triplet]
|
||||||
self.assertEqual(len(res), 1)
|
self.assertEqual(len(res_keys), 1)
|
||||||
res = res[0]
|
res = res_keys[0]
|
||||||
self.assertEqual(res["key_id"], testverifykey_id)
|
self.assertEqual(res["key_id"], testverifykey_id)
|
||||||
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
|
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
|
||||||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
|
||||||
|
|
|
@ -156,11 +156,11 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
|
||||||
|
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
# Mock out the calls over federation.
|
# Mock out the calls over federation.
|
||||||
fed_transport_client = Mock(spec=["send_transaction"])
|
self.fed_transport_client = Mock(spec=["send_transaction"])
|
||||||
fed_transport_client.send_transaction = simple_async_mock({})
|
self.fed_transport_client.send_transaction = simple_async_mock({})
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
federation_transport_client=fed_transport_client,
|
federation_transport_client=self.fed_transport_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
load_legacy_presence_router(hs)
|
load_legacy_presence_router(hs)
|
||||||
|
@ -422,7 +422,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
|
||||||
#
|
#
|
||||||
# Thus we reset the mock, and try sending all online local user
|
# Thus we reset the mock, and try sending all online local user
|
||||||
# presence again
|
# presence again
|
||||||
self.hs.get_federation_transport_client().send_transaction.reset_mock()
|
self.fed_transport_client.send_transaction.reset_mock()
|
||||||
|
|
||||||
# Broadcast local user online presence
|
# Broadcast local user online presence
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -447,9 +447,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
|
||||||
}
|
}
|
||||||
found_users = set()
|
found_users = set()
|
||||||
|
|
||||||
calls = (
|
calls = self.fed_transport_client.send_transaction.call_args_list
|
||||||
self.hs.get_federation_transport_client().send_transaction.call_args_list
|
|
||||||
)
|
|
||||||
for call in calls:
|
for call in calls:
|
||||||
call_args = call[0]
|
call_args = call[0]
|
||||||
federation_transaction: Transaction = call_args[0]
|
federation_transaction: Transaction = call_args[0]
|
||||||
|
|
|
@ -17,7 +17,7 @@ from unittest.mock import Mock
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID, create_requester
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
|
@ -56,7 +56,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||||
|
|
||||||
# Artificially raise the complexity
|
# Artificially raise the complexity
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
|
|
||||||
|
async def get_current_state_event_counts(room_id: str) -> int:
|
||||||
|
return int(500 * 1.23)
|
||||||
|
|
||||||
|
store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
|
||||||
|
|
||||||
# Get the room complexity again -- make sure it's our artificial value
|
# Get the room complexity again -- make sure it's our artificial value
|
||||||
channel = self.make_signed_federation_request(
|
channel = self.make_signed_federation_request(
|
||||||
|
@ -75,12 +79,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||||
|
|
||||||
# Mock out some things, because we don't want to test the whole join
|
# Mock out some things, because we don't want to test the whole join
|
||||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||||
handler.federation_handler.do_invite_join = Mock(
|
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(("", 1))
|
return_value=make_awaitable(("", 1))
|
||||||
)
|
)
|
||||||
|
|
||||||
d = handler._remote_join(
|
d = handler._remote_join(
|
||||||
None,
|
create_requester(u1),
|
||||||
["other.example.com"],
|
["other.example.com"],
|
||||||
"roomid",
|
"roomid",
|
||||||
UserID.from_string(u1),
|
UserID.from_string(u1),
|
||||||
|
@ -106,12 +110,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||||
|
|
||||||
# Mock out some things, because we don't want to test the whole join
|
# Mock out some things, because we don't want to test the whole join
|
||||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||||
handler.federation_handler.do_invite_join = Mock(
|
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(("", 1))
|
return_value=make_awaitable(("", 1))
|
||||||
)
|
)
|
||||||
|
|
||||||
d = handler._remote_join(
|
d = handler._remote_join(
|
||||||
None,
|
create_requester(u1),
|
||||||
["other.example.com"],
|
["other.example.com"],
|
||||||
"roomid",
|
"roomid",
|
||||||
UserID.from_string(u1),
|
UserID.from_string(u1),
|
||||||
|
@ -144,17 +148,18 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||||
|
|
||||||
# Mock out some things, because we don't want to test the whole join
|
# Mock out some things, because we don't want to test the whole join
|
||||||
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
|
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
|
||||||
handler.federation_handler.do_invite_join = Mock(
|
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(("", 1))
|
return_value=make_awaitable(("", 1))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Artificially raise the complexity
|
# Artificially raise the complexity
|
||||||
self.hs.get_datastores().main.get_current_state_event_counts = (
|
async def get_current_state_event_counts(room_id: str) -> int:
|
||||||
lambda x: make_awaitable(600)
|
return 600
|
||||||
)
|
|
||||||
|
self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
|
||||||
|
|
||||||
d = handler._remote_join(
|
d = handler._remote_join(
|
||||||
None,
|
create_requester(u1),
|
||||||
["other.example.com"],
|
["other.example.com"],
|
||||||
room_1,
|
room_1,
|
||||||
UserID.from_string(u1),
|
UserID.from_string(u1),
|
||||||
|
@ -200,12 +205,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
|
||||||
|
|
||||||
# Mock out some things, because we don't want to test the whole join
|
# Mock out some things, because we don't want to test the whole join
|
||||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||||
handler.federation_handler.do_invite_join = Mock(
|
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(("", 1))
|
return_value=make_awaitable(("", 1))
|
||||||
)
|
)
|
||||||
|
|
||||||
d = handler._remote_join(
|
d = handler._remote_join(
|
||||||
None,
|
create_requester(u1),
|
||||||
["other.example.com"],
|
["other.example.com"],
|
||||||
"roomid",
|
"roomid",
|
||||||
UserID.from_string(u1),
|
UserID.from_string(u1),
|
||||||
|
@ -230,12 +235,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
|
||||||
|
|
||||||
# Mock out some things, because we don't want to test the whole join
|
# Mock out some things, because we don't want to test the whole join
|
||||||
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
|
||||||
handler.federation_handler.do_invite_join = Mock(
|
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(("", 1))
|
return_value=make_awaitable(("", 1))
|
||||||
)
|
)
|
||||||
|
|
||||||
d = handler._remote_join(
|
d = handler._remote_join(
|
||||||
None,
|
create_requester(u1),
|
||||||
["other.example.com"],
|
["other.example.com"],
|
||||||
"roomid",
|
"roomid",
|
||||||
UserID.from_string(u1),
|
UserID.from_string(u1),
|
||||||
|
|
|
@ -5,7 +5,11 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.federation.sender import PerDestinationQueue, TransactionManager
|
from synapse.federation.sender import (
|
||||||
|
FederationSender,
|
||||||
|
PerDestinationQueue,
|
||||||
|
TransactionManager,
|
||||||
|
)
|
||||||
from synapse.federation.units import Edu, Transaction
|
from synapse.federation.units import Edu, Transaction
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
@ -33,8 +37,9 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
self.federation_transport_client = Mock(spec=["send_transaction"])
|
||||||
return self.setup_test_homeserver(
|
return self.setup_test_homeserver(
|
||||||
federation_transport_client=Mock(spec=["send_transaction"]),
|
federation_transport_client=self.federation_transport_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
@ -52,10 +57,14 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
||||||
self.pdus: List[JsonDict] = []
|
self.pdus: List[JsonDict] = []
|
||||||
self.failed_pdus: List[JsonDict] = []
|
self.failed_pdus: List[JsonDict] = []
|
||||||
self.is_online = True
|
self.is_online = True
|
||||||
self.hs.get_federation_transport_client().send_transaction.side_effect = (
|
self.federation_transport_client.send_transaction.side_effect = (
|
||||||
self.record_transaction
|
self.record_transaction
|
||||||
)
|
)
|
||||||
|
|
||||||
|
federation_sender = hs.get_federation_sender()
|
||||||
|
assert isinstance(federation_sender, FederationSender)
|
||||||
|
self.federation_sender = federation_sender
|
||||||
|
|
||||||
def default_config(self) -> JsonDict:
|
def default_config(self) -> JsonDict:
|
||||||
config = super().default_config()
|
config = super().default_config()
|
||||||
config["federation_sender_instances"] = None
|
config["federation_sender_instances"] = None
|
||||||
|
@ -229,11 +238,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
||||||
# let's delete the federation transmission queue
|
# let's delete the federation transmission queue
|
||||||
# (this pretends we are starting up fresh.)
|
# (this pretends we are starting up fresh.)
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
self.hs.get_federation_sender()
|
self.federation_sender._per_destination_queues[
|
||||||
._per_destination_queues["host2"]
|
"host2"
|
||||||
.transmission_loop_running
|
].transmission_loop_running
|
||||||
)
|
)
|
||||||
del self.hs.get_federation_sender()._per_destination_queues["host2"]
|
del self.federation_sender._per_destination_queues["host2"]
|
||||||
|
|
||||||
# let's also clear any backoffs
|
# let's also clear any backoffs
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -322,6 +331,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
||||||
# also fetch event 5 so we know its last_successful_stream_ordering later
|
# also fetch event 5 so we know its last_successful_stream_ordering later
|
||||||
event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5))
|
event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5))
|
||||||
|
|
||||||
|
assert event_2.internal_metadata.stream_ordering is not None
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
|
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
|
||||||
"host2", event_2.internal_metadata.stream_ordering
|
"host2", event_2.internal_metadata.stream_ordering
|
||||||
|
@ -425,15 +435,16 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
||||||
def wake_destination_track(destination: str) -> None:
|
def wake_destination_track(destination: str) -> None:
|
||||||
woken.append(destination)
|
woken.append(destination)
|
||||||
|
|
||||||
self.hs.get_federation_sender().wake_destination = wake_destination_track
|
self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment]
|
||||||
|
|
||||||
# cancel the pre-existing timer for _wake_destinations_needing_catchup
|
# cancel the pre-existing timer for _wake_destinations_needing_catchup
|
||||||
# this is because we are calling it manually rather than waiting for it
|
# this is because we are calling it manually rather than waiting for it
|
||||||
# to be called automatically
|
# to be called automatically
|
||||||
self.hs.get_federation_sender()._catchup_after_startup_timer.cancel()
|
assert self.federation_sender._catchup_after_startup_timer is not None
|
||||||
|
self.federation_sender._catchup_after_startup_timer.cancel()
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0
|
self.federation_sender._wake_destinations_needing_catchup(), by=5.0
|
||||||
)
|
)
|
||||||
|
|
||||||
# ASSERT (_wake_destinations_needing_catchup):
|
# ASSERT (_wake_destinations_needing_catchup):
|
||||||
|
@ -475,6 +486,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert event_1.internal_metadata.stream_ordering is not None
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
|
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
|
||||||
"host2", event_1.internal_metadata.stream_ordering
|
"host2", event_1.internal_metadata.stream_ordering
|
||||||
|
|
|
@ -178,7 +178,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
|
||||||
RoomVersions.V9,
|
RoomVersions.V9,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(pulled_pdu_info2)
|
assert pulled_pdu_info2 is not None
|
||||||
remote_pdu2 = pulled_pdu_info2.pdu
|
remote_pdu2 = pulled_pdu_info2.pdu
|
||||||
|
|
||||||
# Sanity check that we are working against the same event
|
# Sanity check that we are working against the same event
|
||||||
|
@ -226,7 +226,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
|
||||||
RoomVersions.V9,
|
RoomVersions.V9,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(pulled_pdu_info)
|
assert pulled_pdu_info is not None
|
||||||
remote_pdu = pulled_pdu_info.pdu
|
remote_pdu = pulled_pdu_info.pdu
|
||||||
|
|
||||||
# check the right call got made to the agent
|
# check the right call got made to the agent
|
||||||
|
|
|
@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
|
from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
|
||||||
from synapse.federation.units import Transaction
|
from synapse.federation.units import Transaction
|
||||||
|
from synapse.handlers.device import DeviceHandler
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login
|
from synapse.rest.client import login
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -41,8 +42,9 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
self.federation_transport_client = Mock(spec=["send_transaction"])
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
federation_transport_client=Mock(spec=["send_transaction"]),
|
federation_transport_client=self.federation_transport_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
|
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
|
||||||
|
@ -61,9 +63,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def test_send_receipts(self) -> None:
|
def test_send_receipts(self) -> None:
|
||||||
mock_send_transaction = (
|
mock_send_transaction = self.federation_transport_client.send_transaction
|
||||||
self.hs.get_federation_transport_client().send_transaction
|
|
||||||
)
|
|
||||||
mock_send_transaction.return_value = make_awaitable({})
|
mock_send_transaction.return_value = make_awaitable({})
|
||||||
|
|
||||||
sender = self.hs.get_federation_sender()
|
sender = self.hs.get_federation_sender()
|
||||||
|
@ -103,9 +103,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_send_receipts_thread(self) -> None:
|
def test_send_receipts_thread(self) -> None:
|
||||||
mock_send_transaction = (
|
mock_send_transaction = self.federation_transport_client.send_transaction
|
||||||
self.hs.get_federation_transport_client().send_transaction
|
|
||||||
)
|
|
||||||
mock_send_transaction.return_value = make_awaitable({})
|
mock_send_transaction.return_value = make_awaitable({})
|
||||||
|
|
||||||
# Create receipts for:
|
# Create receipts for:
|
||||||
|
@ -181,9 +179,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
||||||
def test_send_receipts_with_backoff(self) -> None:
|
def test_send_receipts_with_backoff(self) -> None:
|
||||||
"""Send two receipts in quick succession; the second should be flushed, but
|
"""Send two receipts in quick succession; the second should be flushed, but
|
||||||
only after 20ms"""
|
only after 20ms"""
|
||||||
mock_send_transaction = (
|
mock_send_transaction = self.federation_transport_client.send_transaction
|
||||||
self.hs.get_federation_transport_client().send_transaction
|
|
||||||
)
|
|
||||||
mock_send_transaction.return_value = make_awaitable({})
|
mock_send_transaction.return_value = make_awaitable({})
|
||||||
|
|
||||||
sender = self.hs.get_federation_sender()
|
sender = self.hs.get_federation_sender()
|
||||||
|
@ -277,10 +273,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
self.federation_transport_client = Mock(
|
||||||
|
spec=["send_transaction", "query_user_devices"]
|
||||||
|
)
|
||||||
return self.setup_test_homeserver(
|
return self.setup_test_homeserver(
|
||||||
federation_transport_client=Mock(
|
federation_transport_client=self.federation_transport_client,
|
||||||
spec=["send_transaction", "query_user_devices"]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def default_config(self) -> JsonDict:
|
def default_config(self) -> JsonDict:
|
||||||
|
@ -310,9 +307,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
|
|
||||||
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment]
|
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment]
|
||||||
|
|
||||||
|
device_handler = hs.get_device_handler()
|
||||||
|
assert isinstance(device_handler, DeviceHandler)
|
||||||
|
self.device_handler = device_handler
|
||||||
|
|
||||||
# whenever send_transaction is called, record the edu data
|
# whenever send_transaction is called, record the edu data
|
||||||
self.edus: List[JsonDict] = []
|
self.edus: List[JsonDict] = []
|
||||||
self.hs.get_federation_transport_client().send_transaction.side_effect = (
|
self.federation_transport_client.send_transaction.side_effect = (
|
||||||
self.record_transaction
|
self.record_transaction
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -353,7 +354,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
|
|
||||||
# Send the server a device list EDU for the other user, this will cause
|
# Send the server a device list EDU for the other user, this will cause
|
||||||
# it to try and resync the device lists.
|
# it to try and resync the device lists.
|
||||||
self.hs.get_federation_transport_client().query_user_devices.return_value = (
|
self.federation_transport_client.query_user_devices.return_value = (
|
||||||
make_awaitable(
|
make_awaitable(
|
||||||
{
|
{
|
||||||
"stream_id": "1",
|
"stream_id": "1",
|
||||||
|
@ -364,7 +365,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_device_handler().device_list_updater.incoming_device_list_update(
|
self.device_handler.device_list_updater.incoming_device_list_update(
|
||||||
"host2",
|
"host2",
|
||||||
{
|
{
|
||||||
"user_id": "@user2:host2",
|
"user_id": "@user2:host2",
|
||||||
|
@ -507,9 +508,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
|
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
|
||||||
|
|
||||||
# delete them again
|
# delete them again
|
||||||
self.get_success(
|
self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
|
||||||
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# We queue up device list updates to be sent over federation, so we
|
# We queue up device list updates to be sent over federation, so we
|
||||||
# advance to clear the queue.
|
# advance to clear the queue.
|
||||||
|
@ -533,7 +532,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
"""If the destination server is unreachable, all the updates should get sent on
|
"""If the destination server is unreachable, all the updates should get sent on
|
||||||
recovery
|
recovery
|
||||||
"""
|
"""
|
||||||
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
|
mock_send_txn = self.federation_transport_client.send_transaction
|
||||||
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
|
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
|
||||||
|
|
||||||
# create devices
|
# create devices
|
||||||
|
@ -543,9 +542,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
self.login("user", "pass", device_id="D3")
|
self.login("user", "pass", device_id="D3")
|
||||||
|
|
||||||
# delete them again
|
# delete them again
|
||||||
self.get_success(
|
self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
|
||||||
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# We queue up device list updates to be sent over federation, so we
|
# We queue up device list updates to be sent over federation, so we
|
||||||
# advance to clear the queue.
|
# advance to clear the queue.
|
||||||
|
@ -580,7 +577,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
|
|
||||||
This case tests the behaviour when the server has never been reachable.
|
This case tests the behaviour when the server has never been reachable.
|
||||||
"""
|
"""
|
||||||
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
|
mock_send_txn = self.federation_transport_client.send_transaction
|
||||||
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
|
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
|
||||||
|
|
||||||
# create devices
|
# create devices
|
||||||
|
@ -590,9 +587,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
self.login("user", "pass", device_id="D3")
|
self.login("user", "pass", device_id="D3")
|
||||||
|
|
||||||
# delete them again
|
# delete them again
|
||||||
self.get_success(
|
self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
|
||||||
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# We queue up device list updates to be sent over federation, so we
|
# We queue up device list updates to be sent over federation, so we
|
||||||
# advance to clear the queue.
|
# advance to clear the queue.
|
||||||
|
@ -640,7 +635,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
|
self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
|
||||||
|
|
||||||
# now the server goes offline
|
# now the server goes offline
|
||||||
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
|
mock_send_txn = self.federation_transport_client.send_transaction
|
||||||
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
|
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
|
||||||
|
|
||||||
self.login("user", "pass", device_id="D2")
|
self.login("user", "pass", device_id="D2")
|
||||||
|
@ -651,9 +646,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
|
||||||
self.reactor.advance(1)
|
self.reactor.advance(1)
|
||||||
|
|
||||||
# delete them again
|
# delete them again
|
||||||
self.get_success(
|
self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
|
||||||
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertGreaterEqual(mock_send_txn.call_count, 3)
|
self.assertGreaterEqual(mock_send_txn.call_count, 3)
|
||||||
|
|
||||||
|
|
|
@ -899,7 +899,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
|
||||||
|
|
||||||
# Mock out application services, and allow defining our own in tests
|
# Mock out application services, and allow defining our own in tests
|
||||||
self._services: List[ApplicationService] = []
|
self._services: List[ApplicationService] = []
|
||||||
self.hs.get_datastores().main.get_app_services = Mock(
|
self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
|
||||||
return_value=self._services
|
return_value=self._services
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -61,7 +61,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
||||||
|
|
||||||
cas_response = CasResponse("test_user", {})
|
cas_response = CasResponse("test_user", {})
|
||||||
request = _mock_request()
|
request = _mock_request()
|
||||||
|
@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
||||||
|
|
||||||
# Map a user via SSO.
|
# Map a user via SSO.
|
||||||
cas_response = CasResponse("test_user", {})
|
cas_response = CasResponse("test_user", {})
|
||||||
|
@ -129,7 +129,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
||||||
|
|
||||||
cas_response = CasResponse("föö", {})
|
cas_response = CasResponse("föö", {})
|
||||||
request = _mock_request()
|
request = _mock_request()
|
||||||
|
@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
||||||
|
|
||||||
# The response doesn't have the proper userGroup or department.
|
# The response doesn't have the proper userGroup or department.
|
||||||
cas_response = CasResponse("test_user", {})
|
cas_response = CasResponse("test_user", {})
|
||||||
|
|
|
@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import RoomEncryptionAlgorithms
|
from synapse.api.constants import RoomEncryptionAlgorithms
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
|
from synapse.handlers.device import DeviceHandler
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -187,37 +188,37 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# we should now have an unused alg1 key
|
# we should now have an unused alg1 key
|
||||||
res = self.get_success(
|
fallback_res = self.get_success(
|
||||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||||
)
|
)
|
||||||
self.assertEqual(res, ["alg1"])
|
self.assertEqual(fallback_res, ["alg1"])
|
||||||
|
|
||||||
# claiming an OTK when no OTKs are available should return the fallback
|
# claiming an OTK when no OTKs are available should return the fallback
|
||||||
# key
|
# key
|
||||||
res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res,
|
claim_res,
|
||||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
||||||
)
|
)
|
||||||
|
|
||||||
# we shouldn't have any unused fallback keys again
|
# we shouldn't have any unused fallback keys again
|
||||||
res = self.get_success(
|
unused_res = self.get_success(
|
||||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||||
)
|
)
|
||||||
self.assertEqual(res, [])
|
self.assertEqual(unused_res, [])
|
||||||
|
|
||||||
# claiming an OTK again should return the same fallback key
|
# claiming an OTK again should return the same fallback key
|
||||||
res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res,
|
claim_res,
|
||||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -231,10 +232,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
res = self.get_success(
|
unused_res = self.get_success(
|
||||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||||
)
|
)
|
||||||
self.assertEqual(res, [])
|
self.assertEqual(unused_res, [])
|
||||||
|
|
||||||
# uploading a new fallback key should result in an unused fallback key
|
# uploading a new fallback key should result in an unused fallback key
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -245,10 +246,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
res = self.get_success(
|
unused_res = self.get_success(
|
||||||
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
|
||||||
)
|
)
|
||||||
self.assertEqual(res, ["alg1"])
|
self.assertEqual(unused_res, ["alg1"])
|
||||||
|
|
||||||
# if the user uploads a one-time key, the next claim should fetch the
|
# if the user uploads a one-time key, the next claim should fetch the
|
||||||
# one-time key, and then go back to the fallback
|
# one-time key, and then go back to the fallback
|
||||||
|
@ -258,23 +259,23 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res,
|
claim_res,
|
||||||
{"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
|
{"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
|
||||||
)
|
)
|
||||||
|
|
||||||
res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res,
|
claim_res,
|
||||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
|
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -287,13 +288,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
res = self.get_success(
|
claim_res = self.get_success(
|
||||||
self.handler.claim_one_time_keys(
|
self.handler.claim_one_time_keys(
|
||||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
res,
|
claim_res,
|
||||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -366,7 +367,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
|
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
|
||||||
|
|
||||||
# upload two device keys, which will be signed later by the self-signing key
|
# upload two device keys, which will be signed later by the self-signing key
|
||||||
device_key_1 = {
|
device_key_1: JsonDict = {
|
||||||
"user_id": local_user,
|
"user_id": local_user,
|
||||||
"device_id": "abc",
|
"device_id": "abc",
|
||||||
"algorithms": [
|
"algorithms": [
|
||||||
|
@ -379,7 +380,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
"signatures": {local_user: {"ed25519:abc": "base64+signature"}},
|
"signatures": {local_user: {"ed25519:abc": "base64+signature"}},
|
||||||
}
|
}
|
||||||
device_key_2 = {
|
device_key_2: JsonDict = {
|
||||||
"user_id": local_user,
|
"user_id": local_user,
|
||||||
"device_id": "def",
|
"device_id": "def",
|
||||||
"algorithms": [
|
"algorithms": [
|
||||||
|
@ -451,8 +452,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
}
|
}
|
||||||
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
|
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
|
||||||
|
|
||||||
|
device_handler = self.hs.get_device_handler()
|
||||||
|
assert isinstance(device_handler, DeviceHandler)
|
||||||
e = self.get_failure(
|
e = self.get_failure(
|
||||||
self.hs.get_device_handler().check_device_registered(
|
device_handler.check_device_registered(
|
||||||
user_id=local_user,
|
user_id=local_user,
|
||||||
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
|
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
|
||||||
initial_device_display_name="new display name",
|
initial_device_display_name="new display name",
|
||||||
|
@ -475,7 +478,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
device_id = "xyz"
|
device_id = "xyz"
|
||||||
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
|
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
|
||||||
device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
|
device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
|
||||||
device_key = {
|
device_key: JsonDict = {
|
||||||
"user_id": local_user,
|
"user_id": local_user,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
"algorithms": [
|
"algorithms": [
|
||||||
|
@ -497,7 +500,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
|
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
|
||||||
master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
|
master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
|
||||||
master_key = {
|
master_key: JsonDict = {
|
||||||
"user_id": local_user,
|
"user_id": local_user,
|
||||||
"usage": ["master"],
|
"usage": ["master"],
|
||||||
"keys": {"ed25519:" + master_pubkey: master_pubkey},
|
"keys": {"ed25519:" + master_pubkey: master_pubkey},
|
||||||
|
@ -540,7 +543,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
# the first user
|
# the first user
|
||||||
other_user = "@otherboris:" + self.hs.hostname
|
other_user = "@otherboris:" + self.hs.hostname
|
||||||
other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
|
other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
|
||||||
other_master_key = {
|
other_master_key: JsonDict = {
|
||||||
# private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
|
# private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
|
||||||
"user_id": other_user,
|
"user_id": other_user,
|
||||||
"usage": ["master"],
|
"usage": ["master"],
|
||||||
|
@ -702,7 +705,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
|
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
|
||||||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
|
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
|
||||||
|
|
||||||
self.hs.get_federation_client().query_client_keys = mock.Mock(
|
self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(
|
return_value=make_awaitable(
|
||||||
{
|
{
|
||||||
"device_keys": {remote_user_id: {}},
|
"device_keys": {remote_user_id: {}},
|
||||||
|
@ -782,7 +785,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||||
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
|
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
|
||||||
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
|
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
|
||||||
|
|
||||||
self.hs.get_federation_client().query_user_devices = mock.Mock(
|
self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(
|
return_value=make_awaitable(
|
||||||
{
|
{
|
||||||
"user_id": remote_user_id,
|
"user_id": remote_user_id,
|
||||||
|
|
|
@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
# We mock out the FederationClient.backfill method, to pretend that a remote
|
# We mock out the FederationClient.backfill method, to pretend that a remote
|
||||||
# server has returned our fake event.
|
# server has returned our fake event.
|
||||||
federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
|
federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
|
||||||
self.hs.get_federation_client().backfill = federation_client_backfill_mock
|
self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment]
|
||||||
|
|
||||||
# We also mock the persist method with a side effect of itself. This allows us
|
# We also mock the persist method with a side effect of itself. This allows us
|
||||||
# to track when it has been called while preserving its function.
|
# to track when it has been called while preserving its function.
|
||||||
persist_events_and_notify_mock = Mock(
|
persist_events_and_notify_mock = Mock(
|
||||||
side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
|
side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
|
||||||
)
|
)
|
||||||
self.hs.get_federation_event_handler().persist_events_and_notify = (
|
self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment]
|
||||||
persist_events_and_notify_mock
|
persist_events_and_notify_mock
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
|
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
|
||||||
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
|
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
|
||||||
# Start the partial state sync.
|
# Start the partial state sync.
|
||||||
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
|
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
|
||||||
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
||||||
|
|
||||||
# Try to start another partial state sync.
|
# Try to start another partial state sync.
|
||||||
# Nothing should happen.
|
# Nothing should happen.
|
||||||
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
|
fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
|
||||||
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
||||||
|
|
||||||
# End the partial state sync
|
# End the partial state sync
|
||||||
|
@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
|
|
||||||
# The next attempt to start the partial state sync should work.
|
# The next attempt to start the partial state sync should work.
|
||||||
is_partial_state = True
|
is_partial_state = True
|
||||||
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
|
fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
|
||||||
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
|
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
|
||||||
|
|
||||||
def test_partial_state_room_sync_restart(self) -> None:
|
def test_partial_state_room_sync_restart(self) -> None:
|
||||||
|
@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
|
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
|
||||||
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
|
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
|
||||||
# Start the partial state sync.
|
# Start the partial state sync.
|
||||||
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
|
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
|
||||||
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
||||||
|
|
||||||
# Fail the partial state sync.
|
# Fail the partial state sync.
|
||||||
|
@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
|
||||||
|
|
||||||
# Start the partial state sync again.
|
# Start the partial state sync again.
|
||||||
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
|
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
|
||||||
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
|
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
|
||||||
|
|
||||||
# Deduplicate another partial state sync.
|
# Deduplicate another partial state sync.
|
||||||
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
|
fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
|
||||||
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
|
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
|
||||||
|
|
||||||
# Fail the partial state sync.
|
# Fail the partial state sync.
|
||||||
|
@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
|
||||||
self.assertEqual(mock_sync_partial_state_room.call_count, 3)
|
self.assertEqual(mock_sync_partial_state_room.call_count, 3)
|
||||||
mock_sync_partial_state_room.assert_called_with(
|
mock_sync_partial_state_room.assert_called_with(
|
||||||
initial_destination="hs3",
|
initial_destination="hs3",
|
||||||
other_destinations=["hs2"],
|
other_destinations={"hs2"},
|
||||||
room_id="room_id",
|
room_id="room_id",
|
||||||
)
|
)
|
||||||
|
|
|
@ -29,6 +29,7 @@ from synapse.logging.context import LoggingContext
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.state import StateResolutionStore
|
||||||
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
|
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -161,6 +162,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||||
if prev_exists_as_outlier:
|
if prev_exists_as_outlier:
|
||||||
prev_event.internal_metadata.outlier = True
|
prev_event.internal_metadata.outlier = True
|
||||||
persistence = self.hs.get_storage_controllers().persistence
|
persistence = self.hs.get_storage_controllers().persistence
|
||||||
|
assert persistence is not None
|
||||||
self.get_success(
|
self.get_success(
|
||||||
persistence.persist_event(
|
persistence.persist_event(
|
||||||
prev_event,
|
prev_event,
|
||||||
|
@ -861,7 +863,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||||
bert_member_event.event_id: bert_member_event,
|
bert_member_event.event_id: bert_member_event,
|
||||||
rejected_kick_event.event_id: rejected_kick_event,
|
rejected_kick_event.event_id: rejected_kick_event,
|
||||||
},
|
},
|
||||||
state_res_store=main_store,
|
state_res_store=StateResolutionStore(main_store),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
[bert_member_event.event_id, rejected_kick_event.event_id],
|
[bert_member_event.event_id, rejected_kick_event.event_id],
|
||||||
|
@ -906,7 +908,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||||
rejected_power_levels_event.event_id,
|
rejected_power_levels_event.event_id,
|
||||||
],
|
],
|
||||||
event_map={},
|
event_map={},
|
||||||
state_res_store=main_store,
|
state_res_store=StateResolutionStore(main_store),
|
||||||
full_conflicted_set=set(),
|
full_conflicted_set=set(),
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
|
|
|
@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.handler = self.hs.get_event_creation_handler()
|
self.handler = self.hs.get_event_creation_handler()
|
||||||
self._persist_event_storage_controller = (
|
persistence = self.hs.get_storage_controllers().persistence
|
||||||
self.hs.get_storage_controllers().persistence
|
assert persistence is not None
|
||||||
)
|
self._persist_event_storage_controller = persistence
|
||||||
|
|
||||||
self.user_id = self.register_user("tester", "foobar")
|
self.user_id = self.register_user("tester", "foobar")
|
||||||
self.access_token = self.login("tester", "foobar")
|
self.access_token = self.login("tester", "foobar")
|
||||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
||||||
|
|
||||||
self.info = self.get_success(
|
info = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(
|
self.hs.get_datastores().main.get_user_by_access_token(
|
||||||
self.access_token,
|
self.access_token,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.token_id = self.info.token_id
|
assert info is not None
|
||||||
|
self.token_id = info.token_id
|
||||||
|
|
||||||
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
|
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
|
||||||
|
|
||||||
|
|
|
@ -852,7 +852,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||||
username: The username to use for the test.
|
username: The username to use for the test.
|
||||||
registration: Whether to test with registration URLs.
|
registration: Whether to test with registration URLs.
|
||||||
"""
|
"""
|
||||||
self.hs.get_identity_handler().send_threepid_validation = Mock(
|
self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(0),
|
return_value=make_awaitable(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -203,7 +203,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
@override_config({"limit_usage_by_mau": True})
|
@override_config({"limit_usage_by_mau": True})
|
||||||
def test_get_or_create_user_mau_not_blocked(self) -> None:
|
def test_get_or_create_user_mau_not_blocked(self) -> None:
|
||||||
self.store.count_monthly_users = Mock(
|
self.store.count_monthly_users = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
|
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
|
||||||
)
|
)
|
||||||
# Ensure does not throw exception
|
# Ensure does not throw exception
|
||||||
|
@ -304,7 +304,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
|
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
|
||||||
room_alias_str = "#room:test"
|
room_alias_str = "#room:test"
|
||||||
|
|
||||||
self.store.count_real_users = Mock(return_value=make_awaitable(1))
|
self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment]
|
||||||
self.store.is_real_user = Mock(return_value=make_awaitable(True))
|
self.store.is_real_user = Mock(return_value=make_awaitable(True))
|
||||||
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
||||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||||
|
@ -319,7 +319,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
|
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
|
||||||
self,
|
self,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.store.count_real_users = Mock(return_value=make_awaitable(2))
|
self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment]
|
||||||
self.store.is_real_user = Mock(return_value=make_awaitable(True))
|
self.store.is_real_user = Mock(return_value=make_awaitable(True))
|
||||||
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
user_id = self.get_success(self.handler.register_user(localpart="real"))
|
||||||
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
|
||||||
|
@ -346,6 +346,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Ensure the room is properly not federated.
|
# Ensure the room is properly not federated.
|
||||||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
|
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
|
||||||
|
assert room is not None
|
||||||
self.assertFalse(room["federatable"])
|
self.assertFalse(room["federatable"])
|
||||||
self.assertFalse(room["public"])
|
self.assertFalse(room["public"])
|
||||||
self.assertEqual(room["join_rules"], "public")
|
self.assertEqual(room["join_rules"], "public")
|
||||||
|
@ -375,6 +376,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Ensure the room is properly a public room.
|
# Ensure the room is properly a public room.
|
||||||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
|
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
|
||||||
|
assert room is not None
|
||||||
self.assertEqual(room["join_rules"], "public")
|
self.assertEqual(room["join_rules"], "public")
|
||||||
|
|
||||||
# Both users should be in the room.
|
# Both users should be in the room.
|
||||||
|
@ -413,6 +415,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Ensure the room is properly a private room.
|
# Ensure the room is properly a private room.
|
||||||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
|
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
|
||||||
|
assert room is not None
|
||||||
self.assertFalse(room["public"])
|
self.assertFalse(room["public"])
|
||||||
self.assertEqual(room["join_rules"], "invite")
|
self.assertEqual(room["join_rules"], "invite")
|
||||||
self.assertEqual(room["guest_access"], "can_join")
|
self.assertEqual(room["guest_access"], "can_join")
|
||||||
|
@ -456,6 +459,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Ensure the room is properly a private room.
|
# Ensure the room is properly a private room.
|
||||||
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
|
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
|
||||||
|
assert room is not None
|
||||||
self.assertFalse(room["public"])
|
self.assertFalse(room["public"])
|
||||||
self.assertEqual(room["join_rules"], "invite")
|
self.assertEqual(room["join_rules"], "invite")
|
||||||
self.assertEqual(room["guest_access"], "can_join")
|
self.assertEqual(room["guest_access"], "can_join")
|
||||||
|
|
|
@ -134,7 +134,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
||||||
|
|
||||||
# send a mocked-up SAML response to the callback
|
# send a mocked-up SAML response to the callback
|
||||||
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
||||||
|
@ -164,7 +164,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
||||||
|
|
||||||
# Map a user via SSO.
|
# Map a user via SSO.
|
||||||
saml_response = FakeAuthnResponse(
|
saml_response = FakeAuthnResponse(
|
||||||
|
@ -206,11 +206,11 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
||||||
|
|
||||||
# mock out the error renderer too
|
# mock out the error renderer too
|
||||||
sso_handler = self.hs.get_sso_handler()
|
sso_handler = self.hs.get_sso_handler()
|
||||||
sso_handler.render_error = Mock(return_value=None)
|
sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
|
||||||
|
|
||||||
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
|
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
|
||||||
request = _mock_request()
|
request = _mock_request()
|
||||||
|
@ -227,9 +227,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler and error renderer
|
# stub out the auth handler and error renderer
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
||||||
sso_handler = self.hs.get_sso_handler()
|
sso_handler = self.hs.get_sso_handler()
|
||||||
sso_handler.render_error = Mock(return_value=None)
|
sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
|
||||||
|
|
||||||
# register a user to occupy the first-choice MXID
|
# register a user to occupy the first-choice MXID
|
||||||
store = self.hs.get_datastores().main
|
store = self.hs.get_datastores().main
|
||||||
|
@ -312,7 +312,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
# stub out the auth handler
|
# stub out the auth handler
|
||||||
auth_handler = self.hs.get_auth_handler()
|
auth_handler = self.hs.get_auth_handler()
|
||||||
auth_handler.complete_sso_login = simple_async_mock()
|
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
|
||||||
|
|
||||||
# The response doesn't have the proper userGroup or department.
|
# The response doesn't have the proper userGroup or department.
|
||||||
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
|
||||||
|
|
|
@ -74,8 +74,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
|
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
|
||||||
|
|
||||||
# we mock out the federation client too
|
# we mock out the federation client too
|
||||||
mock_federation_client = Mock(spec=["put_json"])
|
self.mock_federation_client = Mock(spec=["put_json"])
|
||||||
mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
|
self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
|
||||||
|
|
||||||
# the tests assume that we are starting at unix time 1000
|
# the tests assume that we are starting at unix time 1000
|
||||||
reactor.pump((1000,))
|
reactor.pump((1000,))
|
||||||
|
@ -83,7 +83,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
self.mock_hs_notifier = Mock()
|
self.mock_hs_notifier = Mock()
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
notifier=self.mock_hs_notifier,
|
notifier=self.mock_hs_notifier,
|
||||||
federation_http_client=mock_federation_client,
|
federation_http_client=self.mock_federation_client,
|
||||||
keyring=mock_keyring,
|
keyring=mock_keyring,
|
||||||
replication_streams={},
|
replication_streams={},
|
||||||
)
|
)
|
||||||
|
@ -233,8 +233,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
put_json = self.hs.get_federation_http_client().put_json
|
self.mock_federation_client.put_json.assert_called_once_with(
|
||||||
put_json.assert_called_once_with(
|
|
||||||
"farm",
|
"farm",
|
||||||
path="/_matrix/federation/v1/send/1000000",
|
path="/_matrix/federation/v1/send/1000000",
|
||||||
data=_expect_edu_transaction(
|
data=_expect_edu_transaction(
|
||||||
|
@ -349,8 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
|
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
|
||||||
|
|
||||||
put_json = self.hs.get_federation_http_client().put_json
|
self.mock_federation_client.put_json.assert_called_once_with(
|
||||||
put_json.assert_called_once_with(
|
|
||||||
"farm",
|
"farm",
|
||||||
path="/_matrix/federation/v1/send/1000000",
|
path="/_matrix/federation/v1/send/1000000",
|
||||||
data=_expect_edu_transaction(
|
data=_expect_edu_transaction(
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Tuple
|
from typing import Any, Tuple
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ from synapse.appservice import ApplicationService
|
||||||
from synapse.rest.client import login, register, room, user_directory
|
from synapse.rest.client import login, register, room, user_directory
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.roommember import ProfileInfo
|
from synapse.storage.roommember import ProfileInfo
|
||||||
from synapse.types import create_requester
|
from synapse.types import UserProfile, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -34,6 +34,12 @@ from tests.test_utils.event_injection import inject_member_event
|
||||||
from tests.unittest import override_config
|
from tests.unittest import override_config
|
||||||
|
|
||||||
|
|
||||||
|
# A spam checker which doesn't implement anything, so create a bare object.
|
||||||
|
class UselessSpamChecker:
|
||||||
|
def __init__(self, config: Any):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
"""Tests the UserDirectoryHandler.
|
"""Tests the UserDirectoryHandler.
|
||||||
|
|
||||||
|
@ -773,7 +779,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
||||||
self.assertEqual(len(s["results"]), 1)
|
self.assertEqual(len(s["results"]), 1)
|
||||||
|
|
||||||
async def allow_all(user_profile: ProfileInfo) -> bool:
|
async def allow_all(user_profile: UserProfile) -> bool:
|
||||||
# Allow all users.
|
# Allow all users.
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -787,7 +793,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(s["results"]), 1)
|
self.assertEqual(len(s["results"]), 1)
|
||||||
|
|
||||||
# Configure a spam checker that filters all users.
|
# Configure a spam checker that filters all users.
|
||||||
async def block_all(user_profile: ProfileInfo) -> bool:
|
async def block_all(user_profile: UserProfile) -> bool:
|
||||||
# All users are spammy.
|
# All users are spammy.
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -797,6 +803,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
||||||
self.assertEqual(len(s["results"]), 0)
|
self.assertEqual(len(s["results"]), 0)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"spam_checker": {
|
||||||
|
"module": "tests.handlers.test_user_directory.UselessSpamChecker"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
def test_legacy_spam_checker(self) -> None:
|
def test_legacy_spam_checker(self) -> None:
|
||||||
"""
|
"""
|
||||||
A spam checker without the expected method should be ignored.
|
A spam checker without the expected method should be ignored.
|
||||||
|
@ -825,11 +838,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
|
self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
|
||||||
self.assertEqual(public_users, set())
|
self.assertEqual(public_users, set())
|
||||||
|
|
||||||
# Configure a spam checker.
|
|
||||||
spam_checker = self.hs.get_spam_checker()
|
|
||||||
# The spam checker doesn't need any methods, so create a bare object.
|
|
||||||
spam_checker.spam_checker = object()
|
|
||||||
|
|
||||||
# We get one search result when searching for user2 by user1.
|
# We get one search result when searching for user2 by user1.
|
||||||
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
s = self.get_success(self.handler.search_users(u1, "user2", 10))
|
||||||
self.assertEqual(len(s["results"]), 1)
|
self.assertEqual(len(s["results"]), 1)
|
||||||
|
@ -954,10 +962,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
context = self.get_success(unpersisted_context.persist(event))
|
context = self.get_success(unpersisted_context.persist(event))
|
||||||
|
persistence = self.hs.get_storage_controllers().persistence
|
||||||
self.get_success(
|
assert persistence is not None
|
||||||
self.hs.get_storage_controllers().persistence.persist_event(event, context)
|
self.get_success(persistence.persist_event(event, context))
|
||||||
)
|
|
||||||
|
|
||||||
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
|
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
|
||||||
"""We've chosen to simplify the user directory's implementation by
|
"""We've chosen to simplify the user directory's implementation by
|
||||||
|
|
|
@ -68,11 +68,11 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
||||||
|
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
# Mock out the calls over federation.
|
# Mock out the calls over federation.
|
||||||
fed_transport_client = Mock(spec=["send_transaction"])
|
self.fed_transport_client = Mock(spec=["send_transaction"])
|
||||||
fed_transport_client.send_transaction = simple_async_mock({})
|
self.fed_transport_client.send_transaction = simple_async_mock({})
|
||||||
|
|
||||||
return self.setup_test_homeserver(
|
return self.setup_test_homeserver(
|
||||||
federation_transport_client=fed_transport_client,
|
federation_transport_client=self.fed_transport_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_can_register_user(self) -> None:
|
def test_can_register_user(self) -> None:
|
||||||
|
@ -417,7 +417,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
||||||
#
|
#
|
||||||
# Thus we reset the mock, and try sending online local user
|
# Thus we reset the mock, and try sending online local user
|
||||||
# presence again
|
# presence again
|
||||||
self.hs.get_federation_transport_client().send_transaction.reset_mock()
|
self.fed_transport_client.send_transaction.reset_mock()
|
||||||
|
|
||||||
# Broadcast local user online presence
|
# Broadcast local user online presence
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -429,9 +429,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
||||||
|
|
||||||
# Check that a presence update was sent as part of a federation transaction
|
# Check that a presence update was sent as part of a federation transaction
|
||||||
found_update = False
|
found_update = False
|
||||||
calls = (
|
calls = self.fed_transport_client.send_transaction.call_args_list
|
||||||
self.hs.get_federation_transport_client().send_transaction.call_args_list
|
|
||||||
)
|
|
||||||
for call in calls:
|
for call in calls:
|
||||||
call_args = call[0]
|
call_args = call[0]
|
||||||
federation_transaction: Transaction = call_args[0]
|
federation_transaction: Transaction = call_args[0]
|
||||||
|
@ -581,7 +579,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
|
||||||
mocked_remote_join = simple_async_mock(
|
mocked_remote_join = simple_async_mock(
|
||||||
return_value=("fake-event-id", fake_stream_id)
|
return_value=("fake-event-id", fake_stream_id)
|
||||||
)
|
)
|
||||||
self.hs.get_room_member_handler()._remote_join = mocked_remote_join
|
self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment]
|
||||||
fake_remote_host = f"{self.module_api.server_name}-remote"
|
fake_remote_host = f"{self.module_api.server_name}-remote"
|
||||||
|
|
||||||
# Given that the join is to be faked, we expect the relevant join event not to
|
# Given that the join is to be faked, we expect the relevant join event not to
|
||||||
|
|
|
@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
|
from synapse.push.emailpusher import EmailPusher
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -105,6 +106,7 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(self.access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(self.access_token)
|
||||||
)
|
)
|
||||||
|
assert user_tuple is not None
|
||||||
self.token_id = user_tuple.token_id
|
self.token_id = user_tuple.token_id
|
||||||
|
|
||||||
# We need to add email to account before we can create a pusher.
|
# We need to add email to account before we can create a pusher.
|
||||||
|
@ -114,7 +116,7 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pusher = self.get_success(
|
pusher = self.get_success(
|
||||||
self.hs.get_pusherpool().add_or_update_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
access_token=self.token_id,
|
access_token=self.token_id,
|
||||||
|
@ -127,6 +129,8 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
data={},
|
data={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
assert isinstance(pusher, EmailPusher)
|
||||||
|
self.pusher = pusher
|
||||||
|
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
@ -375,10 +379,13 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# check that the pusher for that email address has been deleted
|
# check that the pusher for that email address has been deleted
|
||||||
pushers = self.get_success(
|
pushers = list(
|
||||||
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
|
self.get_success(
|
||||||
|
self.hs.get_datastores().main.get_pushers_by(
|
||||||
|
{"user_name": self.user_id}
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
|
||||||
self.assertEqual(len(pushers), 0)
|
self.assertEqual(len(pushers), 0)
|
||||||
|
|
||||||
def test_remove_unlinked_pushers_background_job(self) -> None:
|
def test_remove_unlinked_pushers_background_job(self) -> None:
|
||||||
|
@ -413,10 +420,13 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
self.wait_for_background_updates()
|
self.wait_for_background_updates()
|
||||||
|
|
||||||
# Check that all pushers with unlinked addresses were deleted
|
# Check that all pushers with unlinked addresses were deleted
|
||||||
pushers = self.get_success(
|
pushers = list(
|
||||||
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
|
self.get_success(
|
||||||
|
self.hs.get_datastores().main.get_pushers_by(
|
||||||
|
{"user_name": self.user_id}
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
|
||||||
self.assertEqual(len(pushers), 0)
|
self.assertEqual(len(pushers), 0)
|
||||||
|
|
||||||
def _check_for_mail(self) -> Tuple[Sequence, Dict]:
|
def _check_for_mail(self) -> Tuple[Sequence, Dict]:
|
||||||
|
@ -428,10 +438,13 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
that notification.
|
that notification.
|
||||||
"""
|
"""
|
||||||
# Get the stream ordering before it gets sent
|
# Get the stream ordering before it gets sent
|
||||||
pushers = self.get_success(
|
pushers = list(
|
||||||
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
|
self.get_success(
|
||||||
|
self.hs.get_datastores().main.get_pushers_by(
|
||||||
|
{"user_name": self.user_id}
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
last_stream_ordering = pushers[0].last_stream_ordering
|
last_stream_ordering = pushers[0].last_stream_ordering
|
||||||
|
|
||||||
|
@ -439,10 +452,13 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
self.pump(10)
|
self.pump(10)
|
||||||
|
|
||||||
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
|
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
|
||||||
pushers = self.get_success(
|
pushers = list(
|
||||||
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
|
self.get_success(
|
||||||
|
self.hs.get_datastores().main.get_pushers_by(
|
||||||
|
{"user_name": self.user_id}
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
|
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
|
||||||
|
|
||||||
|
@ -458,10 +474,13 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
self.assertEqual(len(self.email_attempts), 1)
|
self.assertEqual(len(self.email_attempts), 1)
|
||||||
|
|
||||||
# The stream ordering has increased
|
# The stream ordering has increased
|
||||||
pushers = self.get_success(
|
pushers = list(
|
||||||
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
|
self.get_success(
|
||||||
|
self.hs.get_datastores().main.get_pushers_by(
|
||||||
|
{"user_name": self.user_id}
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
|
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import List, Optional, Tuple
|
from typing import Any, List, Tuple
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
|
@ -22,7 +22,6 @@ from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.push import PusherConfig, PusherConfigException
|
from synapse.push import PusherConfig, PusherConfigException
|
||||||
from synapse.rest.client import login, push_rule, pusher, receipts, room
|
from synapse.rest.client import login, push_rule, pusher, receipts, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
@ -67,9 +66,10 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
|
assert user_tuple is not None
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
def test_data(data: Optional[JsonDict]) -> None:
|
def test_data(data: Any) -> None:
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.hs.get_pusherpool().add_or_update_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -113,6 +113,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
|
assert user_tuple is not None
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -140,10 +141,11 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
self.helper.send(room, body="There!", tok=other_access_token)
|
self.helper.send(room, body="There!", tok=other_access_token)
|
||||||
|
|
||||||
# Get the stream ordering before it gets sent
|
# Get the stream ordering before it gets sent
|
||||||
pushers = self.get_success(
|
pushers = list(
|
||||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
|
self.get_success(
|
||||||
|
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
last_stream_ordering = pushers[0].last_stream_ordering
|
last_stream_ordering = pushers[0].last_stream_ordering
|
||||||
|
|
||||||
|
@ -151,10 +153,11 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
self.pump()
|
self.pump()
|
||||||
|
|
||||||
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
|
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
|
||||||
pushers = self.get_success(
|
pushers = list(
|
||||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
|
self.get_success(
|
||||||
|
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
|
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
|
||||||
|
|
||||||
|
@ -172,10 +175,11 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
self.pump()
|
self.pump()
|
||||||
|
|
||||||
# The stream ordering has increased
|
# The stream ordering has increased
|
||||||
pushers = self.get_success(
|
pushers = list(
|
||||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
|
self.get_success(
|
||||||
|
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
|
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
|
||||||
last_stream_ordering = pushers[0].last_stream_ordering
|
last_stream_ordering = pushers[0].last_stream_ordering
|
||||||
|
@ -194,10 +198,11 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
self.pump()
|
self.pump()
|
||||||
|
|
||||||
# The stream ordering has increased, again
|
# The stream ordering has increased, again
|
||||||
pushers = self.get_success(
|
pushers = list(
|
||||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
|
self.get_success(
|
||||||
|
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
|
||||||
|
)
|
||||||
)
|
)
|
||||||
pushers = list(pushers)
|
|
||||||
self.assertEqual(len(pushers), 1)
|
self.assertEqual(len(pushers), 1)
|
||||||
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
|
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
|
||||||
|
|
||||||
|
@ -229,6 +234,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
|
assert user_tuple is not None
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -349,6 +355,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
|
assert user_tuple is not None
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -435,6 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
|
assert user_tuple is not None
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -512,6 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
|
assert user_tuple is not None
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -618,6 +627,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
|
assert user_tuple is not None
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -753,6 +763,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
|
assert user_tuple is not None
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
@ -895,6 +906,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
|
assert user_tuple is not None
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
device_id = user_tuple.device_id
|
device_id = user_tuple.device_id
|
||||||
|
|
||||||
|
@ -941,9 +953,10 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Look up the user info for the access token so we can compare the device ID.
|
# Look up the user info for the access token so we can compare the device ID.
|
||||||
lookup_result: TokenLookupResult = self.get_success(
|
lookup_result = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
|
assert lookup_result is not None
|
||||||
|
|
||||||
# Get the user's devices and check it has the correct device ID.
|
# Get the user's devices and check it has the correct device ID.
|
||||||
channel = self.make_request("GET", "/pushers", access_token=access_token)
|
channel = self.make_request("GET", "/pushers", access_token=access_token)
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional, Sequence
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# this is the point in the DAG where we make a fork
|
# this is the point in the DAG where we make a fork
|
||||||
fork_point: List[str] = self.get_success(
|
fork_point: Sequence[str] = self.get_success(
|
||||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
|
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -168,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
pl_event = self.get_success(
|
pl_event = self.get_success(
|
||||||
inject_event(
|
inject_event(
|
||||||
self.hs,
|
self.hs,
|
||||||
prev_event_ids=prev_events,
|
prev_event_ids=list(prev_events),
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key="",
|
state_key="",
|
||||||
sender=self.user_id,
|
sender=self.user_id,
|
||||||
|
@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# this is the point in the DAG where we make a fork
|
# this is the point in the DAG where we make a fork
|
||||||
fork_point: List[str] = self.get_success(
|
fork_point: Sequence[str] = self.get_success(
|
||||||
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
|
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -323,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
|
||||||
e = self.get_success(
|
e = self.get_success(
|
||||||
inject_event(
|
inject_event(
|
||||||
self.hs,
|
self.hs,
|
||||||
prev_event_ids=prev_events,
|
prev_event_ids=list(prev_events),
|
||||||
type=EventTypes.PowerLevels,
|
type=EventTypes.PowerLevels,
|
||||||
state_key="",
|
state_key="",
|
||||||
sender=self.user_id,
|
sender=self.user_id,
|
||||||
|
|
|
@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
room_id = self.helper.create_room_as("@bob:test")
|
room_id = self.helper.create_room_as("@bob:test")
|
||||||
# Mark the room as partial-stated.
|
# Mark the room as partial-stated.
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1")
|
self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1")
|
||||||
)
|
)
|
||||||
|
|
||||||
worker = self.make_worker_hs("synapse.app.generic_worker")
|
worker = self.make_worker_hs("synapse.app.generic_worker")
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from synapse.handlers.typing import RoomMember
|
from synapse.handlers.typing import RoomMember, TypingWriterHandler
|
||||||
from synapse.replication.tcp.streams import TypingStream
|
from synapse.replication.tcp.streams import TypingStream
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
|
@ -33,6 +33,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
||||||
|
|
||||||
def test_typing(self) -> None:
|
def test_typing(self) -> None:
|
||||||
typing = self.hs.get_typing_handler()
|
typing = self.hs.get_typing_handler()
|
||||||
|
assert isinstance(typing, TypingWriterHandler)
|
||||||
|
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
|
|
||||||
|
@ -88,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
|
||||||
sends the proper position and RDATA).
|
sends the proper position and RDATA).
|
||||||
"""
|
"""
|
||||||
typing = self.hs.get_typing_handler()
|
typing = self.hs.get_typing_handler()
|
||||||
|
assert isinstance(typing, TypingWriterHandler)
|
||||||
|
|
||||||
self.reconnect()
|
self.reconnect()
|
||||||
|
|
||||||
|
|
|
@ -127,6 +127,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
|
||||||
# ... updating the cache ID gen on the master still shouldn't cause the
|
# ... updating the cache ID gen on the master still shouldn't cause the
|
||||||
# deferred to wake up.
|
# deferred to wake up.
|
||||||
|
assert store._cache_id_gen is not None
|
||||||
ctx = store._cache_id_gen.get_next()
|
ctx = store._cache_id_gen.get_next()
|
||||||
self.get_success(ctx.__aenter__())
|
self.get_success(ctx.__aenter__())
|
||||||
self.get_success(ctx.__aexit__(None, None, None))
|
self.get_success(ctx.__aexit__(None, None, None))
|
||||||
|
|
|
@ -16,6 +16,7 @@ from unittest.mock import Mock
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.events.builder import EventBuilderFactory
|
from synapse.events.builder import EventBuilderFactory
|
||||||
|
from synapse.handlers.typing import TypingWriterHandler
|
||||||
from synapse.rest.admin import register_servlets_for_client_rest_resource
|
from synapse.rest.admin import register_servlets_for_client_rest_resource
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
|
@ -174,6 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
token = self.login("user3", "pass")
|
token = self.login("user3", "pass")
|
||||||
|
|
||||||
typing_handler = self.hs.get_typing_handler()
|
typing_handler = self.hs.get_typing_handler()
|
||||||
|
assert isinstance(typing_handler, TypingWriterHandler)
|
||||||
|
|
||||||
sent_on_1 = False
|
sent_on_1 = False
|
||||||
sent_on_2 = False
|
sent_on_2 = False
|
||||||
|
|
|
@ -50,6 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
user_dict = self.get_success(
|
user_dict = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
|
assert user_dict is not None
|
||||||
token_id = user_dict.token_id
|
token_id = user_dict.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
|
|
@ -2913,7 +2913,8 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
||||||
other_user_tok = self.login("user", "pass")
|
other_user_tok = self.login("user", "pass")
|
||||||
event_builder_factory = self.hs.get_event_builder_factory()
|
event_builder_factory = self.hs.get_event_builder_factory()
|
||||||
event_creation_handler = self.hs.get_event_creation_handler()
|
event_creation_handler = self.hs.get_event_creation_handler()
|
||||||
storage_controllers = self.hs.get_storage_controllers()
|
persistence = self.hs.get_storage_controllers().persistence
|
||||||
|
assert persistence is not None
|
||||||
|
|
||||||
# Create two rooms, one with a local user only and one with both a local
|
# Create two rooms, one with a local user only and one with both a local
|
||||||
# and remote user.
|
# and remote user.
|
||||||
|
@ -2940,7 +2941,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
context = self.get_success(unpersisted_context.persist(event))
|
context = self.get_success(unpersisted_context.persist(event))
|
||||||
|
|
||||||
self.get_success(storage_controllers.persistence.persist_event(event, context))
|
self.get_success(persistence.persist_event(event, context))
|
||||||
|
|
||||||
# Now get rooms
|
# Now get rooms
|
||||||
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
|
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
|
||||||
|
|
|
@ -11,6 +11,8 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
|
@ -33,9 +35,14 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
|
||||||
self.register_user("admin", "pass", admin=True)
|
self.register_user("admin", "pass", admin=True)
|
||||||
self.admin_user_tok = self.login("admin", "pass")
|
self.admin_user_tok = self.login("admin", "pass")
|
||||||
|
|
||||||
async def check_username(username: str) -> bool:
|
async def check_username(
|
||||||
if username == "allowed":
|
localpart: str,
|
||||||
return True
|
guest_access_token: Optional[str] = None,
|
||||||
|
assigned_user_id: Optional[str] = None,
|
||||||
|
inhibit_user_in_use_error: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if localpart == "allowed":
|
||||||
|
return
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
"User ID already taken.",
|
"User ID already taken.",
|
||||||
|
@ -43,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
handler = self.hs.get_registration_handler()
|
handler = self.hs.get_registration_handler()
|
||||||
handler.check_username = check_username
|
handler.check_username = check_username # type: ignore[assignment]
|
||||||
|
|
||||||
def test_username_available(self) -> None:
|
def test_username_available(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1193,7 +1193,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Register a mock that will return the expected result depending on the remote.
|
# Register a mock that will return the expected result depending on the remote.
|
||||||
self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
|
self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment]
|
||||||
|
|
||||||
# Check that we've got the correct response from the client-side endpoint.
|
# Check that we've got the correct response from the client-side endpoint.
|
||||||
self._test_status(
|
self._test_status(
|
||||||
|
|
|
@ -63,14 +63,14 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def test_add_filter_non_local_user(self) -> None:
|
def test_add_filter_non_local_user(self) -> None:
|
||||||
_is_mine = self.hs.is_mine
|
_is_mine = self.hs.is_mine
|
||||||
self.hs.is_mine = lambda target_user: False
|
self.hs.is_mine = lambda target_user: False # type: ignore[assignment]
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"POST",
|
"POST",
|
||||||
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
|
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
|
||||||
self.EXAMPLE_FILTER_JSON,
|
self.EXAMPLE_FILTER_JSON,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hs.is_mine = _is_mine
|
self.hs.is_mine = _is_mine # type: ignore[assignment]
|
||||||
self.assertEqual(channel.code, 403)
|
self.assertEqual(channel.code, 403)
|
||||||
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
|
|
@ -36,14 +36,14 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
|
||||||
presence_handler = Mock(spec=PresenceHandler)
|
self.presence_handler = Mock(spec=PresenceHandler)
|
||||||
presence_handler.set_state.return_value = make_awaitable(None)
|
self.presence_handler.set_state.return_value = make_awaitable(None)
|
||||||
|
|
||||||
hs = self.setup_test_homeserver(
|
hs = self.setup_test_homeserver(
|
||||||
"red",
|
"red",
|
||||||
federation_http_client=None,
|
federation_http_client=None,
|
||||||
federation_client=Mock(),
|
federation_client=Mock(),
|
||||||
presence_handler=presence_handler,
|
presence_handler=self.presence_handler,
|
||||||
)
|
)
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
@ -61,7 +61,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
|
self.assertEqual(self.presence_handler.set_state.call_count, 1)
|
||||||
|
|
||||||
@unittest.override_config({"use_presence": False})
|
@unittest.override_config({"use_presence": False})
|
||||||
def test_put_presence_disabled(self) -> None:
|
def test_put_presence_disabled(self) -> None:
|
||||||
|
@ -76,4 +76,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
|
self.assertEqual(self.presence_handler.set_state.call_count, 0)
|
||||||
|
|
|
@ -151,7 +151,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||||
|
|
||||||
def test_POST_guest_registration(self) -> None:
|
def test_POST_guest_registration(self) -> None:
|
||||||
self.hs.config.key.macaroon_secret_key = "test"
|
self.hs.config.key.macaroon_secret_key = b"test"
|
||||||
self.hs.config.registration.allow_guest_access = True
|
self.hs.config.registration.allow_guest_access = True
|
||||||
|
|
||||||
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||||
|
@ -1166,12 +1166,15 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
user_id = self.register_user("kermit_delta", "user")
|
user_id = self.register_user("kermit_delta", "user")
|
||||||
|
|
||||||
self.hs.config.account_validity.startup_job_max_delta = self.max_delta
|
self.hs.config.account_validity.account_validity_startup_job_max_delta = (
|
||||||
|
self.max_delta
|
||||||
|
)
|
||||||
|
|
||||||
now_ms = self.hs.get_clock().time_msec()
|
now_ms = self.hs.get_clock().time_msec()
|
||||||
self.get_success(self.store._set_expiration_date_when_missing())
|
self.get_success(self.store._set_expiration_date_when_missing())
|
||||||
|
|
||||||
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
|
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
|
||||||
|
assert res is not None
|
||||||
|
|
||||||
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
|
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
|
||||||
self.assertLessEqual(res, now_ms + self.validity_period)
|
self.assertLessEqual(res, now_ms + self.validity_period)
|
||||||
|
|
|
@ -136,6 +136,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||||
# Send a first event, which should be filtered out at the end of the test.
|
# Send a first event, which should be filtered out at the end of the test.
|
||||||
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
|
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
|
||||||
first_event_id = resp.get("event_id")
|
first_event_id = resp.get("event_id")
|
||||||
|
assert isinstance(first_event_id, str)
|
||||||
|
|
||||||
# Advance the time by 2 days. We're using the default retention policy, therefore
|
# Advance the time by 2 days. We're using the default retention policy, therefore
|
||||||
# after this the first event will still be valid.
|
# after this the first event will still be valid.
|
||||||
|
@ -144,6 +145,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||||
# Send another event, which shouldn't get filtered out.
|
# Send another event, which shouldn't get filtered out.
|
||||||
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
|
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
|
||||||
valid_event_id = resp.get("event_id")
|
valid_event_id = resp.get("event_id")
|
||||||
|
assert isinstance(valid_event_id, str)
|
||||||
|
|
||||||
# Advance the time by another 2 days. After this, the first event should be
|
# Advance the time by another 2 days. After this, the first event should be
|
||||||
# outdated but not the second one.
|
# outdated but not the second one.
|
||||||
|
@ -229,7 +231,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Check that we can still access state events that were sent before the event that
|
# Check that we can still access state events that were sent before the event that
|
||||||
# has been purged.
|
# has been purged.
|
||||||
self.get_event(room_id, create_event.event_id)
|
self.get_event(room_id, bool(create_event))
|
||||||
|
|
||||||
def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict:
|
def get_event(self, event_id: str, expect_none: bool = False) -> JsonDict:
|
||||||
event = self.get_success(self.store.get_event(event_id, allow_none=True))
|
event = self.get_success(self.store.get_event(event_id, allow_none=True))
|
||||||
|
@ -238,7 +240,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertIsNone(event)
|
self.assertIsNone(event)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
self.assertIsNotNone(event)
|
assert event is not None
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
serialized = self.serializer.serialize_event(event, time_now)
|
serialized = self.serializer.serialize_event(event, time_now)
|
||||||
|
|
|
@ -3382,8 +3382,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
|
||||||
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
|
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
|
||||||
# can check its call_count later on during the test.
|
# can check its call_count later on during the test.
|
||||||
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
|
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
|
||||||
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
|
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
|
||||||
self.hs.get_identity_handler().lookup_3pid = Mock(
|
self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(None),
|
return_value=make_awaitable(None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3443,8 +3443,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
|
||||||
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
|
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
|
||||||
# can check its call_count later on during the test.
|
# can check its call_count later on during the test.
|
||||||
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
|
make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
|
||||||
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock
|
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
|
||||||
self.hs.get_identity_handler().lookup_3pid = Mock(
|
self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(None),
|
return_value=make_awaitable(None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3563,8 +3563,10 @@ class TimestampLookupTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
event.internal_metadata.outlier = True
|
event.internal_metadata.outlier = True
|
||||||
|
persistence = self._storage_controllers.persistence
|
||||||
|
assert persistence is not None
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self._storage_controllers.persistence.persist_event(
|
persistence.persist_event(
|
||||||
event, EventContext.for_outlier(self._storage_controllers)
|
event, EventContext.for_outlier(self._storage_controllers)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase):
|
||||||
def test_invite_3pid(self) -> None:
|
def test_invite_3pid(self) -> None:
|
||||||
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
|
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
|
||||||
identity_handler = self.hs.get_identity_handler()
|
identity_handler = self.hs.get_identity_handler()
|
||||||
identity_handler.lookup_3pid = Mock(
|
identity_handler.lookup_3pid = Mock( # type: ignore[assignment]
|
||||||
side_effect=AssertionError("This should not get called")
|
side_effect=AssertionError("This should not get called")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -222,7 +222,7 @@ class RoomTestCase(_ShadowBannedBase):
|
||||||
event_source.get_new_events(
|
event_source.get_new_events(
|
||||||
user=UserID.from_string(self.other_user_id),
|
user=UserID.from_string(self.other_user_id),
|
||||||
from_key=0,
|
from_key=0,
|
||||||
limit=None,
|
limit=10,
|
||||||
room_ids=[room_id],
|
room_ids=[room_id],
|
||||||
is_guest=False,
|
is_guest=False,
|
||||||
)
|
)
|
||||||
|
@ -286,6 +286,7 @@ class ProfileTestCase(_ShadowBannedBase):
|
||||||
self.banned_user_id,
|
self.banned_user_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
assert event is not None
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
event.content, {"membership": "join", "displayname": original_display_name}
|
event.content, {"membership": "join", "displayname": original_display_name}
|
||||||
)
|
)
|
||||||
|
@ -321,6 +322,7 @@ class ProfileTestCase(_ShadowBannedBase):
|
||||||
self.banned_user_id,
|
self.banned_user_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
assert event is not None
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
event.content, {"membership": "join", "displayname": original_display_name}
|
event.content, {"membership": "join", "displayname": original_display_name}
|
||||||
)
|
)
|
||||||
|
|
|
@ -84,7 +84,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
|
||||||
self.room_id, EventTypes.Tombstone, ""
|
self.room_id, EventTypes.Tombstone, ""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(tombstone_event)
|
assert tombstone_event is not None
|
||||||
self.assertEqual(new_room_id, tombstone_event.content["replacement_room"])
|
self.assertEqual(new_room_id, tombstone_event.content["replacement_room"])
|
||||||
|
|
||||||
# Check that the new room exists.
|
# Check that the new room exists.
|
||||||
|
|
|
@ -24,6 +24,7 @@ from synapse.server import HomeServer
|
||||||
from synapse.server_notices.resource_limits_server_notices import (
|
from synapse.server_notices.resource_limits_server_notices import (
|
||||||
ResourceLimitsServerNotices,
|
ResourceLimitsServerNotices,
|
||||||
)
|
)
|
||||||
|
from synapse.server_notices.server_notices_sender import ServerNoticesSender
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
@ -58,14 +59,15 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.server_notices_sender = self.hs.get_server_notices_sender()
|
server_notices_sender = self.hs.get_server_notices_sender()
|
||||||
|
assert isinstance(server_notices_sender, ServerNoticesSender)
|
||||||
|
|
||||||
# relying on [1] is far from ideal, but the only case where
|
# relying on [1] is far from ideal, but the only case where
|
||||||
# ResourceLimitsServerNotices class needs to be isolated is this test,
|
# ResourceLimitsServerNotices class needs to be isolated is this test,
|
||||||
# general code should never have a reason to do so ...
|
# general code should never have a reason to do so ...
|
||||||
self._rlsn = self.server_notices_sender._server_notices[1]
|
rlsn = list(server_notices_sender._server_notices)[1]
|
||||||
if not isinstance(self._rlsn, ResourceLimitsServerNotices):
|
assert isinstance(rlsn, ResourceLimitsServerNotices)
|
||||||
raise Exception("Failed to find reference to ResourceLimitsServerNotices")
|
self._rlsn = rlsn
|
||||||
|
|
||||||
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
||||||
return_value=make_awaitable(1000)
|
return_value=make_awaitable(1000)
|
||||||
|
@ -101,25 +103,29 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
|
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
|
||||||
"""Test when user has blocked notice, but should have it removed"""
|
"""Test when user has blocked notice, but should have it removed"""
|
||||||
|
|
||||||
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(None)
|
return_value=make_awaitable(None)
|
||||||
)
|
)
|
||||||
mock_event = Mock(
|
mock_event = Mock(
|
||||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||||
)
|
)
|
||||||
self._rlsn._store.get_events = Mock(
|
self._rlsn._store.get_events = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable({"123": mock_event})
|
return_value=make_awaitable({"123": mock_event})
|
||||||
)
|
)
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
# Would be better to check the content, but once == remove blocking event
|
# Would be better to check the content, but once == remove blocking event
|
||||||
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once()
|
maybe_get_notice_room_for_user = (
|
||||||
|
self._rlsn._server_notices_manager.maybe_get_notice_room_for_user
|
||||||
|
)
|
||||||
|
assert isinstance(maybe_get_notice_room_for_user, Mock)
|
||||||
|
maybe_get_notice_room_for_user.assert_called_once()
|
||||||
self._send_notice.assert_called_once()
|
self._send_notice.assert_called_once()
|
||||||
|
|
||||||
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None:
|
def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test when user has blocked notice, but notice ought to be there (NOOP)
|
Test when user has blocked notice, but notice ought to be there (NOOP)
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(403, "foo"),
|
side_effect=ResourceLimitError(403, "foo"),
|
||||||
)
|
)
|
||||||
|
@ -127,7 +133,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
mock_event = Mock(
|
mock_event = Mock(
|
||||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||||
)
|
)
|
||||||
self._rlsn._store.get_events = Mock(
|
self._rlsn._store.get_events = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable({"123": mock_event})
|
return_value=make_awaitable({"123": mock_event})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -139,7 +145,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
Test when user does not have blocked notice, but should have one
|
Test when user does not have blocked notice, but should have one
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(403, "foo"),
|
side_effect=ResourceLimitError(403, "foo"),
|
||||||
)
|
)
|
||||||
|
@ -152,7 +158,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
Test when user does not have blocked notice, nor should they (NOOP)
|
Test when user does not have blocked notice, nor should they (NOOP)
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(None)
|
return_value=make_awaitable(None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -165,7 +171,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
Test when user is not part of the MAU cohort - this should not ever
|
Test when user is not part of the MAU cohort - this should not ever
|
||||||
happen - but ...
|
happen - but ...
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(None)
|
return_value=make_awaitable(None)
|
||||||
)
|
)
|
||||||
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
self._rlsn._store.user_last_seen_monthly_active = Mock(
|
||||||
|
@ -183,7 +189,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
Test that when server is over MAU limit and alerting is suppressed, then
|
Test that when server is over MAU limit and alerting is suppressed, then
|
||||||
an alert message is not sent into the room
|
an alert message is not sent into the room
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(
|
side_effect=ResourceLimitError(
|
||||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
||||||
|
@ -198,7 +204,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
Test that when a server is disabled, that MAU limit alerting is ignored.
|
Test that when a server is disabled, that MAU limit alerting is ignored.
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(
|
side_effect=ResourceLimitError(
|
||||||
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
|
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
|
||||||
|
@ -217,21 +223,21 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||||
When the room is already in a blocked state, test that when alerting
|
When the room is already in a blocked state, test that when alerting
|
||||||
is suppressed that the room is returned to an unblocked state.
|
is suppressed that the room is returned to an unblocked state.
|
||||||
"""
|
"""
|
||||||
self._rlsn._auth_blocking.check_auth_blocking = Mock(
|
self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(None),
|
return_value=make_awaitable(None),
|
||||||
side_effect=ResourceLimitError(
|
side_effect=ResourceLimitError(
|
||||||
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
|
self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable((True, []))
|
return_value=make_awaitable((True, []))
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_event = Mock(
|
mock_event = Mock(
|
||||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||||
)
|
)
|
||||||
self._rlsn._store.get_events = Mock(
|
self._rlsn._store.get_events = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable({"123": mock_event})
|
return_value=make_awaitable({"123": mock_event})
|
||||||
)
|
)
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
|
@ -262,16 +268,18 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
self.server_notices_sender = self.hs.get_server_notices_sender()
|
|
||||||
self.server_notices_manager = self.hs.get_server_notices_manager()
|
self.server_notices_manager = self.hs.get_server_notices_manager()
|
||||||
self.event_source = self.hs.get_event_sources()
|
self.event_source = self.hs.get_event_sources()
|
||||||
|
|
||||||
|
server_notices_sender = self.hs.get_server_notices_sender()
|
||||||
|
assert isinstance(server_notices_sender, ServerNoticesSender)
|
||||||
|
|
||||||
# relying on [1] is far from ideal, but the only case where
|
# relying on [1] is far from ideal, but the only case where
|
||||||
# ResourceLimitsServerNotices class needs to be isolated is this test,
|
# ResourceLimitsServerNotices class needs to be isolated is this test,
|
||||||
# general code should never have a reason to do so ...
|
# general code should never have a reason to do so ...
|
||||||
self._rlsn = self.server_notices_sender._server_notices[1]
|
rlsn = list(server_notices_sender._server_notices)[1]
|
||||||
if not isinstance(self._rlsn, ResourceLimitsServerNotices):
|
assert isinstance(rlsn, ResourceLimitsServerNotices)
|
||||||
raise Exception("Failed to find reference to ResourceLimitsServerNotices")
|
self._rlsn = rlsn
|
||||||
|
|
||||||
self.user_id = "@user_id:test"
|
self.user_id = "@user_id:test"
|
||||||
|
|
||||||
|
|
|
@ -120,6 +120,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
|
||||||
# Persist the event which should invalidate or prefill the
|
# Persist the event which should invalidate or prefill the
|
||||||
# `have_seen_event` cache so we don't return stale values.
|
# `have_seen_event` cache so we don't return stale values.
|
||||||
persistence = self.hs.get_storage_controllers().persistence
|
persistence = self.hs.get_storage_controllers().persistence
|
||||||
|
assert persistence is not None
|
||||||
self.get_success(
|
self.get_success(
|
||||||
persistence.persist_event(
|
persistence.persist_event(
|
||||||
event,
|
event,
|
||||||
|
|
|
@ -389,6 +389,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
persist_events_store = self.hs.get_datastores().persist_events
|
persist_events_store = self.hs.get_datastores().persist_events
|
||||||
|
assert persist_events_store is not None
|
||||||
|
|
||||||
for e in events:
|
for e in events:
|
||||||
e.internal_metadata.stream_ordering = self._next_stream_ordering
|
e.internal_metadata.stream_ordering = self._next_stream_ordering
|
||||||
|
@ -397,6 +398,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||||
def _persist(txn: LoggingTransaction) -> None:
|
def _persist(txn: LoggingTransaction) -> None:
|
||||||
# We need to persist the events to the events and state_events
|
# We need to persist the events to the events and state_events
|
||||||
# tables.
|
# tables.
|
||||||
|
assert persist_events_store is not None
|
||||||
persist_events_store._store_event_txn(
|
persist_events_store._store_event_txn(
|
||||||
txn,
|
txn,
|
||||||
[(e, EventContext(self.hs.get_storage_controllers())) for e in events],
|
[(e, EventContext(self.hs.get_storage_controllers())) for e in events],
|
||||||
|
@ -540,7 +542,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
||||||
self.requester, events_and_context=[(event, context)]
|
self.requester, events_and_context=[(event, context)]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
state1 = set(self.get_success(context.get_current_state_ids()).values())
|
state_ids1 = self.get_success(context.get_current_state_ids())
|
||||||
|
assert state_ids1 is not None
|
||||||
|
state1 = set(state_ids1.values())
|
||||||
|
|
||||||
event, context = self.get_success(
|
event, context = self.get_success(
|
||||||
event_handler.create_event(
|
event_handler.create_event(
|
||||||
|
@ -560,7 +564,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
|
||||||
self.requester, events_and_context=[(event, context)]
|
self.requester, events_and_context=[(event, context)]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
state2 = set(self.get_success(context.get_current_state_ids()).values())
|
state_ids2 = self.get_success(context.get_current_state_ids())
|
||||||
|
assert state_ids2 is not None
|
||||||
|
state2 = set(state_ids2.values())
|
||||||
|
|
||||||
# Delete the chain cover info.
|
# Delete the chain cover info.
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
|
persist_events = hs.get_datastores().persist_events
|
||||||
|
assert persist_events is not None
|
||||||
|
self.persist_events = persist_events
|
||||||
|
|
||||||
def test_get_prev_events_for_room(self) -> None:
|
def test_get_prev_events_for_room(self) -> None:
|
||||||
room_id = "@ROOM:local"
|
room_id = "@ROOM:local"
|
||||||
|
@ -226,7 +229,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
|
self.persist_events._persist_event_auth_chain_txn(
|
||||||
txn,
|
txn,
|
||||||
[
|
[
|
||||||
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
|
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
|
||||||
|
@ -445,7 +448,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Insert all events apart from 'B'
|
# Insert all events apart from 'B'
|
||||||
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
|
self.persist_events._persist_event_auth_chain_txn(
|
||||||
txn,
|
txn,
|
||||||
[
|
[
|
||||||
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
|
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
|
||||||
|
@ -464,7 +467,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
updatevalues={"has_auth_chain_index": False},
|
updatevalues={"has_auth_chain_index": False},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hs.datastores.persist_events._persist_event_auth_chain_txn(
|
self.persist_events._persist_event_auth_chain_txn(
|
||||||
txn,
|
txn,
|
||||||
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
|
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
|
||||||
)
|
)
|
||||||
|
|
|
@ -40,7 +40,9 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
) -> None:
|
) -> None:
|
||||||
self.state = self.hs.get_state_handler()
|
self.state = self.hs.get_state_handler()
|
||||||
self._persistence = self.hs.get_storage_controllers().persistence
|
persistence = self.hs.get_storage_controllers().persistence
|
||||||
|
assert persistence is not None
|
||||||
|
self._persistence = persistence
|
||||||
self._state_storage_controller = self.hs.get_storage_controllers().state
|
self._state_storage_controller = self.hs.get_storage_controllers().state
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
@ -374,7 +376,9 @@ class InvalideUsersInRoomCacheTestCase(HomeserverTestCase):
|
||||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||||
) -> None:
|
) -> None:
|
||||||
self.state = self.hs.get_state_handler()
|
self.state = self.hs.get_state_handler()
|
||||||
self._persistence = self.hs.get_storage_controllers().persistence
|
persistence = self.hs.get_storage_controllers().persistence
|
||||||
|
assert persistence is not None
|
||||||
|
self._persistence = persistence
|
||||||
self.store = self.hs.get_datastores().main
|
self.store = self.hs.get_datastores().main
|
||||||
|
|
||||||
def test_remote_user_rooms_cache_invalidated(self) -> None:
|
def test_remote_user_rooms_cache_invalidated(self) -> None:
|
||||||
|
|
|
@ -16,8 +16,6 @@ import signedjson.key
|
||||||
import signedjson.types
|
import signedjson.types
|
||||||
import unpaddedbase64
|
import unpaddedbase64
|
||||||
|
|
||||||
from twisted.internet.defer import Deferred
|
|
||||||
|
|
||||||
from synapse.storage.keys import FetchKeyResult
|
from synapse.storage.keys import FetchKeyResult
|
||||||
|
|
||||||
import tests.unittest
|
import tests.unittest
|
||||||
|
@ -44,20 +42,26 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
|
|
||||||
key_id_1 = "ed25519:key1"
|
key_id_1 = "ed25519:key1"
|
||||||
key_id_2 = "ed25519:KEY_ID_2"
|
key_id_2 = "ed25519:KEY_ID_2"
|
||||||
d = store.store_server_verify_keys(
|
self.get_success(
|
||||||
"from_server",
|
store.store_server_verify_keys(
|
||||||
10,
|
"from_server",
|
||||||
[
|
10,
|
||||||
("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
|
[
|
||||||
("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
|
("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
|
||||||
],
|
("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
|
||||||
|
],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.get_success(d)
|
|
||||||
|
|
||||||
d = store.get_server_verify_keys(
|
res = self.get_success(
|
||||||
[("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
|
store.get_server_verify_keys(
|
||||||
|
[
|
||||||
|
("server1", key_id_1),
|
||||||
|
("server1", key_id_2),
|
||||||
|
("server1", "ed25519:key3"),
|
||||||
|
]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
res = self.get_success(d)
|
|
||||||
|
|
||||||
self.assertEqual(len(res.keys()), 3)
|
self.assertEqual(len(res.keys()), 3)
|
||||||
res1 = res[("server1", key_id_1)]
|
res1 = res[("server1", key_id_1)]
|
||||||
|
@ -82,18 +86,20 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
key_id_1 = "ed25519:key1"
|
key_id_1 = "ed25519:key1"
|
||||||
key_id_2 = "ed25519:key2"
|
key_id_2 = "ed25519:key2"
|
||||||
|
|
||||||
d = store.store_server_verify_keys(
|
self.get_success(
|
||||||
"from_server",
|
store.store_server_verify_keys(
|
||||||
0,
|
"from_server",
|
||||||
[
|
0,
|
||||||
("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
|
[
|
||||||
("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
|
("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
|
||||||
],
|
("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
|
||||||
|
],
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self.get_success(d)
|
|
||||||
|
|
||||||
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
|
res = self.get_success(
|
||||||
res = self.get_success(d)
|
store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
|
||||||
|
)
|
||||||
self.assertEqual(len(res.keys()), 2)
|
self.assertEqual(len(res.keys()), 2)
|
||||||
|
|
||||||
res1 = res[("srv1", key_id_1)]
|
res1 = res[("srv1", key_id_1)]
|
||||||
|
@ -105,9 +111,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
self.assertEqual(res2.valid_until_ts, 200)
|
self.assertEqual(res2.valid_until_ts, 200)
|
||||||
|
|
||||||
# we should be able to look up the same thing again without a db hit
|
# we should be able to look up the same thing again without a db hit
|
||||||
res = store.get_server_verify_keys([("srv1", key_id_1)])
|
res = self.get_success(store.get_server_verify_keys([("srv1", key_id_1)]))
|
||||||
if isinstance(res, Deferred):
|
|
||||||
res = self.successResultOf(res)
|
|
||||||
self.assertEqual(len(res.keys()), 1)
|
self.assertEqual(len(res.keys()), 1)
|
||||||
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
|
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
|
||||||
|
|
||||||
|
@ -119,8 +123,9 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.get_success(d)
|
self.get_success(d)
|
||||||
|
|
||||||
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
|
res = self.get_success(
|
||||||
res = self.get_success(d)
|
store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
|
||||||
|
)
|
||||||
self.assertEqual(len(res.keys()), 2)
|
self.assertEqual(len(res.keys()), 2)
|
||||||
|
|
||||||
res1 = res[("srv1", key_id_1)]
|
res1 = res[("srv1", key_id_1)]
|
||||||
|
|
|
@ -112,7 +112,7 @@ class PurgeTests(HomeserverTestCase):
|
||||||
self.room_id, "m.room.create", ""
|
self.room_id, "m.room.create", ""
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.assertIsNotNone(create_event)
|
assert create_event is not None
|
||||||
|
|
||||||
# Purge everything before this topological token
|
# Purge everything before this topological token
|
||||||
self.get_success(
|
self.get_success(
|
||||||
|
|
|
@ -37,9 +37,9 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||||
self.store = homeserver.get_datastores().main
|
self.store = homeserver.get_datastores().main
|
||||||
|
|
||||||
self.room_creator = homeserver.get_room_creation_handler()
|
self.room_creator = homeserver.get_room_creation_handler()
|
||||||
self.persist_event_storage_controller = (
|
persist_event_storage_controller = self.hs.get_storage_controllers().persistence
|
||||||
self.hs.get_storage_controllers().persistence
|
assert persist_event_storage_controller is not None
|
||||||
)
|
self.persist_event_storage_controller = persist_event_storage_controller
|
||||||
|
|
||||||
# Create a test user
|
# Create a test user
|
||||||
self.ourUser = UserID.from_string(OUR_USER_ID)
|
self.ourUser = UserID.from_string(OUR_USER_ID)
|
||||||
|
|
|
@ -119,7 +119,6 @@ class EventSearchInsertionTest(HomeserverTestCase):
|
||||||
"content": {"msgtype": "m.text", "body": 2},
|
"content": {"msgtype": "m.text", "body": 2},
|
||||||
"room_id": room_id,
|
"room_id": room_id,
|
||||||
"sender": user_id,
|
"sender": user_id,
|
||||||
"depth": prev_event.depth + 1,
|
|
||||||
"prev_events": prev_event_ids,
|
"prev_events": prev_event_ids,
|
||||||
"origin_server_ts": self.clock.time_msec(),
|
"origin_server_ts": self.clock.time_msec(),
|
||||||
}
|
}
|
||||||
|
@ -134,7 +133,7 @@ class EventSearchInsertionTest(HomeserverTestCase):
|
||||||
prev_state_map,
|
prev_state_map,
|
||||||
for_verification=False,
|
for_verification=False,
|
||||||
),
|
),
|
||||||
depth=event_dict["depth"],
|
depth=prev_event.depth + 1,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from typing import List
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, RelationTypes
|
from synapse.api.constants import Direction, EventTypes, RelationTypes
|
||||||
from synapse.api.filtering import Filter
|
from synapse.api.filtering import Filter
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import login, room
|
from synapse.rest.client import login, room
|
||||||
|
@ -128,7 +128,7 @@ class PaginationTestCase(HomeserverTestCase):
|
||||||
room_id=self.room_id,
|
room_id=self.room_id,
|
||||||
from_key=self.from_token.room_key,
|
from_key=self.from_token.room_key,
|
||||||
to_key=None,
|
to_key=None,
|
||||||
direction="f",
|
direction=Direction.FORWARDS,
|
||||||
limit=10,
|
limit=10,
|
||||||
event_filter=Filter(self.hs, filter),
|
event_filter=Filter(self.hs, filter),
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from synapse.storage.database import make_conn
|
from synapse.storage.database import make_conn
|
||||||
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.engines._base import IncorrectDatabaseSetup
|
from synapse.storage.engines._base import IncorrectDatabaseSetup
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
@ -38,6 +39,7 @@ class UnsafeLocaleTest(HomeserverTestCase):
|
||||||
|
|
||||||
def test_safe_locale(self) -> None:
|
def test_safe_locale(self) -> None:
|
||||||
database = self.hs.get_datastores().databases[0]
|
database = self.hs.get_datastores().databases[0]
|
||||||
|
assert isinstance(database.engine, PostgresEngine)
|
||||||
|
|
||||||
db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
|
db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
|
||||||
with db_conn.cursor() as txn:
|
with db_conn.cursor() as txn:
|
||||||
|
|
|
@ -12,17 +12,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import Collection, List, Optional, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
from twisted.internet.defer import succeed
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
from synapse.api.errors import FederationError
|
from synapse.api.errors import FederationError
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersion, RoomVersions
|
||||||
from synapse.events import EventBase, make_event_from_dict
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.federation.federation_base import event_from_pdu_json
|
from synapse.federation.federation_base import event_from_pdu_json
|
||||||
|
from synapse.handlers.device import DeviceListUpdater
|
||||||
from synapse.http.types import QueryParams
|
from synapse.http.types import QueryParams
|
||||||
from synapse.logging.context import LoggingContext
|
from synapse.logging.context import LoggingContext
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
@ -81,11 +81,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
federation_event_handler._check_event_auth = _check_event_auth
|
federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment]
|
||||||
self.client = self.hs.get_federation_client()
|
self.client = self.hs.get_federation_client()
|
||||||
self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
|
|
||||||
lambda dest, pdus, **k: succeed(pdus)
|
async def _check_sigs_and_hash_for_pulled_events_and_fetch(
|
||||||
)
|
dest: str, pdus: Collection[EventBase], room_version: RoomVersion
|
||||||
|
) -> List[EventBase]:
|
||||||
|
return list(pdus)
|
||||||
|
|
||||||
|
self.client._check_sigs_and_hash_for_pulled_events_and_fetch = _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment]
|
||||||
|
|
||||||
# Send the join, it should return None (which is not an error)
|
# Send the join, it should return None (which is not an error)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
@ -187,7 +191,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Register the mock on the federation client.
|
# Register the mock on the federation client.
|
||||||
federation_client = self.hs.get_federation_client()
|
federation_client = self.hs.get_federation_client()
|
||||||
federation_client.query_user_devices = Mock(side_effect=query_user_devices)
|
federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment]
|
||||||
|
|
||||||
# Register a mock on the store so that the incoming update doesn't fail because
|
# Register a mock on the store so that the incoming update doesn't fail because
|
||||||
# we don't share a room with the user.
|
# we don't share a room with the user.
|
||||||
|
@ -197,6 +201,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
# Manually inject a fake device list update. We need this update to include at
|
# Manually inject a fake device list update. We need this update to include at
|
||||||
# least one prev_id so that the user's device list will need to be retried.
|
# least one prev_id so that the user's device list will need to be retried.
|
||||||
device_list_updater = self.hs.get_device_handler().device_list_updater
|
device_list_updater = self.hs.get_device_handler().device_list_updater
|
||||||
|
assert isinstance(device_list_updater, DeviceListUpdater)
|
||||||
self.get_success(
|
self.get_success(
|
||||||
device_list_updater.incoming_device_list_update(
|
device_list_updater.incoming_device_list_update(
|
||||||
origin=remote_origin,
|
origin=remote_origin,
|
||||||
|
@ -236,7 +241,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Register mock device list retrieval on the federation client.
|
# Register mock device list retrieval on the federation client.
|
||||||
federation_client = self.hs.get_federation_client()
|
federation_client = self.hs.get_federation_client()
|
||||||
federation_client.query_user_devices = Mock(
|
federation_client.query_user_devices = Mock( # type: ignore[assignment]
|
||||||
return_value=make_awaitable(
|
return_value=make_awaitable(
|
||||||
{
|
{
|
||||||
"user_id": remote_user_id,
|
"user_id": remote_user_id,
|
||||||
|
@ -269,16 +274,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
||||||
keys = self.get_success(
|
keys = self.get_success(
|
||||||
self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
|
self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]),
|
||||||
)
|
)
|
||||||
self.assertTrue(remote_user_id in keys)
|
self.assertIn(remote_user_id, keys)
|
||||||
|
key = keys[remote_user_id]
|
||||||
|
assert key is not None
|
||||||
|
|
||||||
# Check that the master key is the one returned by the mock.
|
# Check that the master key is the one returned by the mock.
|
||||||
master_key = keys[remote_user_id]["master"]
|
master_key = key["master"]
|
||||||
self.assertEqual(len(master_key["keys"]), 1)
|
self.assertEqual(len(master_key["keys"]), 1)
|
||||||
self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
|
self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys())
|
||||||
self.assertTrue(remote_master_key in master_key["keys"].values())
|
self.assertTrue(remote_master_key in master_key["keys"].values())
|
||||||
|
|
||||||
# Check that the self-signing key is the one returned by the mock.
|
# Check that the self-signing key is the one returned by the mock.
|
||||||
self_signing_key = keys[remote_user_id]["self_signing"]
|
self_signing_key = key["self_signing"]
|
||||||
self.assertEqual(len(self_signing_key["keys"]), 1)
|
self.assertEqual(len(self_signing_key["keys"]), 1)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
|
"ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(),
|
||||||
|
|
|
@ -33,7 +33,7 @@ class PhoneHomeStatsTestCase(HomeserverTestCase):
|
||||||
If time doesn't move, don't error out.
|
If time doesn't move, don't error out.
|
||||||
"""
|
"""
|
||||||
past_stats = [
|
past_stats = [
|
||||||
(self.hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF))
|
(int(self.hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
|
||||||
]
|
]
|
||||||
stats: JsonDict = {}
|
stats: JsonDict = {}
|
||||||
self.get_success(phone_stats_home(self.hs, stats, past_stats))
|
self.get_success(phone_stats_home(self.hs, stats, past_stats))
|
||||||
|
|
|
@ -35,6 +35,8 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
self.event_creation_handler = self.hs.get_event_creation_handler()
|
self.event_creation_handler = self.hs.get_event_creation_handler()
|
||||||
self.event_builder_factory = self.hs.get_event_builder_factory()
|
self.event_builder_factory = self.hs.get_event_builder_factory()
|
||||||
self._storage_controllers = self.hs.get_storage_controllers()
|
self._storage_controllers = self.hs.get_storage_controllers()
|
||||||
|
assert self._storage_controllers.persistence is not None
|
||||||
|
self._persistence = self._storage_controllers.persistence
|
||||||
|
|
||||||
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
|
self.get_success(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
|
||||||
|
|
||||||
|
@ -179,9 +181,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
context = self.get_success(unpersisted_context.persist(event))
|
context = self.get_success(unpersisted_context.persist(event))
|
||||||
self.get_success(
|
self.get_success(self._persistence.persist_event(event, context))
|
||||||
self._storage_controllers.persistence.persist_event(event, context)
|
|
||||||
)
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def _inject_room_member(
|
def _inject_room_member(
|
||||||
|
@ -208,9 +208,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
context = self.get_success(unpersisted_context.persist(event))
|
context = self.get_success(unpersisted_context.persist(event))
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(self._persistence.persist_event(event, context))
|
||||||
self._storage_controllers.persistence.persist_event(event, context)
|
|
||||||
)
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def _inject_message(
|
def _inject_message(
|
||||||
|
@ -233,9 +231,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
)
|
)
|
||||||
context = self.get_success(unpersisted_context.persist(event))
|
context = self.get_success(unpersisted_context.persist(event))
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(self._persistence.persist_event(event, context))
|
||||||
self._storage_controllers.persistence.persist_event(event, context)
|
|
||||||
)
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def _inject_outlier(self) -> EventBase:
|
def _inject_outlier(self) -> EventBase:
|
||||||
|
@ -253,7 +249,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||||
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
|
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
|
||||||
event.internal_metadata.outlier = True
|
event.internal_metadata.outlier = True
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self._storage_controllers.persistence.persist_event(
|
self._persistence.persist_event(
|
||||||
event, EventContext.for_outlier(self._storage_controllers)
|
event, EventContext.for_outlier(self._storage_controllers)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -361,7 +361,9 @@ class HomeserverTestCase(TestCase):
|
||||||
store.db_pool.updates.do_next_background_update(False), by=0.1
|
store.db_pool.updates.do_next_background_update(False), by=0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_homeserver(self, reactor: ThreadedMemoryReactorClock, clock: Clock):
|
def make_homeserver(
|
||||||
|
self, reactor: ThreadedMemoryReactorClock, clock: Clock
|
||||||
|
) -> HomeServer:
|
||||||
"""
|
"""
|
||||||
Make and return a homeserver.
|
Make and return a homeserver.
|
||||||
|
|
||||||
|
|
|
@ -54,6 +54,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
|
||||||
self.pump()
|
self.pump()
|
||||||
|
|
||||||
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
|
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
|
||||||
|
assert new_timings is not None
|
||||||
self.assertEqual(new_timings.failure_ts, failure_ts)
|
self.assertEqual(new_timings.failure_ts, failure_ts)
|
||||||
self.assertEqual(new_timings.retry_last_ts, failure_ts)
|
self.assertEqual(new_timings.retry_last_ts, failure_ts)
|
||||||
self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
|
self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL)
|
||||||
|
@ -82,6 +83,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
|
||||||
self.pump()
|
self.pump()
|
||||||
|
|
||||||
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
|
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
|
||||||
|
assert new_timings is not None
|
||||||
self.assertEqual(new_timings.failure_ts, failure_ts)
|
self.assertEqual(new_timings.failure_ts, failure_ts)
|
||||||
self.assertEqual(new_timings.retry_last_ts, retry_ts)
|
self.assertEqual(new_timings.retry_last_ts, retry_ts)
|
||||||
self.assertGreaterEqual(
|
self.assertGreaterEqual(
|
||||||
|
|
Loading…
Reference in New Issue