Convert simple_select_one and simple_select_one_onecol to async (#8162)

pull/8171/head
Patrick Cloke 2020-08-26 07:19:32 -04:00 committed by GitHub
parent 56efa9ec71
commit 4c6c56dc58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 220 additions and 113 deletions

1
changelog.d/8162.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -29,9 +29,11 @@ from typing import (
Tuple,
TypeVar,
Union,
overload,
)
from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
from twisted.internet import defer
@ -1020,14 +1022,36 @@ class DatabasePool(object):
return txn.execute_batch(sql, args)
def simple_select_one(
@overload
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
...
@overload
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
...
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> defer.Deferred:
) -> Optional[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@ -1038,18 +1062,18 @@ class DatabasePool(object):
allow_none: If true, return None instead of failing if the SELECT
statement returns no rows
"""
return self.runInteraction(
return await self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
def simple_select_one_onecol(
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one_onecol",
) -> defer.Deferred:
) -> Optional[Any]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@ -1061,7 +1085,7 @@ class DatabasePool(object):
statement returns no rows
desc: description of the transaction, for logging and metrics
"""
return self.runInteraction(
return await self.runInteraction(
desc,
self.simple_select_one_onecol_txn,
table,

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, Iterable, List, Optional, Set, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id: str, device_id: str):
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore):
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
A dict containing the device information
Raises:
StoreError: if the device is not found
"""
return self.db_pool.simple_select_one(
return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id: str):
async def get_device_list_last_stream_id_for_remote(
self, user_id: str
) -> Optional[Any]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",

View File

@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias):
return self.db_pool.simple_select_one_onecol(
async def get_room_alias_creator(self, room_alias: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",

View File

@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return ret
def count_e2e_room_keys(self, user_id, version):
async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
"""Get the number of keys in a backup version.
Args:
user_id (str): the user whose backup we're querying
version (str): the version ID of the backup we're querying about
user_id: the user whose backup we're querying
version: the version ID of the backup we're querying about
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",

View File

@ -119,19 +119,19 @@ class EventsWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
def get_received_ts(self, event_id):
async def get_received_ts(self, event_id: str) -> Optional[int]:
"""Get received_ts (when it was persisted) for the event.
Raises an exception for unknown events.
Args:
event_id (str)
event_id: The event ID to query.
Returns:
Deferred[int|None]: Timestamp in milliseconds, or None for events
that were persisted before received_ts was implemented.
Timestamp in milliseconds, or None for events that were persisted
before received_ts was implemented.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
class GroupServerWorkerStore(SQLBaseStore):
def get_group(self, group_id):
return self.db_pool.simple_select_one(
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore):
)
return bool(result)
def is_user_admin_in_group(self, group_id, user_id):
return self.db_pool.simple_select_one_onecol(
async def is_user_admin_in_group(
self, group_id: str, user_id: str
) -> Optional[bool]:
return await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="is_user_admin_in_group",
)
def is_user_invited_to_local_group(self, group_id, user_id):
async def is_user_invited_to_local_group(
self, group_id: str, user_id: str
) -> Optional[bool]:
"""Has the group server invited a user?
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",

View File

@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
def get_local_media(self, media_id):
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
Returns:
None if the media_id doesn't exist.
"""
return self.db_pool.simple_select_one(
return await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_thumbnail",
)
def get_cached_remote_media(self, origin, media_id):
return self.db_pool.simple_select_one(
async def get_cached_remote_media(
self, origin, media_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(

View File

@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return users
@cached(num_args=1)
def user_last_seen_monthly_active(self, user_id):
async def user_last_seen_monthly_active(self, user_id: str) -> int:
"""
Checks if a given user is part of the monthly active user group
Arguments:
user_id (str): user to add/update
Return:
Deferred[int] : timestamp since last seen, None if never seen
Checks if a given user is part of the monthly active user group
Arguments:
user_id: user to add/update
Return:
Timestamp since last seen, None if never seen
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",

View File

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
class ProfileWorkerStore(SQLBaseStore):
async def get_profileinfo(self, user_localpart):
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
def get_profile_displayname(self, user_localpart):
return self.db_pool.simple_select_one_onecol(
async def get_profile_displayname(self, user_localpart: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
def get_profile_avatar_url(self, user_localpart):
return self.db_pool.simple_select_one_onecol(
async def get_profile_avatar_url(self, user_localpart: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
def get_from_remote_profile_cache(self, user_id):
return self.db_pool.simple_select_one(
async def get_from_remote_profile_cache(
self, user_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),

View File

@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
return self.db_pool.simple_select_one_onecol(
async def get_last_receipt_event_id_for_user(
self, user_id: str, room_id: str, receipt_type: str
) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,

View File

@ -17,7 +17,7 @@
import logging
import re
from typing import Awaitable, Dict, List, Optional
from typing import Any, Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@cached()
def get_user_by_id(self, user_id):
return self.db_pool.simple_select_one(
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@ -1259,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="del_user_pending_deactivation",
)
def get_user_pending_deactivation(self):
async def get_user_pending_deactivation(self) -> Optional[str]:
"""
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",

View File

@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import Optional
from synapse.storage._base import SQLBaseStore
@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
def get_rejection_reason(self, event_id):
return self.db_pool.simple_select_one_onecol(
async def get_rejection_reason(self, event_id: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},

View File

@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
self.config = hs.config
def get_room(self, room_id):
async def get_room(self, room_id: str) -> dict:
"""Retrieve a room.
Args:
room_id (str): The ID of the room to retrieve.
room_id: The ID of the room to retrieve.
Returns:
A dict containing the room information, or None if the room is unknown.
"""
return self.db_pool.simple_select_one(
return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore):
return ret_val
@cached(max_entries=10000)
def is_room_blocked(self, room_id):
return self.db_pool.simple_select_one_onecol(
async def is_room_blocked(self, room_id: str) -> Optional[bool]:
return await self.db_pool.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="1",

View File

@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return event.content.get("canonical_alias")
@cached(max_entries=50000)
def _get_state_group_for_event(self, event_id):
return self.db_pool.simple_select_one_onecol(
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
return await self.db_pool.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",

View File

@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore):
return len(rooms_to_work_on)
def get_stats_positions(self):
async def get_stats_positions(self) -> int:
"""
Returns the stats processor positions.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="stats_incremental_position",
keyvalues={},
retcol="stream_id",
@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
def get_earliest_token_for_stats(self, stats_type, id):
async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore):
being calculated.
Returns:
Deferred[int]
The earliest token.
"""
table, id_col = TYPE_TO_TABLE[stats_type]
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
"%s_current" % (table,),
keyvalues={id_col: id},
retcol="completed_delta_stream_id",

View File

@ -15,6 +15,7 @@
import logging
import re
from typing import Any, Dict, Optional
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
def get_user_in_directory(self, user_id):
return self.db_pool.simple_select_one(
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
def get_user_directory_stream_pos(self):
return self.db_pool.simple_select_one_onecol(
async def get_user_directory_stream_pos(self) -> int:
return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",

View File

@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_name(self):
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
yield defer.ensureDeferred(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.frank)
@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank",
)
@defer.inlineCallbacks
@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
yield defer.ensureDeferred(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank",
)
# Setting displayname a second time is forbidden
@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield self.store.set_profile_displayname("caroline", "Caroline")
yield defer.ensureDeferred(
self.store.set_profile_displayname("caroline", "Caroline")
)
response = yield defer.ensureDeferred(
self.query_handlers["profile"](
@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
yield self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
)
)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/pic.gif",
)
@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/me.png",
)
@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
yield self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
)
)
self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/me.png",
)

View File

@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_users_in_room = get_users_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
self.datastore.get_user_directory_stream_pos.side_effect = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
defer.succeed(1)
lambda: make_awaitable(1)
)
self.datastore.get_current_state_deltas.return_value = (0, None)

View File

@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the new user exists with all provided attributes
self.assertEqual(user_id, "@bob:test")
self.assertTrue(access_token)
self.assertTrue(self.store.get_user_by_id(user_id))
self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
# Check that the email was assigned
emails = self.get_success(self.store.user_get_threepids(user_id))

View File

@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
value = yield self.datastore.db_pool.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
value = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
)
)
self.assertEquals("Value", value)
@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
ret = yield self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"],
ret = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"],
)
)
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
ret = yield self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
allow_none=True,
ret = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
allow_none=True,
)
)
self.assertFalse(ret)

View File

@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name")
)
res = yield self.store.get_device("user_id", "device_id")
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name 1")
)
res = yield self.store.get_device("user_id", "device_id")
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
res = yield self.store.get_device("user_id", "device_id")
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# check it worked
res = yield self.store.get_device("user_id", "device_id")
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks

View File

@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase):
def test_displayname(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
yield defer.ensureDeferred(
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
)
self.assertEquals(
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
"Frank",
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.u_frank.localpart)
)
),
)
@defer.inlineCallbacks
def test_avatar_url(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
)
)
self.assertEquals(
"http://my.site/here",
(yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
),
)

View File

@ -53,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"user_type": None,
"deactivated": 0,
},
(yield self.store.get_user_by_id(self.user_id)),
(yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
)
@defer.inlineCallbacks

View File

@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
(yield self.store.get_room(self.room.to_string())),
(yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
)
@defer.inlineCallbacks
def test_get_room_unknown_room(self):
self.assertIsNone((yield self.store.get_room("!uknown:test")),)
self.assertIsNone(
(yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
)
@defer.inlineCallbacks
def test_get_room_with_stats(self):
@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
(yield self.store.get_room_with_stats(self.room.to_string())),
(
yield defer.ensureDeferred(
self.store.get_room_with_stats(self.room.to_string())
)
),
)
@defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
self.assertIsNone(
(
yield defer.ensureDeferred(
self.store.get_room_with_stats("!uknown:test")
)
),
)
class RoomEventsStoreTestCase(unittest.TestCase):