|  |  |  | @ -12,7 +12,7 @@ | 
		
	
		
			
				|  |  |  |  | # See the License for the specific language governing permissions and | 
		
	
		
			
				|  |  |  |  | # limitations under the License. | 
		
	
		
			
				|  |  |  |  | from abc import ABC, abstractmethod | 
		
	
		
			
				|  |  |  |  | from typing import TYPE_CHECKING, List, Optional, Tuple | 
		
	
		
			
				|  |  |  |  | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | import attr | 
		
	
		
			
				|  |  |  |  | from immutabledict import immutabledict | 
		
	
	
		
			
				
					|  |  |  | @ -107,33 +107,32 @@ class EventContext(UnpersistedEventContextBase): | 
		
	
		
			
				|  |  |  |  |         state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None | 
		
	
		
			
				|  |  |  |  |             then this is the delta of the state between the two groups. | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         prev_group: If it is known, ``state_group``'s prev_group. Note that this being | 
		
	
		
			
				|  |  |  |  |             None does not necessarily mean that ``state_group`` does not have | 
		
	
		
			
				|  |  |  |  |             a prev_group! | 
		
	
		
			
				|  |  |  |  |         state_group_deltas: If not empty, this is a dict collecting a mapping of the state | 
		
	
		
			
				|  |  |  |  |             difference between state groups. | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |             If the event is a state event, this is normally the same as | 
		
	
		
			
				|  |  |  |  |             ``state_group_before_event``. | 
		
	
		
			
				|  |  |  |  |             The keys are a tuple of two integers: the initial group and final state group. | 
		
	
		
			
				|  |  |  |  |             The corresponding value is a state map representing the state delta between | 
		
	
		
			
				|  |  |  |  |             these state groups. | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |             If ``state_group`` is None (ie, the event is an outlier), ``prev_group`` | 
		
	
		
			
				|  |  |  |  |             will always also be ``None``. | 
		
	
		
			
				|  |  |  |  |             The dictionary is expected to have at most two entries with state groups of: | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |             Note that this *not* (necessarily) the state group associated with | 
		
	
		
			
				|  |  |  |  |             ``_prev_state_ids``. | 
		
	
		
			
				|  |  |  |  |             1. The state group before the event and after the event. | 
		
	
		
			
				|  |  |  |  |             2. The state group preceding the state group before the event and the | 
		
	
		
			
				|  |  |  |  |                state group before the event. | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group`` | 
		
	
		
			
				|  |  |  |  |             and ``state_group``. | 
		
	
		
			
				|  |  |  |  |             This information is collected and stored as part of an optimization for persisting | 
		
	
		
			
				|  |  |  |  |             events. | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         partial_state: if True, we may be storing this event with a temporary, | 
		
	
		
			
				|  |  |  |  |             incomplete state. | 
		
	
		
			
				|  |  |  |  |     """ | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     _storage: "StorageControllers" | 
		
	
		
			
				|  |  |  |  |     state_group_deltas: Dict[Tuple[int, int], StateMap[str]] | 
		
	
		
			
				|  |  |  |  |     rejected: Optional[str] = None | 
		
	
		
			
				|  |  |  |  |     _state_group: Optional[int] = None | 
		
	
		
			
				|  |  |  |  |     state_group_before_event: Optional[int] = None | 
		
	
		
			
				|  |  |  |  |     _state_delta_due_to_event: Optional[StateMap[str]] = None | 
		
	
		
			
				|  |  |  |  |     prev_group: Optional[int] = None | 
		
	
		
			
				|  |  |  |  |     delta_ids: Optional[StateMap[str]] = None | 
		
	
		
			
				|  |  |  |  |     app_service: Optional[ApplicationService] = None | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     partial_state: bool = False | 
		
	
	
		
			
				
					|  |  |  | @ -145,16 +144,14 @@ class EventContext(UnpersistedEventContextBase): | 
		
	
		
			
				|  |  |  |  |         state_group_before_event: Optional[int], | 
		
	
		
			
				|  |  |  |  |         state_delta_due_to_event: Optional[StateMap[str]], | 
		
	
		
			
				|  |  |  |  |         partial_state: bool, | 
		
	
		
			
				|  |  |  |  |         prev_group: Optional[int] = None, | 
		
	
		
			
				|  |  |  |  |         delta_ids: Optional[StateMap[str]] = None, | 
		
	
		
			
				|  |  |  |  |         state_group_deltas: Dict[Tuple[int, int], StateMap[str]], | 
		
	
		
			
				|  |  |  |  |     ) -> "EventContext": | 
		
	
		
			
				|  |  |  |  |         return EventContext( | 
		
	
		
			
				|  |  |  |  |             storage=storage, | 
		
	
		
			
				|  |  |  |  |             state_group=state_group, | 
		
	
		
			
				|  |  |  |  |             state_group_before_event=state_group_before_event, | 
		
	
		
			
				|  |  |  |  |             state_delta_due_to_event=state_delta_due_to_event, | 
		
	
		
			
				|  |  |  |  |             prev_group=prev_group, | 
		
	
		
			
				|  |  |  |  |             delta_ids=delta_ids, | 
		
	
		
			
				|  |  |  |  |             state_group_deltas=state_group_deltas, | 
		
	
		
			
				|  |  |  |  |             partial_state=partial_state, | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
	
		
			
				
					|  |  |  | @ -163,7 +160,7 @@ class EventContext(UnpersistedEventContextBase): | 
		
	
		
			
				|  |  |  |  |         storage: "StorageControllers", | 
		
	
		
			
				|  |  |  |  |     ) -> "EventContext": | 
		
	
		
			
				|  |  |  |  |         """Return an EventContext instance suitable for persisting an outlier event""" | 
		
	
		
			
				|  |  |  |  |         return EventContext(storage=storage) | 
		
	
		
			
				|  |  |  |  |         return EventContext(storage=storage, state_group_deltas={}) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     async def persist(self, event: EventBase) -> "EventContext": | 
		
	
		
			
				|  |  |  |  |         return self | 
		
	
	
		
			
				
					|  |  |  | @ -183,13 +180,15 @@ class EventContext(UnpersistedEventContextBase): | 
		
	
		
			
				|  |  |  |  |             "state_group": self._state_group, | 
		
	
		
			
				|  |  |  |  |             "state_group_before_event": self.state_group_before_event, | 
		
	
		
			
				|  |  |  |  |             "rejected": self.rejected, | 
		
	
		
			
				|  |  |  |  |             "prev_group": self.prev_group, | 
		
	
		
			
				|  |  |  |  |             "state_group_deltas": _encode_state_group_delta(self.state_group_deltas), | 
		
	
		
			
				|  |  |  |  |             "state_delta_due_to_event": _encode_state_dict( | 
		
	
		
			
				|  |  |  |  |                 self._state_delta_due_to_event | 
		
	
		
			
				|  |  |  |  |             ), | 
		
	
		
			
				|  |  |  |  |             "delta_ids": _encode_state_dict(self.delta_ids), | 
		
	
		
			
				|  |  |  |  |             "app_service_id": self.app_service.id if self.app_service else None, | 
		
	
		
			
				|  |  |  |  |             "partial_state": self.partial_state, | 
		
	
		
			
				|  |  |  |  |             # add dummy delta_ids and prev_group for backwards compatibility | 
		
	
		
			
				|  |  |  |  |             "delta_ids": None, | 
		
	
		
			
				|  |  |  |  |             "prev_group": None, | 
		
	
		
			
				|  |  |  |  |         } | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     @staticmethod | 
		
	
	
		
			
				
					|  |  |  | @ -204,17 +203,24 @@ class EventContext(UnpersistedEventContextBase): | 
		
	
		
			
				|  |  |  |  |         Returns: | 
		
	
		
			
				|  |  |  |  |             The event context. | 
		
	
		
			
				|  |  |  |  |         """ | 
		
	
		
			
				|  |  |  |  |         # workaround for backwards/forwards compatibility: if the input doesn't have a value | 
		
	
		
			
				|  |  |  |  |         # for "state_group_deltas" just assign an empty dict | 
		
	
		
			
				|  |  |  |  |         state_group_deltas = input.get("state_group_deltas", None) | 
		
	
		
			
				|  |  |  |  |         if state_group_deltas: | 
		
	
		
			
				|  |  |  |  |             state_group_deltas = _decode_state_group_delta(state_group_deltas) | 
		
	
		
			
				|  |  |  |  |         else: | 
		
	
		
			
				|  |  |  |  |             state_group_deltas = {} | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         context = EventContext( | 
		
	
		
			
				|  |  |  |  |             # We use the state_group and prev_state_id stuff to pull the | 
		
	
		
			
				|  |  |  |  |             # current_state_ids out of the DB and construct prev_state_ids. | 
		
	
		
			
				|  |  |  |  |             storage=storage, | 
		
	
		
			
				|  |  |  |  |             state_group=input["state_group"], | 
		
	
		
			
				|  |  |  |  |             state_group_before_event=input["state_group_before_event"], | 
		
	
		
			
				|  |  |  |  |             prev_group=input["prev_group"], | 
		
	
		
			
				|  |  |  |  |             state_group_deltas=state_group_deltas, | 
		
	
		
			
				|  |  |  |  |             state_delta_due_to_event=_decode_state_dict( | 
		
	
		
			
				|  |  |  |  |                 input["state_delta_due_to_event"] | 
		
	
		
			
				|  |  |  |  |             ), | 
		
	
		
			
				|  |  |  |  |             delta_ids=_decode_state_dict(input["delta_ids"]), | 
		
	
		
			
				|  |  |  |  |             rejected=input["rejected"], | 
		
	
		
			
				|  |  |  |  |             partial_state=input.get("partial_state", False), | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
	
		
			
				
					|  |  |  | @ -349,7 +355,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase): | 
		
	
		
			
				|  |  |  |  |     _storage: "StorageControllers" | 
		
	
		
			
				|  |  |  |  |     state_group_before_event: Optional[int] | 
		
	
		
			
				|  |  |  |  |     state_group_after_event: Optional[int] | 
		
	
		
			
				|  |  |  |  |     state_delta_due_to_event: Optional[dict] | 
		
	
		
			
				|  |  |  |  |     state_delta_due_to_event: Optional[StateMap[str]] | 
		
	
		
			
				|  |  |  |  |     prev_group_for_state_group_before_event: Optional[int] | 
		
	
		
			
				|  |  |  |  |     delta_ids_to_state_group_before_event: Optional[StateMap[str]] | 
		
	
		
			
				|  |  |  |  |     partial_state: bool | 
		
	
	
		
			
				
					|  |  |  | @ -380,26 +386,16 @@ class UnpersistedEventContext(UnpersistedEventContextBase): | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         events_and_persisted_context = [] | 
		
	
		
			
				|  |  |  |  |         for event, unpersisted_context in amended_events_and_context: | 
		
	
		
			
				|  |  |  |  |             if event.is_state(): | 
		
	
		
			
				|  |  |  |  |                 context = EventContext( | 
		
	
		
			
				|  |  |  |  |                     storage=unpersisted_context._storage, | 
		
	
		
			
				|  |  |  |  |                     state_group=unpersisted_context.state_group_after_event, | 
		
	
		
			
				|  |  |  |  |                     state_group_before_event=unpersisted_context.state_group_before_event, | 
		
	
		
			
				|  |  |  |  |                     state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, | 
		
	
		
			
				|  |  |  |  |                     partial_state=unpersisted_context.partial_state, | 
		
	
		
			
				|  |  |  |  |                     prev_group=unpersisted_context.state_group_before_event, | 
		
	
		
			
				|  |  |  |  |                     delta_ids=unpersisted_context.state_delta_due_to_event, | 
		
	
		
			
				|  |  |  |  |                 ) | 
		
	
		
			
				|  |  |  |  |             else: | 
		
	
		
			
				|  |  |  |  |                 context = EventContext( | 
		
	
		
			
				|  |  |  |  |                     storage=unpersisted_context._storage, | 
		
	
		
			
				|  |  |  |  |                     state_group=unpersisted_context.state_group_after_event, | 
		
	
		
			
				|  |  |  |  |                     state_group_before_event=unpersisted_context.state_group_before_event, | 
		
	
		
			
				|  |  |  |  |                     state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, | 
		
	
		
			
				|  |  |  |  |                     partial_state=unpersisted_context.partial_state, | 
		
	
		
			
				|  |  |  |  |                     prev_group=unpersisted_context.prev_group_for_state_group_before_event, | 
		
	
		
			
				|  |  |  |  |                     delta_ids=unpersisted_context.delta_ids_to_state_group_before_event, | 
		
	
		
			
				|  |  |  |  |                 ) | 
		
	
		
			
				|  |  |  |  |             state_group_deltas = unpersisted_context._build_state_group_deltas() | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |             context = EventContext( | 
		
	
		
			
				|  |  |  |  |                 storage=unpersisted_context._storage, | 
		
	
		
			
				|  |  |  |  |                 state_group=unpersisted_context.state_group_after_event, | 
		
	
		
			
				|  |  |  |  |                 state_group_before_event=unpersisted_context.state_group_before_event, | 
		
	
		
			
				|  |  |  |  |                 state_delta_due_to_event=unpersisted_context.state_delta_due_to_event, | 
		
	
		
			
				|  |  |  |  |                 partial_state=unpersisted_context.partial_state, | 
		
	
		
			
				|  |  |  |  |                 state_group_deltas=state_group_deltas, | 
		
	
		
			
				|  |  |  |  |             ) | 
		
	
		
			
				|  |  |  |  |             events_and_persisted_context.append((event, context)) | 
		
	
		
			
				|  |  |  |  |         return events_and_persisted_context | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
	
		
			
				
					|  |  |  | @ -452,11 +448,11 @@ class UnpersistedEventContext(UnpersistedEventContextBase): | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # if the event isn't a state event the state group doesn't change | 
		
	
		
			
				|  |  |  |  |         if not self.state_delta_due_to_event: | 
		
	
		
			
				|  |  |  |  |             state_group_after_event = self.state_group_before_event | 
		
	
		
			
				|  |  |  |  |             self.state_group_after_event = self.state_group_before_event | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # otherwise if it is a state event we need to get a state group for it | 
		
	
		
			
				|  |  |  |  |         else: | 
		
	
		
			
				|  |  |  |  |             state_group_after_event = await self._storage.state.store_state_group( | 
		
	
		
			
				|  |  |  |  |             self.state_group_after_event = await self._storage.state.store_state_group( | 
		
	
		
			
				|  |  |  |  |                 event.event_id, | 
		
	
		
			
				|  |  |  |  |                 event.room_id, | 
		
	
		
			
				|  |  |  |  |                 prev_group=self.state_group_before_event, | 
		
	
	
		
			
				
					|  |  |  | @ -464,16 +460,81 @@ class UnpersistedEventContext(UnpersistedEventContextBase): | 
		
	
		
			
				|  |  |  |  |                 current_state_ids=None, | 
		
	
		
			
				|  |  |  |  |             ) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         state_group_deltas = self._build_state_group_deltas() | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         return EventContext.with_state( | 
		
	
		
			
				|  |  |  |  |             storage=self._storage, | 
		
	
		
			
				|  |  |  |  |             state_group=state_group_after_event, | 
		
	
		
			
				|  |  |  |  |             state_group=self.state_group_after_event, | 
		
	
		
			
				|  |  |  |  |             state_group_before_event=self.state_group_before_event, | 
		
	
		
			
				|  |  |  |  |             state_delta_due_to_event=self.state_delta_due_to_event, | 
		
	
		
			
				|  |  |  |  |             state_group_deltas=state_group_deltas, | 
		
	
		
			
				|  |  |  |  |             partial_state=self.partial_state, | 
		
	
		
			
				|  |  |  |  |             prev_group=self.state_group_before_event, | 
		
	
		
			
				|  |  |  |  |             delta_ids=self.state_delta_due_to_event, | 
		
	
		
			
				|  |  |  |  |         ) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     def _build_state_group_deltas(self) -> Dict[Tuple[int, int], StateMap]: | 
		
	
		
			
				|  |  |  |  |         """ | 
		
	
		
			
				|  |  |  |  |         Collect deltas between the state groups associated with this context | 
		
	
		
			
				|  |  |  |  |         """ | 
		
	
		
			
				|  |  |  |  |         state_group_deltas = {} | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # if we know the state group before the event and after the event, add them and the | 
		
	
		
			
				|  |  |  |  |         # state delta between them to state_group_deltas | 
		
	
		
			
				|  |  |  |  |         if self.state_group_before_event and self.state_group_after_event: | 
		
	
		
			
				|  |  |  |  |             # if we have the state groups we should have the delta | 
		
	
		
			
				|  |  |  |  |             assert self.state_delta_due_to_event is not None | 
		
	
		
			
				|  |  |  |  |             state_group_deltas[ | 
		
	
		
			
				|  |  |  |  |                 ( | 
		
	
		
			
				|  |  |  |  |                     self.state_group_before_event, | 
		
	
		
			
				|  |  |  |  |                     self.state_group_after_event, | 
		
	
		
			
				|  |  |  |  |                 ) | 
		
	
		
			
				|  |  |  |  |             ] = self.state_delta_due_to_event | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         # the state group before the event may also have a state group which precedes it, if | 
		
	
		
			
				|  |  |  |  |         # we have that and the state group before the event, add them and the state | 
		
	
		
			
				|  |  |  |  |         # delta between them to state_group_deltas | 
		
	
		
			
				|  |  |  |  |         if ( | 
		
	
		
			
				|  |  |  |  |             self.prev_group_for_state_group_before_event | 
		
	
		
			
				|  |  |  |  |             and self.state_group_before_event | 
		
	
		
			
				|  |  |  |  |         ): | 
		
	
		
			
				|  |  |  |  |             # if we have both state groups we should have the delta between them | 
		
	
		
			
				|  |  |  |  |             assert self.delta_ids_to_state_group_before_event is not None | 
		
	
		
			
				|  |  |  |  |             state_group_deltas[ | 
		
	
		
			
				|  |  |  |  |                 ( | 
		
	
		
			
				|  |  |  |  |                     self.prev_group_for_state_group_before_event, | 
		
	
		
			
				|  |  |  |  |                     self.state_group_before_event, | 
		
	
		
			
				|  |  |  |  |                 ) | 
		
	
		
			
				|  |  |  |  |             ] = self.delta_ids_to_state_group_before_event | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |         return state_group_deltas | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | def _encode_state_group_delta( | 
		
	
		
			
				|  |  |  |  |     state_group_delta: Dict[Tuple[int, int], StateMap[str]] | 
		
	
		
			
				|  |  |  |  | ) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]: | 
		
	
		
			
				|  |  |  |  |     if not state_group_delta: | 
		
	
		
			
				|  |  |  |  |         return [] | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     state_group_delta_encoded = [] | 
		
	
		
			
				|  |  |  |  |     for key, value in state_group_delta.items(): | 
		
	
		
			
				|  |  |  |  |         state_group_delta_encoded.append((key[0], key[1], _encode_state_dict(value))) | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     return state_group_delta_encoded | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | def _decode_state_group_delta( | 
		
	
		
			
				|  |  |  |  |     input: List[Tuple[int, int, List[Tuple[str, str, str]]]] | 
		
	
		
			
				|  |  |  |  | ) -> Dict[Tuple[int, int], StateMap[str]]: | 
		
	
		
			
				|  |  |  |  |     if not input: | 
		
	
		
			
				|  |  |  |  |         return {} | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     state_group_deltas = {} | 
		
	
		
			
				|  |  |  |  |     for state_group_1, state_group_2, state_dict in input: | 
		
	
		
			
				|  |  |  |  |         state_map = _decode_state_dict(state_dict) | 
		
	
		
			
				|  |  |  |  |         assert state_map is not None | 
		
	
		
			
				|  |  |  |  |         state_group_deltas[(state_group_1, state_group_2)] = state_map | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |     return state_group_deltas | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  | def _encode_state_dict( | 
		
	
		
			
				|  |  |  |  |     state_dict: Optional[StateMap[str]], | 
		
	
	
		
			
				
					|  |  |  | 
 |