Remove unused token param from REPLICATE cmd

pull/7024/head
Erik Johnston 2020-03-03 16:51:34 +00:00
parent 1f83255de1
commit 8734b75ca8
3 changed files with 14 additions and 35 deletions

View File

@ -183,35 +183,22 @@ class ReplicateCommand(Command):
Format:: Format::
REPLICATE <stream_name> <token> REPLICATE <stream_name>
Where <token> may be either: The <stream_name> can be "ALL" to subscribe to all known streams
* a numeric stream_id to stream updates from
* "NOW" to stream all subsequent updates.
The <stream_name> can be "ALL" to subscribe to all known streams, in which
case the <token> must be set to "NOW", i.e.::
REPLICATE ALL NOW
""" """
NAME = "REPLICATE" NAME = "REPLICATE"
def __init__(self, stream_name, token): def __init__(self, stream_name):
self.stream_name = stream_name self.stream_name = stream_name
self.token = token
@classmethod @classmethod
def from_line(cls, line): def from_line(cls, line):
stream_name, token = line.split(" ", 1) return cls(line)
if token in ("NOW", "now"):
token = "NOW"
else:
token = int(token)
return cls(stream_name, token)
def to_line(self): def to_line(self):
return " ".join((self.stream_name, str(self.token))) return self.stream_name
def get_logcontext_id(self): def get_logcontext_id(self):
return "REPLICATE-" + self.stream_name return "REPLICATE-" + self.stream_name

View File

@ -435,12 +435,11 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def on_REPLICATE(self, cmd): async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name stream_name = cmd.stream_name
token = cmd.token
if stream_name == "ALL": if stream_name == "ALL":
# Subscribe to all streams we're publishing to. # Subscribe to all streams we're publishing to.
deferreds = [ deferreds = [
run_in_background(self.subscribe_to_stream, stream, token) run_in_background(self.subscribe_to_stream, stream)
for stream in iterkeys(self.streamer.streams_by_name) for stream in iterkeys(self.streamer.streams_by_name)
] ]
@ -448,7 +447,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )
else: else:
await self.subscribe_to_stream(stream_name, token) await self.subscribe_to_stream(stream_name)
async def on_FEDERATION_ACK(self, cmd): async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token) self.streamer.federation_ack(cmd.token)
@ -472,12 +471,8 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
cmd.last_seen, cmd.last_seen,
) )
async def subscribe_to_stream(self, stream_name, token): async def subscribe_to_stream(self, stream_name):
"""Subscribe the remote to a stream. """Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
updates down if they have. During that time new updates for the stream
are queued and sent once we've sent down any missed updates.
""" """
self.replication_streams.discard(stream_name) self.replication_streams.discard(stream_name)
@ -620,8 +615,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
BaseReplicationStreamProtocol.connectionMade(self) BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams # Once we've connected subscribe to the necessary streams
for stream_name, token in iteritems(self.handler.get_streams_to_replicate()): for stream_name in self.handler.get_streams_to_replicate():
self.replicate(stream_name, token) self.replicate(stream_name)
# Tell the server if we have any users currently syncing (should only # Tell the server if we have any users currently syncing (should only
# happen on synchrotrons) # happen on synchrotrons)
@ -711,22 +706,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand): async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data) self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token): def replicate(self, stream_name):
"""Send the subscription request to the server """Send the subscription request to the server
""" """
if stream_name not in STREAMS_MAP: if stream_name not in STREAMS_MAP:
raise Exception("Invalid stream name %r" % (stream_name,)) raise Exception("Invalid stream name %r" % (stream_name,))
logger.info( logger.info(
"[%s] Subscribing to replication stream: %r from %r", "[%s] Subscribing to replication stream: %r", self.id(), stream_name,
self.id(),
stream_name,
token,
) )
self.streams_connecting.add(stream_name) self.streams_connecting.add(stream_name)
self.send_command(ReplicateCommand(stream_name, token)) self.send_command(ReplicateCommand(stream_name))
def on_connection_closed(self): def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self) BaseReplicationStreamProtocol.on_connection_closed(self)

View File

@ -71,7 +71,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
def replicate_stream(self, stream, token="NOW"): def replicate_stream(self, stream, token="NOW"):
"""Make the client end a REPLICATE command to set up a subscription to a stream""" """Make the client end a REPLICATE command to set up a subscription to a stream"""
self.client.send_command(ReplicateCommand(stream, token)) self.client.send_command(ReplicateCommand(stream))
class TestReplicationClientHandler(object): class TestReplicationClientHandler(object):