Move connecting logic into ClientReplicationStreamProtocol
							parent
							
								
									09fc34c935
								
							
						
					
					
						commit
						6870fc496f
					
				|  | @ -89,11 +89,6 @@ class ReplicationClientHandler(object): | |||
|         # Used for tests. | ||||
|         self.awaiting_syncs = {} | ||||
| 
 | ||||
|         # Set of stream names that have been subscribe to, but haven't yet | ||||
|         # caught up with. This is used to track when the client has been fully | ||||
|         # connected to the remote. | ||||
|         self.streams_connecting = None | ||||
| 
 | ||||
|         # The factory used to create connections. | ||||
|         self.factory = None | ||||
| 
 | ||||
|  | @ -122,12 +117,6 @@ class ReplicationClientHandler(object): | |||
| 
 | ||||
|         Can be overriden in subclasses to handle more. | ||||
|         """ | ||||
|         # When we get a `POSITION` command it means we've finished getting | ||||
|         # missing updates for the given stream, and are now up to date. | ||||
|         self.streams_connecting.discard(stream_name) | ||||
|         if not self.streams_connecting: | ||||
|             self.finished_connecting() | ||||
| 
 | ||||
|         return self.store.process_replication_rows(stream_name, token, []) | ||||
| 
 | ||||
|     def on_sync(self, data): | ||||
|  | @ -154,9 +143,6 @@ class ReplicationClientHandler(object): | |||
|         elif room_account_data: | ||||
|             args["account_data"] = room_account_data | ||||
| 
 | ||||
|         # Record which streams we're in the process of subscribing to | ||||
|         self.streams_connecting = set(args.keys()) | ||||
| 
 | ||||
|         return args | ||||
| 
 | ||||
|     def get_currently_syncing_users(self): | ||||
|  | @ -222,10 +208,6 @@ class ReplicationClientHandler(object): | |||
|                 connection.send_command(cmd) | ||||
|             self.pending_commands = [] | ||||
| 
 | ||||
|             # This will happen if we don't actually subscribe to any streams | ||||
|             if not self.streams_connecting: | ||||
|                 self.finished_connecting() | ||||
| 
 | ||||
|     def finished_connecting(self): | ||||
|         """Called when we have successfully subscribed and caught up to all | ||||
|         streams we're interested in. | ||||
|  |  | |||
|  | @ -511,6 +511,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|         self.server_name = server_name | ||||
|         self.handler = handler | ||||
| 
 | ||||
|         # Set of stream names that have been subscribe to, but haven't yet | ||||
|         # caught up with. This is used to track when the client has been fully | ||||
|         # connected to the remote. | ||||
|         self.streams_connecting = set() | ||||
| 
 | ||||
|         # Map of stream to batched updates. See RdataCommand for info on how | ||||
|         # batching works. | ||||
|         self.pending_batches = {} | ||||
|  | @ -533,6 +538,10 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|         # We've now finished connecting to so inform the client handler | ||||
|         self.handler.update_connection(self) | ||||
| 
 | ||||
|         # This will happen if we don't actually subscribe to any streams | ||||
|         if not self.streams_connecting: | ||||
|             self.handler.finished_connecting() | ||||
| 
 | ||||
|     def on_SERVER(self, cmd): | ||||
|         if cmd.data != self.server_name: | ||||
|             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) | ||||
|  | @ -562,6 +571,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|             return self.handler.on_rdata(stream_name, cmd.token, rows) | ||||
| 
 | ||||
|     def on_POSITION(self, cmd): | ||||
|         # When we get a `POSITION` command it means we've finished getting | ||||
|         # missing updates for the given stream, and are now up to date. | ||||
|         self.streams_connecting.discard(cmd.stream_name) | ||||
|         if not self.streams_connecting: | ||||
|             self.handler.finished_connecting() | ||||
| 
 | ||||
|         return self.handler.on_position(cmd.stream_name, cmd.token) | ||||
| 
 | ||||
|     def on_SYNC(self, cmd): | ||||
|  | @ -578,6 +593,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|             self.id(), stream_name, token | ||||
|         ) | ||||
| 
 | ||||
|         self.streams_connecting.add(stream_name) | ||||
| 
 | ||||
|         self.send_command(ReplicateCommand(stream_name, token)) | ||||
| 
 | ||||
|     def on_connection_closed(self): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston