Don't push if an user account has expired
Signed-off-by: Mathieu Velten <matmaul@gmail.com>pull/8353/head
parent
4f3096d866
commit
eb088c6aa5
|
@ -0,0 +1 @@
|
||||||
|
Don't push if an user account has expired.
|
|
@ -218,11 +218,7 @@ class Auth:
|
||||||
# Deny the request if the user account has expired.
|
# Deny the request if the user account has expired.
|
||||||
if self._account_validity.enabled and not allow_expired:
|
if self._account_validity.enabled and not allow_expired:
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
|
if await self.store.is_account_expired(user_id, self.clock.time_msec()):
|
||||||
if (
|
|
||||||
expiration_ts is not None
|
|
||||||
and self.clock.time_msec() >= expiration_ts
|
|
||||||
):
|
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
||||||
)
|
)
|
||||||
|
|
|
@ -60,6 +60,8 @@ class PusherPool:
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.clock = self.hs.get_clock()
|
self.clock = self.hs.get_clock()
|
||||||
|
|
||||||
|
self._account_validity = hs.config.account_validity
|
||||||
|
|
||||||
# We shard the handling of push notifications by user ID.
|
# We shard the handling of push notifications by user ID.
|
||||||
self._pusher_shard_config = hs.config.push.pusher_shard_config
|
self._pusher_shard_config = hs.config.push.pusher_shard_config
|
||||||
self._instance_name = hs.get_instance_name()
|
self._instance_name = hs.get_instance_name()
|
||||||
|
@ -202,6 +204,14 @@ class PusherPool:
|
||||||
)
|
)
|
||||||
|
|
||||||
for u in users_affected:
|
for u in users_affected:
|
||||||
|
# Don't push if the user account has expired
|
||||||
|
if self._account_validity.enabled:
|
||||||
|
expired = await self.store.is_account_expired(
|
||||||
|
u, self.clock.time_msec()
|
||||||
|
)
|
||||||
|
if expired:
|
||||||
|
continue
|
||||||
|
|
||||||
if u in self.pushers:
|
if u in self.pushers:
|
||||||
for p in self.pushers[u].values():
|
for p in self.pushers[u].values():
|
||||||
p.on_new_notifications(max_stream_id)
|
p.on_new_notifications(max_stream_id)
|
||||||
|
@ -222,6 +232,14 @@ class PusherPool:
|
||||||
)
|
)
|
||||||
|
|
||||||
for u in users_affected:
|
for u in users_affected:
|
||||||
|
# Don't push if the user account has expired
|
||||||
|
if self._account_validity.enabled:
|
||||||
|
expired = await self.store.is_account_expired(
|
||||||
|
u, self.clock.time_msec()
|
||||||
|
)
|
||||||
|
if expired:
|
||||||
|
continue
|
||||||
|
|
||||||
if u in self.pushers:
|
if u in self.pushers:
|
||||||
for p in self.pushers[u].values():
|
for p in self.pushers[u].values():
|
||||||
p.on_new_receipts(min_stream_id, max_stream_id)
|
p.on_new_receipts(min_stream_id, max_stream_id)
|
||||||
|
|
|
@ -116,6 +116,22 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||||
desc="get_expiration_ts_for_user",
|
desc="get_expiration_ts_for_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
|
||||||
|
"""
|
||||||
|
Returns whether an user account is expired.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
current_ts: The current timestamp
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[bool]: whether the user account has expired
|
||||||
|
"""
|
||||||
|
expiration_ts = await self.get_expiration_ts_for_user(user_id)
|
||||||
|
if expiration_ts is not None and current_ts >= expiration_ts:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def set_account_validity_for_user(
|
async def set_account_validity_for_user(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|
Loading…
Reference in New Issue