Change slave storage to use new replication interface

As the TCP replication uses a slightly different API and streams than
the HTTP replication.

This breaks HTTP replication.
pull/2097/head
Erik Johnston 2017-03-27 16:32:07 +01:00
parent 52bfa604e1
commit 3a1f3f8388
11 changed files with 128 additions and 179 deletions

View File

@ -15,7 +15,6 @@
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker
@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore):
else:
self._cache_id_gen = None
self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
self.http_client = hs.get_simple_http_client()
self.hs = hs
def stream_positions(self):
pos = {}
@ -43,35 +41,20 @@ class BaseSlavedStore(SQLBaseStore):
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def process_replication(self, result):
stream = result.get("caches")
if stream:
for row in stream["rows"]:
(
position, cache_func, keys, invalidation_ts,
) = row
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
self._cache_id_gen.advance(token)
for row in rows:
try:
getattr(self, cache_func).invalidate(tuple(keys))
getattr(self, row.cache_func).invalidate(tuple(row.keys))
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
pass
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)
txn.call_after(self._send_invalidation_poke, cache_func, keys)
@defer.inlineCallbacks
def _send_invalidation_poke(self, cache_func, keys):
try:
yield self.http_client.post_json_get_json(self.expire_cache_url, {
"invalidate": [{
"name": cache_func.__name__,
"keys": list(keys),
}]
})
except:
logger.exception("Failed to poke on expire_cache")
self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)

View File

@ -69,38 +69,25 @@ class SlavedAccountDataStore(BaseSlavedStore):
result["tag_account_data"] = position
return result
def process_replication(self, result):
stream = result.get("user_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id, data_type = row[:3]
self.get_global_account_data_by_type_for_user.invalidate(
(data_type, user_id,)
)
self.get_account_data_for_user.invalidate((user_id,))
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "tag_account_data":
self._account_data_id_gen.advance(token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
row.user_id, token
)
stream = result.get("room_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_account_data_for_user.invalidate((user_id,))
elif stream_name == "account_data":
self._account_data_id_gen.advance(token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
(row.data_type, row.user_id,)
)
self.get_account_data_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
row.user_id, token
)
stream = result.get("tag_account_data")
if stream:
self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_tags_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedAccountDataStore, self).process_replication(result)
return super(SlavedAccountDataStore, self).process_replication_rows(
stream_name, token, rows
)

View File

@ -53,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("to_device")
if stream:
self._device_inbox_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
entity = row[1]
if entity.startswith("@"):
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "to_device":
self._device_inbox_id_gen.advance(token)
for row in rows:
if row.entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed(
entity, stream_id
row.entity, token
)
else:
self._device_federation_outbox_stream_cache.entity_has_changed(
entity, stream_id
row.entity, token
)
return super(SlavedDeviceInboxStore, self).process_replication(result)
return super(SlavedDeviceInboxStore, self).process_replication_rows(
stream_name, token, rows
)

View File

@ -51,22 +51,18 @@ class SlavedDeviceStore(BaseSlavedStore):
result["device_lists"] = self._device_list_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("device_lists")
if stream:
self._device_list_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
destination = row[2]
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "device_lists":
self._device_list_id_gen.advance(token)
for row in rows:
self._device_list_stream_cache.entity_has_changed(
user_id, stream_id
row.user_id, token
)
if destination:
if row.destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, stream_id
row.destination, token
)
return super(SlavedDeviceStore, self).process_replication(result)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)

View File

@ -201,48 +201,25 @@ class SlavedEventStore(BaseSlavedStore):
result["backfill"] = -self._backfill_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("events")
if stream:
self._stream_id_gen.advance(int(stream["position"]))
if stream["rows"]:
logger.info("Got %d event rows", len(stream["rows"]))
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False,
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "events":
self._stream_id_gen.advance(token)
for row in rows:
self.invalidate_caches_for_event(
token, row.event_id, row.room_id, row.type, row.state_key,
row.redacts,
backfilled=False,
)
stream = result.get("backfill")
if stream:
self._backfill_id_gen.advance(-int(stream["position"]))
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=True,
elif stream_name == "backfill":
self._backfill_id_gen.advance(-token)
for row in rows:
self.invalidate_caches_for_event(
-token, row.event_id, row.room_id, row.type, row.state_key,
row.redacts,
backfilled=True,
)
stream = result.get("forward_ex_outliers")
if stream:
self._stream_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
stream = result.get("backward_ex_outliers")
if stream:
self._backfill_id_gen.advance(-int(stream["position"]))
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled):
stream_ordering = row[0] if not backfilled else -row[0]
self.invalidate_caches_for_event(
stream_ordering, row[1], row[2], row[3], row[4], row[5],
backfilled=backfilled,
return super(SlavedEventStore, self).process_replication_rows(
stream_name, token, rows
)
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,

View File

@ -48,15 +48,14 @@ class SlavedPresenceStore(BaseSlavedStore):
result["presence"] = position
return result
def process_replication(self, result):
stream = result.get("presence")
if stream:
self._presence_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "presence":
self._presence_id_gen.advance(token)
for row in rows:
self.presence_stream_cache.entity_has_changed(
user_id, position
row.user_id, token
)
self._get_presence_for_user.invalidate((user_id,))
return super(SlavedPresenceStore, self).process_replication(result)
self._get_presence_for_user.invalidate((row.user_id,))
return super(SlavedPresenceStore, self).process_replication_rows(
stream_name, token, rows
)

View File

@ -50,18 +50,15 @@ class SlavedPushRuleStore(SlavedEventStore):
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("push_rules")
if stream:
for row in stream["rows"]:
position = row[0]
user_id = row[2]
self.get_push_rules_for_user.invalidate((user_id,))
self.get_push_rules_enabled_for_user.invalidate((user_id,))
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "push_rules":
self._push_rules_stream_id_gen.advance(token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
self.push_rules_stream_cache.entity_has_changed(
user_id, position
row.user_id, token
)
self._push_rules_stream_id_gen.advance(int(stream["position"]))
return super(SlavedPushRuleStore, self).process_replication(result)
return super(SlavedPushRuleStore, self).process_replication_rows(
stream_name, token, rows
)

View File

@ -40,13 +40,9 @@ class SlavedPusherStore(BaseSlavedStore):
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("pushers")
if stream:
self._pushers_id_gen.advance(int(stream["position"]))
stream = result.get("deleted_pushers")
if stream:
self._pushers_id_gen.advance(int(stream["position"]))
return super(SlavedPusherStore, self).process_replication(result)
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "pushers":
self._pushers_id_gen.advance(token)
return super(SlavedPusherStore, self).process_replication_rows(
stream_name, token, rows
)

View File

@ -65,20 +65,22 @@ class SlavedReceiptsStore(BaseSlavedStore):
result["receipts"] = self._receipts_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("receipts")
if stream:
self._receipts_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, room_id, receipt_type, user_id = row[:4]
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
self._receipts_stream_cache.entity_has_changed(room_id, position)
return super(SlavedReceiptsStore, self).process_replication(result)
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "receipts":
self._receipts_id_gen.advance(token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
)
self._receipts_stream_cache.entity_has_changed(row.room_id, token)
return super(SlavedReceiptsStore, self).process_replication_rows(
stream_name, token, rows
)

View File

@ -46,9 +46,10 @@ class RoomStore(BaseSlavedStore):
result["public_rooms"] = self._public_room_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("public_rooms")
if stream:
self._public_room_id_gen.advance(int(stream["position"]))
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "public_rooms":
self._public_room_id_gen.advance(token)
return super(RoomStore, self).process_replication(result)
return super(RoomStore, self).process_replication_rows(
stream_name, token, rows
)

View File

@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from twisted.internet import defer, reactor
from tests import unittest
from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver
from synapse.replication.resource import ReplicationResource
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.replication.tcp.client import (
ReplicationClientHandler, ReplicationClientFactory,
)
class BaseSlavedStoreTestCase(unittest.TestCase):
@ -33,18 +36,29 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
)
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
self.replication = ReplicationResource(self.hs)
self.master_store = self.hs.get_datastore()
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
listener = reactor.listenUNIX("\0xxx", server_factory)
self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer
self.replication_handler = ReplicationClientHandler(self.slaved_store)
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
client_connector = reactor.connectUNIX("\0xxx", client_factory)
self.addCleanup(client_factory.stopTrying)
self.addCleanup(client_connector.disconnect)
@defer.inlineCallbacks
def replicate(self):
streams = self.slaved_store.stream_positions()
writer = yield self.replication.replicate(streams, 100)
result = writer.finish()
yield self.slaved_store.process_replication(result)
yield self.streamer.on_notifier_poke()
d = self.replication_handler.await_sync("replication_test")
self.streamer.send_sync_to_all_connections("replication_test")
yield d
@defer.inlineCallbacks
def check(self, method, args, expected_result=None):