Refactor getting replication updates from database v2. (#7740)

pull/7800/head
Erik Johnston 2020-07-07 12:11:35 +01:00 committed by GitHub
parent d378c3da78
commit 67d7756fcf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 336 additions and 195 deletions

1
changelog.d/7740.misc Normal file
View File

@ -0,0 +1 @@
Refactor getting replication updates from database.

View File

@ -294,6 +294,9 @@ class TypingHandler(object):
rows.sort() rows.sort()
limited = False limited = False
# We, unusually, use a strict limit here as we have all the rows in
# memory rather than pulling them out of the database with a `LIMIT ?`
# clause.
if len(rows) > limit: if len(rows) > limit:
rows = rows[:limit] rows = rows[:limit]
current_id = rows[-1][0] current_id = rows[-1][0]

View File

@ -198,26 +198,6 @@ def current_token_without_instance(
return lambda instance_name: current_token() return lambda instance_name: current_token()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return update_function
def make_http_update_function(hs, stream_name: str) -> UpdateFunction: def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
"""Makes a suitable function for use as an `update_function` that queries """Makes a suitable function for use as an `update_function` that queries
the master process for updates. the master process for updates.
@ -393,7 +373,7 @@ class PushersStream(Stream):
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_pushers_stream_token), current_token_without_instance(store.get_pushers_stream_token),
db_query_to_update_function(store.get_all_updated_pushers_rows), store.get_all_updated_pushers_rows,
) )
@ -421,27 +401,13 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow ROW_TYPE = CachesStreamRow
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() store = hs.get_datastore()
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
self.store.get_cache_stream_token, store.get_cache_stream_token,
self._update_function, store.get_all_updated_caches,
) )
async def _update_function(
self, instance_name: str, from_token: int, upto_token: int, limit: int
):
rows = await self.store.get_all_updated_caches(
instance_name, from_token, upto_token, limit
)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
class PublicRoomsStream(Stream): class PublicRoomsStream(Stream):
"""The public rooms list changed """The public rooms list changed
@ -465,7 +431,7 @@ class PublicRoomsStream(Stream):
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_current_public_room_stream_id), current_token_without_instance(store.get_current_public_room_stream_id),
db_query_to_update_function(store.get_all_new_public_rooms), store.get_all_new_public_rooms,
) )
@ -486,7 +452,7 @@ class DeviceListsStream(Stream):
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token), current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(store.get_all_device_list_changes_for_remotes), store.get_all_device_list_changes_for_remotes,
) )
@ -504,7 +470,7 @@ class ToDeviceStream(Stream):
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_to_device_stream_token), current_token_without_instance(store.get_to_device_stream_token),
db_query_to_update_function(store.get_all_new_device_messages), store.get_all_new_device_messages,
) )
@ -524,7 +490,7 @@ class TagAccountDataStream(Stream):
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_max_account_data_stream_id), current_token_without_instance(store.get_max_account_data_stream_id),
db_query_to_update_function(store.get_all_updated_tags), store.get_all_updated_tags,
) )
@ -612,7 +578,7 @@ class GroupServerStream(Stream):
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_group_stream_token), current_token_without_instance(store.get_group_stream_token),
db_query_to_update_function(store.get_all_groups_changes), store.get_all_groups_changes,
) )
@ -630,7 +596,5 @@ class UserSignatureStream(Stream):
super().__init__( super().__init__(
hs.get_instance_name(), hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token), current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function( store.get_all_user_signature_changes_for_remotes,
store.get_all_user_signature_changes_for_remotes
),
) )

View File

@ -16,7 +16,7 @@
import itertools import itertools
import logging import logging
from typing import Any, Iterable, Optional, Tuple from typing import Any, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream from synapse.replication.tcp.streams import BackfillStream, CachesStream
@ -46,13 +46,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
async def get_all_updated_caches( async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int self, instance_name: str, last_id: int, current_id: int, limit: int
): ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Fetches cache invalidation rows between the two given IDs written """Get updates for caches replication stream.
by the given instance. Returns at most `limit` rows.
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
""" """
if last_id == current_id: if last_id == current_id:
return [] return [], current_id, False
def get_all_updated_caches_txn(txn): def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to # We purposefully don't bound by the current token, as we want to
@ -66,7 +83,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
LIMIT ? LIMIT ?
""" """
txn.execute(sql, (last_id, instance_name, limit)) txn.execute(sql, (last_id, instance_name, limit))
return txn.fetchall() updates = [(row[0], row[1:]) for row in txn]
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return await self.db.runInteraction( return await self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn "get_all_updated_caches", get_all_updated_caches_txn

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Tuple
from canonicaljson import json from canonicaljson import json
@ -207,31 +208,46 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
) )
def get_all_new_device_messages(self, last_pos, current_pos, limit): async def get_all_new_device_messages(
""" self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for to device replication stream.
Args: Args:
last_pos(int): instance_name: The writer we want to fetch updates from. Unused
current_pos(int): here since there is only ever one writer.
limit(int): last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns: Returns:
A deferred list of rows from the device inbox A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
""" """
if last_pos == current_pos:
return defer.succeed([]) if last_id == current_id:
return [], current_id, False
def get_all_new_device_messages_txn(txn): def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and # We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id # we want to make sure we always get all entries for any stream_id
# we return. # we return.
upper_pos = min(current_pos, last_pos + limit) upper_pos = min(current_id, last_id + limit)
sql = ( sql = (
"SELECT max(stream_id), user_id" "SELECT max(stream_id), user_id"
" FROM device_inbox" " FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?" " WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id" " GROUP BY user_id"
) )
txn.execute(sql, (last_pos, upper_pos)) txn.execute(sql, (last_id, upper_pos))
rows = txn.fetchall() updates = [(row[0], row[1:]) for row in txn]
sql = ( sql = (
"SELECT max(stream_id), destination" "SELECT max(stream_id), destination"
@ -239,15 +255,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" WHERE ? < stream_id AND stream_id <= ?" " WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination" " GROUP BY destination"
) )
txn.execute(sql, (last_pos, upper_pos)) txn.execute(sql, (last_id, upper_pos))
rows.extend(txn) updates.extend((row[0], row[1:]) for row in txn)
# Order by ascending stream ordering # Order by ascending stream ordering
rows.sort() updates.sort()
return rows limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return self.db.runInteraction( return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn "get_all_new_device_messages", get_all_new_device_messages_txn
) )

View File

@ -582,13 +582,33 @@ class DeviceWorkerStore(SQLBaseStore):
return set() return set()
async def get_all_device_list_changes_for_remotes( async def get_all_device_list_changes_for_remotes(
self, from_key: int, to_key: int, limit: int, self, instance_name: str, last_id: int, current_id: int, limit: int
) -> List[Tuple[int, str]]: ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Return a list of `(stream_id, entity)` which is the combined list of """Get updates for device lists replication stream.
changes to devices and which destinations need to be poked. Entity is
either a user ID (starting with '@') or a remote destination. Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
""" """
if last_id == current_id:
return [], current_id, False
def _get_all_device_list_changes_for_remotes(txn):
# This query Does The Right Thing where it'll correctly apply the # This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries. # bounds to the inner queries.
sql = """ sql = """
@ -601,13 +621,19 @@ class DeviceWorkerStore(SQLBaseStore):
LIMIT ? LIMIT ?
""" """
return await self.db.execute( txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], row[1:]) for row in txn]
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_device_list_changes_for_remotes", "get_all_device_list_changes_for_remotes",
None, _get_all_device_list_changes_for_remotes,
sql,
from_key,
to_key,
limit,
) )
@cached(max_entries=10000) @cached(max_entries=10000)

View File

@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, List from typing import Dict, List, Tuple
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json, json
@ -479,20 +479,39 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result return result
def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit): async def get_all_user_signature_changes_for_remotes(
"""Return a list of changes from the user signature stream to notify remotes. self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for groups replication stream.
Note that the user signature stream represents when a user signs their Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other device with their user-signing key, which is not published to other
users or servers, so no `destination` is needed in the returned users or servers, so no `destination` is needed in the returned
list. However, this is needed to poke workers. list. However, this is needed to poke workers.
Args: Args:
from_key (int): the stream ID to start at (exclusive) instance_name: The writer we want to fetch updates from. Unused
to_key (int): the stream ID to end at (inclusive) here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns: Returns:
Deferred[list[(int,str)]] a list of `(stream_id, user_id)` A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
""" """
if last_id == current_id:
return [], current_id, False
def _get_all_user_signature_changes_for_remotes_txn(txn):
sql = """ sql = """
SELECT stream_id, from_user_id AS user_id SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream FROM user_signature_stream
@ -500,13 +519,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
ORDER BY stream_id ASC ORDER BY stream_id ASC
LIMIT ? LIMIT ?
""" """
return self.db.execute( txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], (row[1:])) for row in txn]
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_user_signature_changes_for_remotes", "get_all_user_signature_changes_for_remotes",
None, _get_all_user_signature_changes_for_remotes_txn,
sql,
from_key,
to_key,
limit,
) )

View File

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Tuple
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
@ -526,13 +528,35 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_groups_changes_for_user", _get_groups_changes_for_user_txn "get_groups_changes_for_user", _get_groups_changes_for_user_txn
) )
def get_all_groups_changes(self, from_token, to_token, limit): async def get_all_groups_changes(
from_token = int(from_token) self, instance_name: str, last_id: int, current_id: int, limit: int
has_changed = self._group_updates_stream_cache.has_any_entity_changed( ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
from_token """Get updates for groups replication stream.
)
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
last_id = int(last_id)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
if not has_changed: if not has_changed:
return defer.succeed([]) return [], current_id, False
def _get_all_groups_changes_txn(txn): def _get_all_groups_changes_txn(txn):
sql = """ sql = """
@ -541,13 +565,21 @@ class GroupServerWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
LIMIT ? LIMIT ?
""" """
txn.execute(sql, (from_token, to_token, limit)) txn.execute(sql, (last_id, current_id, limit))
return [ updates = [
(stream_id, group_id, user_id, gtype, json.loads(content_json)) (stream_id, (group_id, user_id, gtype, json.loads(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn for stream_id, group_id, user_id, gtype, content_json in txn
] ]
return self.db.runInteraction( limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn "get_all_groups_changes", _get_all_groups_changes_txn
) )

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Iterable, Iterator from typing import Iterable, Iterator, List, Tuple
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json, json
@ -98,77 +98,69 @@ class PusherWorkerStore(SQLBaseStore):
rows = yield self.db.runInteraction("get_all_pushers", get_pushers) rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
return rows return rows
def get_all_updated_pushers(self, last_id, current_id, limit): async def get_all_updated_pushers_rows(
if last_id == current_id: self, instance_name: str, last_id: int, current_id: int, limit: int
return defer.succeed(([], [])) ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for pushers replication stream.
def get_all_updated_pushers_txn(txn): Args:
sql = ( instance_name: The writer we want to fetch updates from. Unused
"SELECT id, user_name, access_token, profile_tag, kind," here since there is only ever one writer.
" app_id, app_display_name, device_display_name, pushkey, ts," last_id: The token to fetch updates from. Exclusive.
" lang, data" current_id: The token to fetch updates up to. Inclusive.
" FROM pushers" limit: The requested limit for the number of rows to return. The
" WHERE ? < id AND id <= ?" function may return more or fewer rows.
" ORDER BY id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
updated = txn.fetchall()
sql = (
"SELECT stream_id, user_id, app_id, pushkey"
" FROM deleted_pushers"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
deleted = txn.fetchall()
return updated, deleted
return self.db.runInteraction(
"get_all_updated_pushers", get_all_updated_pushers_txn
)
def get_all_updated_pushers_rows(self, last_id, current_id, limit):
"""Get all the pushers that have changed between the given tokens.
Returns: Returns:
Deferred(list(tuple)): each tuple consists of: A tuple consisting of: the updates, a token to use to fetch
stream_id (str) subsequent updates, and whether we returned fewer rows than exists
user_id (str) between the requested tokens due to the limit.
app_id (str)
pushkey (str) The token returned can be used in a subsequent call to this
was_deleted (bool): whether the pusher was added/updated (False) function to get further updatees.
or deleted (True)
The updates are a list of 2-tuples of stream ID and the row data
""" """
if last_id == current_id: if last_id == current_id:
return defer.succeed([]) return [], current_id, False
def get_all_updated_pushers_rows_txn(txn): def get_all_updated_pushers_rows_txn(txn):
sql = ( sql = """
"SELECT id, user_name, app_id, pushkey" SELECT id, user_name, app_id, pushkey
" FROM pushers" FROM pushers
" WHERE ? < id AND id <= ?" WHERE ? < id AND id <= ?
" ORDER BY id ASC LIMIT ?" ORDER BY id ASC LIMIT ?
) """
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
results = [list(row) + [False] for row in txn] updates = [
(stream_id, (user_name, app_id, pushkey, False))
for stream_id, user_name, app_id, pushkey in txn
]
sql = ( sql = """
"SELECT stream_id, user_id, app_id, pushkey" SELECT stream_id, user_id, app_id, pushkey
" FROM deleted_pushers" FROM deleted_pushers
" WHERE ? < stream_id AND stream_id <= ?" WHERE ? < stream_id AND stream_id <= ?
" ORDER BY stream_id ASC LIMIT ?" ORDER BY stream_id ASC LIMIT ?
) """
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
updates.extend(
(stream_id, (user_name, app_id, pushkey, True))
for stream_id, user_name, app_id, pushkey in txn
)
results.extend(list(row) + [True] for row in txn) updates.sort() # Sort so that they're ordered by stream id
results.sort() # Sort so that they're ordered by stream id
return results limited = False
upper_bound = current_id
if len(updates) >= limit:
limited = True
upper_bound = updates[-1][0]
return self.db.runInteraction( return updates, upper_bound, limited
return await self.db.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
) )

View File

@ -803,7 +803,32 @@ class RoomWorkerStore(SQLBaseStore):
return total_media_quarantined return total_media_quarantined
def get_all_new_public_rooms(self, prev_id, current_id, limit): async def get_all_new_public_rooms(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for public rooms replication stream.
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
return [], current_id, False
def get_all_new_public_rooms(txn): def get_all_new_public_rooms(txn):
sql = """ sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id SELECT stream_id, room_id, visibility, appservice_id, network_id
@ -813,13 +838,17 @@ class RoomWorkerStore(SQLBaseStore):
LIMIT ? LIMIT ?
""" """
txn.execute(sql, (prev_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall() updates = [(row[0], row[1:]) for row in txn]
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
if prev_id == current_id: return updates, upto_token, limited
return defer.succeed([])
return self.db.runInteraction( return await self.db.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms "get_all_new_public_rooms", get_all_new_public_rooms
) )

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Tuple
from canonicaljson import json from canonicaljson import json
@ -53,18 +54,32 @@ class TagsWorkerStore(AccountDataWorkerStore):
return deferred return deferred
@defer.inlineCallbacks async def get_all_updated_tags(
def get_all_updated_tags(self, last_id, current_id, limit): self, instance_name: str, last_id: int, current_id: int, limit: int
"""Get all the client tags that have changed on the server ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for tags replication stream.
Args: Args:
last_id(int): The position to fetch from. instance_name: The writer we want to fetch updates from. Unused
current_id(int): The position to fetch up to. here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns: Returns:
A deferred list of tuples of stream_id int, user_id string, A tuple consisting of: the updates, a token to use to fetch
room_id string, tag string and content string. subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
""" """
if last_id == current_id: if last_id == current_id:
return [] return [], current_id, False
def get_all_updated_tags_txn(txn): def get_all_updated_tags_txn(txn):
sql = ( sql = (
@ -76,7 +91,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall() return txn.fetchall()
tag_ids = yield self.db.runInteraction( tag_ids = await self.db.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn "get_all_updated_tags", get_all_updated_tags_txn
) )
@ -89,21 +104,27 @@ class TagsWorkerStore(AccountDataWorkerStore):
for tag, content in txn: for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content) tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}" tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, user_id, room_id, tag_json)) results.append((stream_id, (user_id, room_id, tag_json)))
return results return results
batch_size = 50 batch_size = 50
results = [] results = []
for i in range(0, len(tag_ids), batch_size): for i in range(0, len(tag_ids), batch_size):
tags = yield self.db.runInteraction( tags = await self.db.runInteraction(
"get_all_updated_tag_content", "get_all_updated_tag_content",
get_tag_content, get_tag_content,
tag_ids[i : i + batch_size], tag_ids[i : i + batch_size],
) )
results.extend(tags) results.extend(tags)
return results limited = False
upto_token = current_id
if len(results) >= limit:
upto_token = results[-1][0]
limited = True
return results, upto_token, limited
@defer.inlineCallbacks @defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id): def get_updated_tags(self, user_id, stream_id):