Move registration's bg updates to a dedicated store
							parent
							
								
									54f87e0734
								
							
						
					
					
						commit
						81e6ffb536
					
				| 
						 | 
				
			
			@ -37,7 +37,57 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
 | 
			
		|||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RegistrationWorkerStore(SQLBaseStore):
 | 
			
		||||
class RegistrationDeactivationStore(SQLBaseStore):
 | 
			
		||||
    @cachedInlineCallbacks()
 | 
			
		||||
    def get_user_deactivated_status(self, user_id):
 | 
			
		||||
        """Retrieve the value for the `deactivated` property for the provided user.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            user_id (str): The ID of the user to retrieve the status for.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            defer.Deferred(bool): The requested value.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        res = yield self._simple_select_one_onecol(
 | 
			
		||||
            table="users",
 | 
			
		||||
            keyvalues={"name": user_id},
 | 
			
		||||
            retcol="deactivated",
 | 
			
		||||
            desc="get_user_deactivated_status",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Convert the integer into a boolean.
 | 
			
		||||
        return res == 1
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def set_user_deactivated_status(self, user_id, deactivated):
 | 
			
		||||
        """Set the `deactivated` property for the provided user to the provided value.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            user_id (str): The ID of the user to set the status for.
 | 
			
		||||
            deactivated (bool): The value to set for `deactivated`.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        yield self.runInteraction(
 | 
			
		||||
            "set_user_deactivated_status",
 | 
			
		||||
            self.set_user_deactivated_status_txn,
 | 
			
		||||
            user_id,
 | 
			
		||||
            deactivated,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
 | 
			
		||||
        self._simple_update_one_txn(
 | 
			
		||||
            txn=txn,
 | 
			
		||||
            table="users",
 | 
			
		||||
            keyvalues={"name": user_id},
 | 
			
		||||
            updatevalues={"deactivated": 1 if deactivated else 0},
 | 
			
		||||
        )
 | 
			
		||||
        self._invalidate_cache_and_stream(
 | 
			
		||||
            txn, self.get_user_deactivated_status, (user_id,)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RegistrationWorkerStore(RegistrationDeactivationStore):
 | 
			
		||||
    def __init__(self, db_conn, hs):
 | 
			
		||||
        super(RegistrationWorkerStore, self).__init__(db_conn, hs)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -673,27 +723,6 @@ class RegistrationWorkerStore(SQLBaseStore):
 | 
			
		|||
            desc="get_id_servers_user_bound",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @cachedInlineCallbacks()
 | 
			
		||||
    def get_user_deactivated_status(self, user_id):
 | 
			
		||||
        """Retrieve the value for the `deactivated` property for the provided user.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            user_id (str): The ID of the user to retrieve the status for.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            defer.Deferred(bool): The requested value.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        res = yield self._simple_select_one_onecol(
 | 
			
		||||
            table="users",
 | 
			
		||||
            keyvalues={"name": user_id},
 | 
			
		||||
            retcol="deactivated",
 | 
			
		||||
            desc="get_user_deactivated_status",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Convert the integer into a boolean.
 | 
			
		||||
        return res == 1
 | 
			
		||||
 | 
			
		||||
    def get_threepid_validation_session(
 | 
			
		||||
        self, medium, client_secret, address=None, sid=None, validated=True
 | 
			
		||||
    ):
 | 
			
		||||
| 
						 | 
				
			
			@ -787,13 +816,14 @@ class RegistrationWorkerStore(SQLBaseStore):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RegistrationStore(
 | 
			
		||||
    RegistrationWorkerStore, background_updates.BackgroundUpdateStore
 | 
			
		||||
class RegistrationBackgroundUpdateStore(
 | 
			
		||||
    RegistrationDeactivationStore, background_updates.BackgroundUpdateStore
 | 
			
		||||
):
 | 
			
		||||
    def __init__(self, db_conn, hs):
 | 
			
		||||
        super(RegistrationStore, self).__init__(db_conn, hs)
 | 
			
		||||
        super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
 | 
			
		||||
 | 
			
		||||
        self.clock = hs.get_clock()
 | 
			
		||||
        self.config = hs.config
 | 
			
		||||
 | 
			
		||||
        self.register_background_index_update(
 | 
			
		||||
            "access_tokens_device_index",
 | 
			
		||||
| 
						 | 
				
			
			@ -809,8 +839,6 @@ class RegistrationStore(
 | 
			
		|||
            columns=["creation_ts"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self._account_validity = hs.config.account_validity
 | 
			
		||||
 | 
			
		||||
        # we no longer use refresh tokens, but it's possible that some people
 | 
			
		||||
        # might have a background update queued to build this index. Just
 | 
			
		||||
        # clear the background update.
 | 
			
		||||
| 
						 | 
				
			
			@ -824,17 +852,6 @@ class RegistrationStore(
 | 
			
		|||
            "users_set_deactivated_flag", self._background_update_set_deactivated_flag
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Create a background job for culling expired 3PID validity tokens
 | 
			
		||||
        def start_cull():
 | 
			
		||||
            # run as a background process to make sure that the database transactions
 | 
			
		||||
            # have a logcontext to report to
 | 
			
		||||
            return run_as_background_process(
 | 
			
		||||
                "cull_expired_threepid_validation_tokens",
 | 
			
		||||
                self.cull_expired_threepid_validation_tokens,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _background_update_set_deactivated_flag(self, progress, batch_size):
 | 
			
		||||
        """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
 | 
			
		||||
| 
						 | 
				
			
			@ -896,6 +913,54 @@ class RegistrationStore(
 | 
			
		|||
 | 
			
		||||
        return nb_processed
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _bg_user_threepids_grandfather(self, progress, batch_size):
 | 
			
		||||
        """We now track which identity servers a user binds their 3PID to, so
 | 
			
		||||
        we need to handle the case of existing bindings where we didn't track
 | 
			
		||||
        this.
 | 
			
		||||
 | 
			
		||||
        We do this by grandfathering in existing user threepids assuming that
 | 
			
		||||
        they used one of the server configured trusted identity servers.
 | 
			
		||||
        """
 | 
			
		||||
        id_servers = set(self.config.trusted_third_party_id_servers)
 | 
			
		||||
 | 
			
		||||
        def _bg_user_threepids_grandfather_txn(txn):
 | 
			
		||||
            sql = """
 | 
			
		||||
                INSERT INTO user_threepid_id_server
 | 
			
		||||
                    (user_id, medium, address, id_server)
 | 
			
		||||
                SELECT user_id, medium, address, ?
 | 
			
		||||
                FROM user_threepids
 | 
			
		||||
            """
 | 
			
		||||
 | 
			
		||||
            txn.executemany(sql, [(id_server,) for id_server in id_servers])
 | 
			
		||||
 | 
			
		||||
        if id_servers:
 | 
			
		||||
            yield self.runInteraction(
 | 
			
		||||
                "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        yield self._end_background_update("user_threepids_grandfather")
 | 
			
		||||
 | 
			
		||||
        return 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RegistrationStore(RegistrationWorkerStore, RegistrationBackgroundUpdateStore):
 | 
			
		||||
    def __init__(self, db_conn, hs):
 | 
			
		||||
        super(RegistrationStore, self).__init__(db_conn, hs)
 | 
			
		||||
 | 
			
		||||
        self._account_validity = hs.config.account_validity
 | 
			
		||||
 | 
			
		||||
        # Create a background job for culling expired 3PID validity tokens
 | 
			
		||||
        def start_cull():
 | 
			
		||||
            # run as a background process to make sure that the database transactions
 | 
			
		||||
            # have a logcontext to report to
 | 
			
		||||
            return run_as_background_process(
 | 
			
		||||
                "cull_expired_threepid_validation_tokens",
 | 
			
		||||
                self.cull_expired_threepid_validation_tokens,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
 | 
			
		||||
        """Adds an access token for the given user.
 | 
			
		||||
| 
						 | 
				
			
			@ -1244,36 +1309,6 @@ class RegistrationStore(
 | 
			
		|||
            desc="get_users_pending_deactivation",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _bg_user_threepids_grandfather(self, progress, batch_size):
 | 
			
		||||
        """We now track which identity servers a user binds their 3PID to, so
 | 
			
		||||
        we need to handle the case of existing bindings where we didn't track
 | 
			
		||||
        this.
 | 
			
		||||
 | 
			
		||||
        We do this by grandfathering in existing user threepids assuming that
 | 
			
		||||
        they used one of the server configured trusted identity servers.
 | 
			
		||||
        """
 | 
			
		||||
        id_servers = set(self.config.trusted_third_party_id_servers)
 | 
			
		||||
 | 
			
		||||
        def _bg_user_threepids_grandfather_txn(txn):
 | 
			
		||||
            sql = """
 | 
			
		||||
                INSERT INTO user_threepid_id_server
 | 
			
		||||
                    (user_id, medium, address, id_server)
 | 
			
		||||
                SELECT user_id, medium, address, ?
 | 
			
		||||
                FROM user_threepids
 | 
			
		||||
            """
 | 
			
		||||
 | 
			
		||||
            txn.executemany(sql, [(id_server,) for id_server in id_servers])
 | 
			
		||||
 | 
			
		||||
        if id_servers:
 | 
			
		||||
            yield self.runInteraction(
 | 
			
		||||
                "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        yield self._end_background_update("user_threepids_grandfather")
 | 
			
		||||
 | 
			
		||||
        return 1
 | 
			
		||||
 | 
			
		||||
    def validate_threepid_session(self, session_id, client_secret, token, current_ts):
 | 
			
		||||
        """Attempt to validate a threepid session using a token
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1464,30 +1499,3 @@ class RegistrationStore(
 | 
			
		|||
            cull_expired_threepid_validation_tokens_txn,
 | 
			
		||||
            self.clock.time_msec(),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
 | 
			
		||||
        self._simple_update_one_txn(
 | 
			
		||||
            txn=txn,
 | 
			
		||||
            table="users",
 | 
			
		||||
            keyvalues={"name": user_id},
 | 
			
		||||
            updatevalues={"deactivated": 1 if deactivated else 0},
 | 
			
		||||
        )
 | 
			
		||||
        self._invalidate_cache_and_stream(
 | 
			
		||||
            txn, self.get_user_deactivated_status, (user_id,)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def set_user_deactivated_status(self, user_id, deactivated):
 | 
			
		||||
        """Set the `deactivated` property for the provided user to the provided value.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            user_id (str): The ID of the user to set the status for.
 | 
			
		||||
            deactivated (bool): The value to set for `deactivated`.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        yield self.runInteraction(
 | 
			
		||||
            "set_user_deactivated_status",
 | 
			
		||||
            self.set_user_deactivated_status_txn,
 | 
			
		||||
            user_id,
 | 
			
		||||
            deactivated,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue