Factor out `MultiWriter` token from `RoomStreamToken` (#16427)
							parent
							
								
									ab9c1e8f39
								
							
						
					
					
						commit
						009b47badf
					
				|  | @ -0,0 +1 @@ | |||
| Factor out `MultiWriter` token from `RoomStreamToken`. | ||||
|  | @ -171,8 +171,8 @@ class AdminHandler: | |||
|             else: | ||||
|                 stream_ordering = room.stream_ordering | ||||
| 
 | ||||
|             from_key = RoomStreamToken(0, 0) | ||||
|             to_key = RoomStreamToken(None, stream_ordering) | ||||
|             from_key = RoomStreamToken(topological=0, stream=0) | ||||
|             to_key = RoomStreamToken(stream=stream_ordering) | ||||
| 
 | ||||
|             # Events that we've processed in this room | ||||
|             written_events: Set[str] = set() | ||||
|  |  | |||
|  | @ -192,8 +192,7 @@ class InitialSyncHandler: | |||
|                     ) | ||||
|                 elif event.membership == Membership.LEAVE: | ||||
|                     room_end_token = RoomStreamToken( | ||||
|                         None, | ||||
|                         event.stream_ordering, | ||||
|                         stream=event.stream_ordering, | ||||
|                     ) | ||||
|                     deferred_room_state = run_in_background( | ||||
|                         self._state_storage_controller.get_state_for_events, | ||||
|  |  | |||
|  | @ -1708,7 +1708,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]): | |||
| 
 | ||||
|         if from_key.topological: | ||||
|             logger.warning("Stream has topological part!!!! %r", from_key) | ||||
|             from_key = RoomStreamToken(None, from_key.stream) | ||||
|             from_key = RoomStreamToken(stream=from_key.stream) | ||||
| 
 | ||||
|         app_service = self.store.get_app_service_by_user_id(user.to_string()) | ||||
|         if app_service: | ||||
|  |  | |||
|  | @ -2333,7 +2333,7 @@ class SyncHandler: | |||
|                             continue | ||||
| 
 | ||||
|                 leave_token = now_token.copy_and_replace( | ||||
|                     StreamKeyType.ROOM, RoomStreamToken(None, event.stream_ordering) | ||||
|                     StreamKeyType.ROOM, RoomStreamToken(stream=event.stream_ordering) | ||||
|                 ) | ||||
|                 room_entries.append( | ||||
|                     RoomSyncResultBuilder( | ||||
|  |  | |||
|  | @ -146,7 +146,7 @@ class PurgeHistoryRestServlet(RestServlet): | |||
|             # RoomStreamToken expects [int] not Optional[int] | ||||
|             assert event.internal_metadata.stream_ordering is not None | ||||
|             room_token = RoomStreamToken( | ||||
|                 event.depth, event.internal_metadata.stream_ordering | ||||
|                 topological=event.depth, stream=event.internal_metadata.stream_ordering | ||||
|             ) | ||||
|             token = await room_token.to_string(self.store) | ||||
| 
 | ||||
|  |  | |||
|  | @ -266,7 +266,7 @@ def generate_next_token( | |||
|         # when we are going backwards so we subtract one from the | ||||
|         # stream part. | ||||
|         last_stream_ordering -= 1 | ||||
|     return RoomStreamToken(last_topo_ordering, last_stream_ordering) | ||||
|     return RoomStreamToken(topological=last_topo_ordering, stream=last_stream_ordering) | ||||
| 
 | ||||
| 
 | ||||
| def _make_generic_sql_bound( | ||||
|  | @ -558,7 +558,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
|                 if p > min_pos | ||||
|             } | ||||
| 
 | ||||
|         return RoomStreamToken(None, min_pos, immutabledict(positions)) | ||||
|         return RoomStreamToken(stream=min_pos, instance_map=immutabledict(positions)) | ||||
| 
 | ||||
|     async def get_room_events_stream_for_rooms( | ||||
|         self, | ||||
|  | @ -708,7 +708,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
|             ret.reverse() | ||||
| 
 | ||||
|         if rows: | ||||
|             key = RoomStreamToken(None, min(r.stream_ordering for r in rows)) | ||||
|             key = RoomStreamToken(stream=min(r.stream_ordering for r in rows)) | ||||
|         else: | ||||
|             # Assume we didn't get anything because there was nothing to | ||||
|             # get. | ||||
|  | @ -969,7 +969,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
|         topo = await self.db_pool.runInteraction( | ||||
|             "_get_max_topological_txn", self._get_max_topological_txn, room_id | ||||
|         ) | ||||
|         return RoomStreamToken(topo, stream_ordering) | ||||
|         return RoomStreamToken(topological=topo, stream=stream_ordering) | ||||
| 
 | ||||
|     @overload | ||||
|     def get_stream_id_for_event_txn( | ||||
|  | @ -1033,7 +1033,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
|             retcols=("stream_ordering", "topological_ordering"), | ||||
|             desc="get_topological_token_for_event", | ||||
|         ) | ||||
|         return RoomStreamToken(row["topological_ordering"], row["stream_ordering"]) | ||||
|         return RoomStreamToken( | ||||
|             topological=row["topological_ordering"], stream=row["stream_ordering"] | ||||
|         ) | ||||
| 
 | ||||
|     async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: | ||||
|         """Gets the topological token in a room after or at the given stream | ||||
|  | @ -1114,8 +1116,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
|             else: | ||||
|                 topo = None | ||||
|             internal = event.internal_metadata | ||||
|             internal.before = RoomStreamToken(topo, stream - 1) | ||||
|             internal.after = RoomStreamToken(topo, stream) | ||||
|             internal.before = RoomStreamToken(topological=topo, stream=stream - 1) | ||||
|             internal.after = RoomStreamToken(topological=topo, stream=stream) | ||||
|             internal.order = (int(topo) if topo else 0, int(stream)) | ||||
| 
 | ||||
|     async def get_events_around( | ||||
|  | @ -1191,11 +1193,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
|         # Paginating backwards includes the event at the token, but paginating | ||||
|         # forward doesn't. | ||||
|         before_token = RoomStreamToken( | ||||
|             results["topological_ordering"] - 1, results["stream_ordering"] | ||||
|             topological=results["topological_ordering"] - 1, | ||||
|             stream=results["stream_ordering"], | ||||
|         ) | ||||
| 
 | ||||
|         after_token = RoomStreamToken( | ||||
|             results["topological_ordering"], results["stream_ordering"] | ||||
|             topological=results["topological_ordering"], | ||||
|             stream=results["stream_ordering"], | ||||
|         ) | ||||
| 
 | ||||
|         rows, start_token = self._paginate_room_events_txn( | ||||
|  |  | |||
|  | @ -61,6 +61,8 @@ from synapse.util.cancellation import cancellable | |||
| from synapse.util.stringutils import parse_and_validate_server_name | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|     from typing_extensions import Self | ||||
| 
 | ||||
|     from synapse.appservice.api import ApplicationService | ||||
|     from synapse.storage.databases.main import DataStore, PurgeEventsStore | ||||
|     from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore | ||||
|  | @ -437,7 +439,78 @@ def map_username_to_mxid_localpart( | |||
| 
 | ||||
| 
 | ||||
| @attr.s(frozen=True, slots=True, order=False) | ||||
| class RoomStreamToken: | ||||
| class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): | ||||
|     """An abstract stream token class for streams that supports multiple | ||||
|     writers. | ||||
| 
 | ||||
|     This works by keeping track of the stream position of each writer, | ||||
|     represented by a default `stream` attribute and a map of instance name to | ||||
|     stream position of any writers that are ahead of the default stream | ||||
|     position. | ||||
|     """ | ||||
| 
 | ||||
|     stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True) | ||||
| 
 | ||||
|     instance_map: "immutabledict[str, int]" = attr.ib( | ||||
|         factory=immutabledict, | ||||
|         validator=attr.validators.deep_mapping( | ||||
|             key_validator=attr.validators.instance_of(str), | ||||
|             value_validator=attr.validators.instance_of(int), | ||||
|             mapping_validator=attr.validators.instance_of(immutabledict), | ||||
|         ), | ||||
|         kw_only=True, | ||||
|     ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     @abc.abstractmethod | ||||
|     async def parse(cls, store: "DataStore", string: str) -> "Self": | ||||
|         """Parse the string representation of the token.""" | ||||
|         ... | ||||
| 
 | ||||
|     @abc.abstractmethod | ||||
|     async def to_string(self, store: "DataStore") -> str: | ||||
|         """Serialize the token into its string representation.""" | ||||
|         ... | ||||
| 
 | ||||
|     def copy_and_advance(self, other: "Self") -> "Self": | ||||
|         """Return a new token such that if an event is after both this token and | ||||
|         the other token, then its after the returned token too. | ||||
|         """ | ||||
| 
 | ||||
|         max_stream = max(self.stream, other.stream) | ||||
| 
 | ||||
|         instance_map = { | ||||
|             instance: max( | ||||
|                 self.instance_map.get(instance, self.stream), | ||||
|                 other.instance_map.get(instance, other.stream), | ||||
|             ) | ||||
|             for instance in set(self.instance_map).union(other.instance_map) | ||||
|         } | ||||
| 
 | ||||
|         return attr.evolve( | ||||
|             self, stream=max_stream, instance_map=immutabledict(instance_map) | ||||
|         ) | ||||
| 
 | ||||
|     def get_max_stream_pos(self) -> int: | ||||
|         """Get the maximum stream position referenced in this token. | ||||
| 
 | ||||
|         The corresponding "min" position is, by definition just `self.stream`. | ||||
| 
 | ||||
|         This is used to handle tokens that have non-empty `instance_map`, and so | ||||
|         reference stream positions after the `self.stream` position. | ||||
|         """ | ||||
|         return max(self.instance_map.values(), default=self.stream) | ||||
| 
 | ||||
|     def get_stream_pos_for_instance(self, instance_name: str) -> int: | ||||
|         """Get the stream position that the given writer was at at this token.""" | ||||
| 
 | ||||
|         # If we don't have an entry for the instance we can assume that it was | ||||
|         # at `self.stream`. | ||||
|         return self.instance_map.get(instance_name, self.stream) | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(frozen=True, slots=True, order=False) | ||||
| class RoomStreamToken(AbstractMultiWriterStreamToken): | ||||
|     """Tokens are positions between events. The token "s1" comes after event 1. | ||||
| 
 | ||||
|             s0    s1 | ||||
|  | @ -514,16 +587,8 @@ class RoomStreamToken: | |||
| 
 | ||||
|     topological: Optional[int] = attr.ib( | ||||
|         validator=attr.validators.optional(attr.validators.instance_of(int)), | ||||
|     ) | ||||
|     stream: int = attr.ib(validator=attr.validators.instance_of(int)) | ||||
| 
 | ||||
|     instance_map: "immutabledict[str, int]" = attr.ib( | ||||
|         factory=immutabledict, | ||||
|         validator=attr.validators.deep_mapping( | ||||
|             key_validator=attr.validators.instance_of(str), | ||||
|             value_validator=attr.validators.instance_of(int), | ||||
|             mapping_validator=attr.validators.instance_of(immutabledict), | ||||
|         ), | ||||
|         kw_only=True, | ||||
|         default=None, | ||||
|     ) | ||||
| 
 | ||||
|     def __attrs_post_init__(self) -> None: | ||||
|  | @ -583,17 +648,7 @@ class RoomStreamToken: | |||
|         if self.topological or other.topological: | ||||
|             raise Exception("Can't advance topological tokens") | ||||
| 
 | ||||
|         max_stream = max(self.stream, other.stream) | ||||
| 
 | ||||
|         instance_map = { | ||||
|             instance: max( | ||||
|                 self.instance_map.get(instance, self.stream), | ||||
|                 other.instance_map.get(instance, other.stream), | ||||
|             ) | ||||
|             for instance in set(self.instance_map).union(other.instance_map) | ||||
|         } | ||||
| 
 | ||||
|         return RoomStreamToken(None, max_stream, immutabledict(instance_map)) | ||||
|         return super().copy_and_advance(other) | ||||
| 
 | ||||
|     def as_historical_tuple(self) -> Tuple[int, int]: | ||||
|         """Returns a tuple of `(topological, stream)` for historical tokens. | ||||
|  | @ -619,16 +674,6 @@ class RoomStreamToken: | |||
|         # at `self.stream`. | ||||
|         return self.instance_map.get(instance_name, self.stream) | ||||
| 
 | ||||
|     def get_max_stream_pos(self) -> int: | ||||
|         """Get the maximum stream position referenced in this token. | ||||
| 
 | ||||
|         The corresponding "min" position is, by definition just `self.stream`. | ||||
| 
 | ||||
|         This is used to handle tokens that have non-empty `instance_map`, and so | ||||
|         reference stream positions after the `self.stream` position. | ||||
|         """ | ||||
|         return max(self.instance_map.values(), default=self.stream) | ||||
| 
 | ||||
|     async def to_string(self, store: "DataStore") -> str: | ||||
|         if self.topological is not None: | ||||
|             return "t%d-%d" % (self.topological, self.stream) | ||||
|  | @ -838,23 +883,28 @@ class StreamToken: | |||
|         return getattr(self, key.value) | ||||
| 
 | ||||
| 
 | ||||
| StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0) | ||||
| StreamToken.START = StreamToken(RoomStreamToken(stream=0), 0, 0, 0, 0, 0, 0, 0, 0, 0) | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(slots=True, frozen=True, auto_attribs=True) | ||||
| class PersistedEventPosition: | ||||
| class PersistedPosition: | ||||
|     """Position of a newly persisted row with instance that persisted it.""" | ||||
| 
 | ||||
|     instance_name: str | ||||
|     stream: int | ||||
| 
 | ||||
|     def persisted_after(self, token: AbstractMultiWriterStreamToken) -> bool: | ||||
|         return token.get_stream_pos_for_instance(self.instance_name) < self.stream | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(slots=True, frozen=True, auto_attribs=True) | ||||
| class PersistedEventPosition(PersistedPosition): | ||||
|     """Position of a newly persisted event with instance that persisted it. | ||||
| 
 | ||||
|     This can be used to test whether the event is persisted before or after a | ||||
|     RoomStreamToken. | ||||
|     """ | ||||
| 
 | ||||
|     instance_name: str | ||||
|     stream: int | ||||
| 
 | ||||
|     def persisted_after(self, token: RoomStreamToken) -> bool: | ||||
|         return token.get_stream_pos_for_instance(self.instance_name) < self.stream | ||||
| 
 | ||||
|     def to_room_stream_token(self) -> RoomStreamToken: | ||||
|         """Converts the position to a room stream token such that events | ||||
|         persisted in the same room after this position will be after the | ||||
|  | @ -865,7 +915,7 @@ class PersistedEventPosition: | |||
|         """ | ||||
|         # Doing the naive thing satisfies the desired properties described in | ||||
|         # the docstring. | ||||
|         return RoomStreamToken(None, self.stream) | ||||
|         return RoomStreamToken(stream=self.stream) | ||||
| 
 | ||||
| 
 | ||||
| @attr.s(slots=True, frozen=True, auto_attribs=True) | ||||
|  |  | |||
|  | @ -86,7 +86,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
|                 [event], | ||||
|             ] | ||||
|         ) | ||||
|         self.handler.notify_interested_services(RoomStreamToken(None, 1)) | ||||
|         self.handler.notify_interested_services(RoomStreamToken(stream=1)) | ||||
| 
 | ||||
|         self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( | ||||
|             interested_service, events=[event] | ||||
|  | @ -107,7 +107,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
|             ] | ||||
|         ) | ||||
|         self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]]) | ||||
|         self.handler.notify_interested_services(RoomStreamToken(None, 0)) | ||||
|         self.handler.notify_interested_services(RoomStreamToken(stream=0)) | ||||
| 
 | ||||
|         self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) | ||||
| 
 | ||||
|  | @ -126,7 +126,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
|             ] | ||||
|         ) | ||||
| 
 | ||||
|         self.handler.notify_interested_services(RoomStreamToken(None, 0)) | ||||
|         self.handler.notify_interested_services(RoomStreamToken(stream=0)) | ||||
| 
 | ||||
|         self.assertFalse( | ||||
|             self.mock_as_api.query_user.called, | ||||
|  | @ -441,7 +441,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): | |||
|         self.get_success( | ||||
|             self.hs.get_application_service_handler()._notify_interested_services( | ||||
|                 RoomStreamToken( | ||||
|                     None, self.hs.get_application_service_handler().current_max | ||||
|                     stream=self.hs.get_application_service_handler().current_max | ||||
|                 ) | ||||
|             ) | ||||
|         ) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston