Remove redundant `get_success` calls in test code (#12346)

There are a bunch of places we call get_success on an immediate value, which is unnecessary. Let's rip them out, and remove the redundant functionality in get_success and friends.
pull/12355/head
Richard van der Hoff 2022-04-01 16:10:31 +01:00 committed by GitHub
parent c4cf916ed7
commit 33ebee47e4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 76 additions and 123 deletions

1
changelog.d/12346.misc Normal file
View File

@ -0,0 +1 @@
Remove redundant `get_success` calls in test code.

View File

@ -44,21 +44,20 @@ class DeactivateAccountTestCase(HomeserverTestCase):
Deactivates the account `self.user` using `self.token` and asserts Deactivates the account `self.user` using `self.token` and asserts
that it returns a 200 success code. that it returns a 200 success code.
""" """
req = self.get_success( req = self.make_request(
self.make_request( "POST",
"POST", "account/deactivate",
"account/deactivate", {
{ "auth": {
"auth": { "type": "m.login.password",
"type": "m.login.password", "user": self.user,
"user": self.user, "password": "pass",
"password": "pass",
},
"erase": True,
}, },
access_token=self.token, "erase": True,
) },
access_token=self.token,
) )
self.assertEqual(req.code, HTTPStatus.OK, req) self.assertEqual(req.code, HTTPStatus.OK, req)
def test_global_account_data_deleted_upon_deactivation(self) -> None: def test_global_account_data_deleted_upon_deactivation(self) -> None:

View File

@ -59,7 +59,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
self.bob = UserID.from_string("@4567:test") self.bob = UserID.from_string("@4567:test")
self.alice = UserID.from_string("@alice:remote") self.alice = UserID.from_string("@alice:remote")
self.get_success(self.register_user(self.frank.localpart, "frankpassword")) self.register_user(self.frank.localpart, "frankpassword")
self.handler = hs.get_profile_handler() self.handler = hs.get_profile_handler()

View File

@ -158,9 +158,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
) )
# Blow away caches (supported room versions can only change due to a restart). # Blow away caches (supported room versions can only change due to a restart).
self.get_success( self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
)
self.store._get_event_cache.clear() self.store._get_event_cache.clear()
# The rooms should be excluded from the sync response. # The rooms should be excluded from the sync response.

View File

@ -87,24 +87,22 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(displayname, "Bobberino") self.assertEqual(displayname, "Bobberino")
def test_can_register_admin_user(self): def test_can_register_admin_user(self):
user_id = self.get_success( user_id = self.register_user(
self.register_user( "bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
"bob_module_admin", "1234", displayname="Bobberino Admin", admin=True
)
) )
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
self.assertEqual(found_user.user_id.to_string(), user_id) self.assertEqual(found_user.user_id.to_string(), user_id)
self.assertIdentical(found_user.is_admin, True) self.assertIdentical(found_user.is_admin, True)
def test_can_set_admin(self): def test_can_set_admin(self):
user_id = self.get_success( user_id = self.register_user(
self.register_user( "alice_wants_admin",
"alice_wants_admin", "1234",
"1234", displayname="Alice Powerhungry",
displayname="Alice Powerhungry", admin=False,
admin=False,
)
) )
self.get_success(self.module_api.set_user_admin(user_id, True)) self.get_success(self.module_api.set_user_admin(user_id, True))
found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id))
self.assertEqual(found_user.user_id.to_string(), user_id) self.assertEqual(found_user.user_id.to_string(), user_id)

View File

@ -268,7 +268,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event_source = RoomEventSource(self.hs) event_source = RoomEventSource(self.hs)
event_source.store = self.slaved_store event_source.store = self.slaved_store
current_token = self.get_success(event_source.get_current_key()) current_token = event_source.get_current_key()
# gradually stream out the replication # gradually stream out the replication
while repl_transport.buffer: while repl_transport.buffer:
@ -277,7 +277,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.pump(0) self.pump(0)
prev_token = current_token prev_token = current_token
current_token = self.get_success(event_source.get_current_key()) current_token = event_source.get_current_key()
# attempt to replicate the behaviour of the sync handler. # attempt to replicate the behaviour of the sync handler.
# #

View File

@ -214,9 +214,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
self.assertEqual(messages[0]["sender"], "@notices:test") self.assertEqual(messages[0]["sender"], "@notices:test")
# invalidate cache of server notices room_ids # invalidate cache of server notices room_ids
self.get_success( self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
)
# send second message # send second message
channel = self.make_request( channel = self.make_request(
@ -291,9 +289,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
# invalidate cache of server notices room_ids # invalidate cache of server notices room_ids
# if server tries to send to a cached room_id the user gets the message # if server tries to send to a cached room_id the user gets the message
# in old room # in old room
self.get_success( self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
)
# send second message # send second message
channel = self.make_request( channel = self.make_request(
@ -380,9 +376,7 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase):
# invalidate cache of server notices room_ids # invalidate cache of server notices room_ids
# if server tries to send to a cached room_id it gives an error # if server tries to send to a cached room_id it gives an error
self.get_success( self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
)
# send second message # send second message
channel = self.make_request( channel = self.make_request(

View File

@ -982,7 +982,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
super().prepare(reactor, clock, hs) super().prepare(reactor, clock, hs)
# profile changes expect that the user is actually registered # profile changes expect that the user is actually registered
user = UserID.from_string(self.user_id) user = UserID.from_string(self.user_id)
self.get_success(self.register_user(user.localpart, "supersecretpassword")) self.register_user(user.localpart, "supersecretpassword")
@unittest.override_config( @unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}

View File

@ -157,10 +157,10 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 7}) self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
ctx1 = self.get_success(id_gen.get_next()) ctx1 = id_gen.get_next()
ctx2 = self.get_success(id_gen.get_next()) ctx2 = id_gen.get_next()
ctx3 = self.get_success(id_gen.get_next()) ctx3 = id_gen.get_next()
ctx4 = self.get_success(id_gen.get_next()) ctx4 = id_gen.get_next()
s1 = self.get_success(ctx1.__aenter__()) s1 = self.get_success(ctx1.__aenter__())
s2 = self.get_success(ctx2.__aenter__()) s2 = self.get_success(ctx2.__aenter__())
@ -362,8 +362,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7) self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
# Persist two rows at once # Persist two rows at once
ctx1 = self.get_success(id_gen.get_next()) ctx1 = id_gen.get_next()
ctx2 = self.get_success(id_gen.get_next()) ctx2 = id_gen.get_next()
s1 = self.get_success(ctx1.__aenter__()) s1 = self.get_success(ctx1.__aenter__())
s2 = self.get_success(ctx2.__aenter__()) s2 = self.get_success(ctx2.__aenter__())

View File

@ -119,11 +119,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
return event return event
def test_redact(self): def test_redact(self):
self.get_success( self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
)
msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) msg_event = self.inject_message(self.room1, self.u_alice, "t")
# Check event has not been redacted: # Check event has not been redacted:
event = self.get_success(self.store.get_event(msg_event.event_id)) event = self.get_success(self.store.get_event(msg_event.event_id))
@ -141,9 +139,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
# Redact event # Redact event
reason = "Because I said so" reason = "Because I said so"
self.get_success( self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
)
event = self.get_success(self.store.get_event(msg_event.event_id)) event = self.get_success(self.store.get_event(msg_event.event_id))
@ -170,14 +166,10 @@ class RedactionTestCase(unittest.HomeserverTestCase):
) )
def test_redact_join(self): def test_redact_join(self):
self.get_success( self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
)
msg_event = self.get_success( msg_event = self.inject_room_member(
self.inject_room_member( self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
self.room1, self.u_bob, Membership.JOIN, extra_content={"blue": "red"}
)
) )
event = self.get_success(self.store.get_event(msg_event.event_id)) event = self.get_success(self.store.get_event(msg_event.event_id))
@ -195,9 +187,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
# Redact event # Redact event
reason = "Because I said so" reason = "Because I said so"
self.get_success( self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
)
# Check redaction # Check redaction
@ -311,11 +301,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def test_redact_censor(self): def test_redact_censor(self):
"""Test that a redacted event gets censored in the DB after a month""" """Test that a redacted event gets censored in the DB after a month"""
self.get_success( self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
)
msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) msg_event = self.inject_message(self.room1, self.u_alice, "t")
# Check event has not been redacted: # Check event has not been redacted:
event = self.get_success(self.store.get_event(msg_event.event_id)) event = self.get_success(self.store.get_event(msg_event.event_id))
@ -333,9 +321,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
# Redact event # Redact event
reason = "Because I said so" reason = "Because I said so"
self.get_success( self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
self.inject_redaction(self.room1, msg_event.event_id, self.u_alice, reason)
)
event = self.get_success(self.store.get_event(msg_event.event_id)) event = self.get_success(self.store.get_event(msg_event.event_id))
@ -381,25 +367,19 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def test_redact_redaction(self): def test_redact_redaction(self):
"""Tests that we can redact a redaction and can fetch it again.""" """Tests that we can redact a redaction and can fetch it again."""
self.get_success( self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
msg_event = self.inject_message(self.room1, self.u_alice, "t")
first_redact_event = self.inject_redaction(
self.room1, msg_event.event_id, self.u_alice, "Redacting message"
) )
msg_event = self.get_success(self.inject_message(self.room1, self.u_alice, "t")) self.inject_redaction(
self.room1,
first_redact_event = self.get_success( first_redact_event.event_id,
self.inject_redaction( self.u_alice,
self.room1, msg_event.event_id, self.u_alice, "Redacting message" "Redacting redaction",
)
)
self.get_success(
self.inject_redaction(
self.room1,
first_redact_event.event_id,
self.u_alice,
"Redacting redaction",
)
) )
# Now lets jump to the future where we have censored the redaction event # Now lets jump to the future where we have censored the redaction event
@ -414,9 +394,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def test_store_redacted_redaction(self): def test_store_redacted_redaction(self):
"""Tests that we can store a redacted redaction.""" """Tests that we can store a redacted redaction."""
self.get_success( self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
self.inject_room_member(self.room1, self.u_alice, Membership.JOIN)
)
builder = self.event_builder_factory.for_room_version( builder = self.event_builder_factory.for_room_version(
RoomVersions.V1, RoomVersions.V1,

View File

@ -110,9 +110,7 @@ class PaginationTestCase(HomeserverTestCase):
def _filter_messages(self, filter: JsonDict) -> List[EventBase]: def _filter_messages(self, filter: JsonDict) -> List[EventBase]:
"""Make a request to /messages with a filter, returns the chunk of events.""" """Make a request to /messages with a filter, returns the chunk of events."""
from_token = self.get_success( from_token = self.hs.get_event_sources().get_current_token_for_pagination()
self.hs.get_event_sources().get_current_token_for_pagination()
)
events, next_key = self.get_success( events, next_key = self.get_success(
self.hs.get_datastores().main.paginate_room_events( self.hs.get_datastores().main.paginate_room_events(

View File

@ -48,17 +48,15 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# #
# before we do that, we persist some other events to act as state. # before we do that, we persist some other events to act as state.
self.get_success(self._inject_visibility("@admin:hs", "joined")) self._inject_visibility("@admin:hs", "joined")
for i in range(0, 10): for i in range(0, 10):
self.get_success(self._inject_room_member("@resident%i:hs" % i)) self._inject_room_member("@resident%i:hs" % i)
events_to_filter = [] events_to_filter = []
for i in range(0, 10): for i in range(0, 10):
user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
evt = self.get_success( evt = self._inject_room_member(user, extra_content={"a": "b"})
self._inject_room_member(user, extra_content={"a": "b"})
)
events_to_filter.append(evt) events_to_filter.append(evt)
filtered = self.get_success( filtered = self.get_success(
@ -76,10 +74,10 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
def test_filter_outlier(self) -> None: def test_filter_outlier(self) -> None:
# outlier events must be returned, for the good of the collective federation # outlier events must be returned, for the good of the collective federation
self.get_success(self._inject_room_member("@resident:remote_hs")) self._inject_room_member("@resident:remote_hs")
self.get_success(self._inject_visibility("@resident:remote_hs", "joined")) self._inject_visibility("@resident:remote_hs", "joined")
outlier = self.get_success(self._inject_outlier()) outlier = self._inject_outlier()
self.assertEqual( self.assertEqual(
self.get_success( self.get_success(
filter_events_for_server(self.storage, "remote_hs", [outlier]) filter_events_for_server(self.storage, "remote_hs", [outlier])
@ -88,7 +86,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
) )
# it should also work when there are other events in the list # it should also work when there are other events in the list
evt = self.get_success(self._inject_message("@unerased:local_hs")) evt = self._inject_message("@unerased:local_hs")
filtered = self.get_success( filtered = self.get_success(
filter_events_for_server(self.storage, "remote_hs", [outlier, evt]) filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
@ -112,19 +110,19 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# change in the middle of them. # change in the middle of them.
events_to_filter = [] events_to_filter = []
evt = self.get_success(self._inject_message("@unerased:local_hs")) evt = self._inject_message("@unerased:local_hs")
events_to_filter.append(evt) events_to_filter.append(evt)
evt = self.get_success(self._inject_message("@erased:local_hs")) evt = self._inject_message("@erased:local_hs")
events_to_filter.append(evt) events_to_filter.append(evt)
evt = self.get_success(self._inject_room_member("@joiner:remote_hs")) evt = self._inject_room_member("@joiner:remote_hs")
events_to_filter.append(evt) events_to_filter.append(evt)
evt = self.get_success(self._inject_message("@unerased:local_hs")) evt = self._inject_message("@unerased:local_hs")
events_to_filter.append(evt) events_to_filter.append(evt)
evt = self.get_success(self._inject_message("@erased:local_hs")) evt = self._inject_message("@erased:local_hs")
events_to_filter.append(evt) events_to_filter.append(evt)
# the erasey user gets erased # the erasey user gets erased

View File

@ -16,7 +16,6 @@
import gc import gc
import hashlib import hashlib
import hmac import hmac
import inspect
import json import json
import logging import logging
import secrets import secrets
@ -519,33 +518,23 @@ class HomeserverTestCase(TestCase):
self.reactor.pump([by] * 100) self.reactor.pump([by] * 100)
def get_success(self, d, by=0.0): def get_success(self, d, by=0.0):
if inspect.isawaitable(d): deferred: Deferred[TV] = ensureDeferred(d)
d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump(by=by) self.pump(by=by)
return self.successResultOf(d) return self.successResultOf(deferred)
def get_failure(self, d, exc): def get_failure(self, d, exc):
""" """
Run a Deferred and get a Failure from it. The failure must be of the type `exc`. Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
""" """
if inspect.isawaitable(d): deferred: Deferred[Any] = ensureDeferred(d)
d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump() self.pump()
return self.failureResultOf(d, exc) return self.failureResultOf(deferred, exc)
def get_success_or_raise(self, d, by=0.0): def get_success_or_raise(self, d, by=0.0):
"""Drive deferred to completion and return result or raise exception """Drive deferred to completion and return result or raise exception
on failure. on failure.
""" """
deferred: Deferred[TV] = ensureDeferred(d)
if inspect.isawaitable(d):
deferred = ensureDeferred(d)
if not isinstance(deferred, Deferred):
return d
results: list = [] results: list = []
deferred.addBoth(results.append) deferred.addBoth(results.append)