Remove handling of multiple rows per ID
parent
59ad93d2a4
commit
f70f44abc7
|
@ -112,23 +112,13 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
if not has_changed:
|
if not has_changed:
|
||||||
return now_stream_id, []
|
return now_stream_id, []
|
||||||
|
|
||||||
# We retrieve n+1 devices from the list of outbound pokes where n is
|
|
||||||
# our outbound device update limit. We then check if the very last
|
|
||||||
# device has the same stream_id as the second-to-last device. If so,
|
|
||||||
# then we ignore all devices with that stream_id and only send the
|
|
||||||
# devices with a lower stream_id.
|
|
||||||
#
|
|
||||||
# If when culling the list we end up with no devices afterwards, we
|
|
||||||
# consider the device update to be too large, and simply skip the
|
|
||||||
# stream_id; the rationale being that such a large device list update
|
|
||||||
# is likely an error.
|
|
||||||
updates = yield self.db.runInteraction(
|
updates = yield self.db.runInteraction(
|
||||||
"get_device_updates_by_remote",
|
"get_device_updates_by_remote",
|
||||||
self._get_device_updates_by_remote_txn,
|
self._get_device_updates_by_remote_txn,
|
||||||
destination,
|
destination,
|
||||||
from_stream_id,
|
from_stream_id,
|
||||||
now_stream_id,
|
now_stream_id,
|
||||||
limit + 1,
|
limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return an empty list if there are no updates
|
# Return an empty list if there are no updates
|
||||||
|
@ -166,14 +156,6 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
"device_id": verify_key.version,
|
"device_id": verify_key.version,
|
||||||
}
|
}
|
||||||
|
|
||||||
# if we have exceeded the limit, we need to exclude any results with the
|
|
||||||
# same stream_id as the last row.
|
|
||||||
if len(updates) > limit:
|
|
||||||
stream_id_cutoff = updates[-1][2]
|
|
||||||
now_stream_id = stream_id_cutoff - 1
|
|
||||||
else:
|
|
||||||
stream_id_cutoff = None
|
|
||||||
|
|
||||||
# Perform the equivalent of a GROUP BY
|
# Perform the equivalent of a GROUP BY
|
||||||
#
|
#
|
||||||
# Iterate through the updates list and copy non-duplicate
|
# Iterate through the updates list and copy non-duplicate
|
||||||
|
@ -192,10 +174,6 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
query_map = {}
|
query_map = {}
|
||||||
cross_signing_keys_by_user = {}
|
cross_signing_keys_by_user = {}
|
||||||
for user_id, device_id, update_stream_id, update_context in updates:
|
for user_id, device_id, update_stream_id, update_context in updates:
|
||||||
if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
|
|
||||||
# Stop processing updates
|
|
||||||
break
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
user_id in master_key_by_user
|
user_id in master_key_by_user
|
||||||
and device_id == master_key_by_user[user_id]["device_id"]
|
and device_id == master_key_by_user[user_id]["device_id"]
|
||||||
|
@ -218,17 +196,6 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
if update_stream_id > previous_update_stream_id:
|
if update_stream_id > previous_update_stream_id:
|
||||||
query_map[key] = (update_stream_id, update_context)
|
query_map[key] = (update_stream_id, update_context)
|
||||||
|
|
||||||
# If we didn't find any updates with a stream_id lower than the cutoff, it
|
|
||||||
# means that there are more than limit updates all of which have the same
|
|
||||||
# steam_id.
|
|
||||||
|
|
||||||
# That should only happen if a client is spamming the server with new
|
|
||||||
# devices, in which case E2E isn't going to work well anyway. We'll just
|
|
||||||
# skip that stream_id and return an empty list, and continue with the next
|
|
||||||
# stream_id next time.
|
|
||||||
if not query_map and not cross_signing_keys_by_user:
|
|
||||||
return stream_id_cutoff, []
|
|
||||||
|
|
||||||
results = yield self._get_device_update_edus_by_remote(
|
results = yield self._get_device_update_edus_by_remote(
|
||||||
destination, from_stream_id, query_map
|
destination, from_stream_id, query_map
|
||||||
)
|
)
|
||||||
|
|
|
@ -88,51 +88,6 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
# Check original device_ids are contained within these updates
|
# Check original device_ids are contained within these updates
|
||||||
self._check_devices_in_updates(device_ids, device_updates)
|
self._check_devices_in_updates(device_ids, device_updates)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_get_device_updates_by_remote_limited(self):
|
|
||||||
# Test breaking the update limit in 1, 101, and 1 device_id segments
|
|
||||||
|
|
||||||
# first add one device
|
|
||||||
device_ids1 = ["device_id0"]
|
|
||||||
yield self.store.add_device_change_to_streams(
|
|
||||||
"user_id", device_ids1, ["someotherhost"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# then add 101
|
|
||||||
device_ids2 = ["device_id" + str(i + 1) for i in range(101)]
|
|
||||||
yield self.store.add_device_change_to_streams(
|
|
||||||
"user_id", device_ids2, ["someotherhost"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# then one more
|
|
||||||
device_ids3 = ["newdevice"]
|
|
||||||
yield self.store.add_device_change_to_streams(
|
|
||||||
"user_id", device_ids3, ["someotherhost"]
|
|
||||||
)
|
|
||||||
|
|
||||||
#
|
|
||||||
# now read them back.
|
|
||||||
#
|
|
||||||
|
|
||||||
# first we should get a single update
|
|
||||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
|
||||||
"someotherhost", -1, limit=100
|
|
||||||
)
|
|
||||||
self._check_devices_in_updates(device_ids1, device_updates)
|
|
||||||
|
|
||||||
# Then we should get an empty list back as the 101 devices broke the limit
|
|
||||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
|
||||||
"someotherhost", now_stream_id, limit=100
|
|
||||||
)
|
|
||||||
self.assertEqual(len(device_updates), 0)
|
|
||||||
|
|
||||||
# The 101 devices should've been cleared, so we should now just get one device
|
|
||||||
# update
|
|
||||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
|
||||||
"someotherhost", now_stream_id, limit=100
|
|
||||||
)
|
|
||||||
self._check_devices_in_updates(device_ids3, device_updates)
|
|
||||||
|
|
||||||
def _check_devices_in_updates(self, expected_device_ids, device_updates):
|
def _check_devices_in_updates(self, expected_device_ids, device_updates):
|
||||||
"""Check that an specific device ids exist in a list of device update EDUs"""
|
"""Check that an specific device ids exist in a list of device update EDUs"""
|
||||||
self.assertEqual(len(device_updates), len(expected_device_ids))
|
self.assertEqual(len(device_updates), len(expected_device_ids))
|
||||||
|
|
Loading…
Reference in New Issue