Send out emails with links to extend an account's validity period
parent
747aa9f8ca
commit
20f0617e87
|
@ -0,0 +1 @@
|
|||
Add time-based account expiration.
|
|
@ -643,11 +643,31 @@ uploads_path: "DATADIR/uploads"
|
|||
#
|
||||
#enable_registration: false
|
||||
|
||||
# Optional account validity parameter. This allows for, e.g., accounts to
|
||||
# be denied any request after a given period.
|
||||
# Optional account validity configuration. This allows for accounts to be denied
|
||||
# any request after a given period.
|
||||
#
|
||||
# ``enabled`` defines whether the account validity feature is enabled. Defaults
|
||||
# to False.
|
||||
#
|
||||
# ``period`` allows setting the period after which an account is valid
|
||||
# after its registration. When renewing the account, its validity period
|
||||
# will be extended by this amount of time. This parameter is required when using
|
||||
# the account validity feature.
|
||||
#
|
||||
# ``renew_at`` is the amount of time before an account's expiry date at which
|
||||
# Synapse will send an email to the account's email address with a renewal link.
|
||||
# This needs the ``email`` and ``public_baseurl`` configuration sections to be
|
||||
# filled.
|
||||
#
|
||||
# ``renew_email_subject`` is the subject of the email sent out with the renewal
|
||||
# link. ``%(app)s`` can be used as a placeholder for the ``app_name`` parameter
|
||||
# from the ``email`` section.
|
||||
#
|
||||
#account_validity:
|
||||
# enabled: True
|
||||
# period: 6w
|
||||
# renew_at: 1w
|
||||
# renew_email_subject: "Renew your %(app)s account"
|
||||
|
||||
# The user must provide all of the below types of 3PID when registering.
|
||||
#
|
||||
|
@ -890,7 +910,7 @@ password_config:
|
|||
|
||||
|
||||
|
||||
# Enable sending emails for notification events
|
||||
# Enable sending emails for notification events or expiry notices
|
||||
# Defining a custom URL for Riot is only needed if email notifications
|
||||
# should contain links to a self-hosted installation of Riot; when set
|
||||
# the "app_name" setting is ignored.
|
||||
|
@ -912,6 +932,9 @@ password_config:
|
|||
# #template_dir: res/templates
|
||||
# notif_template_html: notif_mail.html
|
||||
# notif_template_text: notif_mail.txt
|
||||
# # Templates for account expiry notices.
|
||||
# expiry_template_html: notice_expiry.html
|
||||
# expiry_template_text: notice_expiry.txt
|
||||
# notif_for_new_users: True
|
||||
# riot_base_url: "http://localhost/riot"
|
||||
|
||||
|
|
|
@ -230,8 +230,9 @@ class Auth(object):
|
|||
|
||||
# Deny the request if the user account has expired.
|
||||
if self._account_validity.enabled:
|
||||
expiration_ts = yield self.store.get_expiration_ts_for_user(user)
|
||||
if self.clock.time_msec() >= expiration_ts:
|
||||
user_id = user.to_string()
|
||||
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
|
||||
if expiration_ts and self.clock.time_msec() >= expiration_ts:
|
||||
raise AuthError(
|
||||
403,
|
||||
"User account has expired",
|
||||
|
|
|
@ -71,6 +71,8 @@ class EmailConfig(Config):
|
|||
self.email_notif_from = email_config["notif_from"]
|
||||
self.email_notif_template_html = email_config["notif_template_html"]
|
||||
self.email_notif_template_text = email_config["notif_template_text"]
|
||||
self.email_expiry_template_html = email_config["expiry_template_html"]
|
||||
self.email_expiry_template_text = email_config["expiry_template_text"]
|
||||
|
||||
template_dir = email_config.get("template_dir")
|
||||
# we need an absolute path, because we change directory after starting (and
|
||||
|
@ -120,7 +122,7 @@ class EmailConfig(Config):
|
|||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# Enable sending emails for notification events
|
||||
# Enable sending emails for notification events or expiry notices
|
||||
# Defining a custom URL for Riot is only needed if email notifications
|
||||
# should contain links to a self-hosted installation of Riot; when set
|
||||
# the "app_name" setting is ignored.
|
||||
|
@ -142,6 +144,9 @@ class EmailConfig(Config):
|
|||
# #template_dir: res/templates
|
||||
# notif_template_html: notif_mail.html
|
||||
# notif_template_text: notif_mail.txt
|
||||
# # Templates for account expiry notices.
|
||||
# expiry_template_html: notice_expiry.html
|
||||
# expiry_template_text: notice_expiry.txt
|
||||
# notif_for_new_users: True
|
||||
# riot_base_url: "http://localhost/riot"
|
||||
"""
|
||||
|
|
|
@ -21,12 +21,26 @@ from synapse.util.stringutils import random_string_with_symbols
|
|||
|
||||
|
||||
class AccountValidityConfig(Config):
|
||||
def __init__(self, config):
|
||||
self.enabled = (len(config) > 0)
|
||||
def __init__(self, config, synapse_config):
|
||||
self.enabled = config.get("enabled", False)
|
||||
self.renew_by_email_enabled = ("renew_at" in config)
|
||||
|
||||
period = config.get("period", None)
|
||||
if period:
|
||||
self.period = self.parse_duration(period)
|
||||
if self.enabled:
|
||||
if "period" in config:
|
||||
self.period = self.parse_duration(config["period"])
|
||||
else:
|
||||
raise ConfigError("'period' is required when using account validity")
|
||||
|
||||
if "renew_at" in config:
|
||||
self.renew_at = self.parse_duration(config["renew_at"])
|
||||
|
||||
if "renew_email_subject" in config:
|
||||
self.renew_email_subject = config["renew_email_subject"]
|
||||
else:
|
||||
self.renew_email_subject = "Renew your %(app)s account"
|
||||
|
||||
if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
|
||||
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
|
||||
|
||||
|
||||
class RegistrationConfig(Config):
|
||||
|
@ -40,7 +54,9 @@ class RegistrationConfig(Config):
|
|||
strtobool(str(config["disable_registration"]))
|
||||
)
|
||||
|
||||
self.account_validity = AccountValidityConfig(config.get("account_validity", {}))
|
||||
self.account_validity = AccountValidityConfig(
|
||||
config.get("account_validity", {}), config,
|
||||
)
|
||||
|
||||
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
||||
|
@ -86,11 +102,31 @@ class RegistrationConfig(Config):
|
|||
#
|
||||
#enable_registration: false
|
||||
|
||||
# Optional account validity parameter. This allows for, e.g., accounts to
|
||||
# be denied any request after a given period.
|
||||
# Optional account validity configuration. This allows for accounts to be denied
|
||||
# any request after a given period.
|
||||
#
|
||||
# ``enabled`` defines whether the account validity feature is enabled. Defaults
|
||||
# to False.
|
||||
#
|
||||
# ``period`` allows setting the period after which an account is valid
|
||||
# after its registration. When renewing the account, its validity period
|
||||
# will be extended by this amount of time. This parameter is required when using
|
||||
# the account validity feature.
|
||||
#
|
||||
# ``renew_at`` is the amount of time before an account's expiry date at which
|
||||
# Synapse will send an email to the account's email address with a renewal link.
|
||||
# This needs the ``email`` and ``public_baseurl`` configuration sections to be
|
||||
# filled.
|
||||
#
|
||||
# ``renew_email_subject`` is the subject of the email sent out with the renewal
|
||||
# link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter
|
||||
# from the ``email`` section.
|
||||
#
|
||||
#account_validity:
|
||||
# enabled: True
|
||||
# period: 6w
|
||||
# renew_at: 1w
|
||||
# renew_email_subject: "Renew your %%(app)s account"
|
||||
|
||||
# The user must provide all of the below types of 3PID when registering.
|
||||
#
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import email.mime.multipart
|
||||
import email.utils
|
||||
import logging
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.types import UserID
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
|
||||
try:
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
except ImportError:
|
||||
load_jinja2_templates = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccountValidityHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = self.hs.get_datastore()
|
||||
self.sendmail = self.hs.get_sendmail()
|
||||
self.clock = self.hs.get_clock()
|
||||
|
||||
self._account_validity = self.hs.config.account_validity
|
||||
|
||||
if self._account_validity.renew_by_email_enabled and load_jinja2_templates:
|
||||
# Don't do email-specific configuration if renewal by email is disabled.
|
||||
try:
|
||||
app_name = self.hs.config.email_app_name
|
||||
|
||||
self._subject = self._account_validity.renew_email_subject % {
|
||||
"app": app_name,
|
||||
}
|
||||
|
||||
self._from_string = self.hs.config.email_notif_from % {
|
||||
"app": app_name,
|
||||
}
|
||||
except Exception:
|
||||
# If substitution failed, fall back to the bare strings.
|
||||
self._subject = self._account_validity.renew_email_subject
|
||||
self._from_string = self.hs.config.email_notif_from
|
||||
|
||||
self._raw_from = email.utils.parseaddr(self._from_string)[1]
|
||||
|
||||
self._template_html, self._template_text = load_jinja2_templates(
|
||||
config=self.hs.config,
|
||||
template_html_name=self.hs.config.email_expiry_template_html,
|
||||
template_text_name=self.hs.config.email_expiry_template_text,
|
||||
)
|
||||
|
||||
# Check the renewal emails to send and send them every 30min.
|
||||
self.clock.looping_call(
|
||||
self.send_renewal_emails,
|
||||
30 * 60 * 1000,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_renewal_emails(self):
|
||||
"""Gets the list of users whose account is expiring in the amount of time
|
||||
configured in the ``renew_at`` parameter from the ``account_validity``
|
||||
configuration, and sends renewal emails to all of these users as long as they
|
||||
have an email 3PID attached to their account.
|
||||
"""
|
||||
expiring_users = yield self.store.get_users_expiring_soon()
|
||||
|
||||
if expiring_users:
|
||||
for user in expiring_users:
|
||||
yield self._send_renewal_email(
|
||||
user_id=user["user_id"],
|
||||
expiration_ts=user["expiration_ts_ms"],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_renewal_email(self, user_id, expiration_ts):
|
||||
"""Sends out a renewal email to every email address attached to the given user
|
||||
with a unique link allowing them to renew their account.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user to send email(s) to.
|
||||
expiration_ts (int): Timestamp in milliseconds for the expiration date of
|
||||
this user's account (used in the email templates).
|
||||
"""
|
||||
addresses = yield self._get_email_addresses_for_user(user_id)
|
||||
|
||||
# Stop right here if the user doesn't have at least one email address.
|
||||
# In this case, they will have to ask their server admin to renew their
|
||||
# account manually.
|
||||
if not addresses:
|
||||
return
|
||||
|
||||
try:
|
||||
user_display_name = yield self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
)
|
||||
if user_display_name is None:
|
||||
user_display_name = user_id
|
||||
except StoreError:
|
||||
user_display_name = user_id
|
||||
|
||||
renewal_token = yield self._get_renewal_token(user_id)
|
||||
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
|
||||
self.hs.config.public_baseurl,
|
||||
renewal_token,
|
||||
)
|
||||
|
||||
template_vars = {
|
||||
"display_name": user_display_name,
|
||||
"expiration_ts": expiration_ts,
|
||||
"url": url,
|
||||
}
|
||||
|
||||
html_text = self._template_html.render(**template_vars)
|
||||
html_part = MIMEText(html_text, "html", "utf8")
|
||||
|
||||
plain_text = self._template_text.render(**template_vars)
|
||||
text_part = MIMEText(plain_text, "plain", "utf8")
|
||||
|
||||
for address in addresses:
|
||||
raw_to = email.utils.parseaddr(address)[1]
|
||||
|
||||
multipart_msg = MIMEMultipart('alternative')
|
||||
multipart_msg['Subject'] = self._subject
|
||||
multipart_msg['From'] = self._from_string
|
||||
multipart_msg['To'] = address
|
||||
multipart_msg['Date'] = email.utils.formatdate()
|
||||
multipart_msg['Message-ID'] = email.utils.make_msgid()
|
||||
multipart_msg.attach(text_part)
|
||||
multipart_msg.attach(html_part)
|
||||
|
||||
logger.info("Sending renewal email to %s", address)
|
||||
|
||||
yield make_deferred_yieldable(self.sendmail(
|
||||
self.hs.config.email_smtp_host,
|
||||
self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'),
|
||||
reactor=self.hs.get_reactor(),
|
||||
port=self.hs.config.email_smtp_port,
|
||||
requireAuthentication=self.hs.config.email_smtp_user is not None,
|
||||
username=self.hs.config.email_smtp_user,
|
||||
password=self.hs.config.email_smtp_pass,
|
||||
requireTransportSecurity=self.hs.config.require_transport_security
|
||||
))
|
||||
|
||||
yield self.store.set_renewal_mail_status(
|
||||
user_id=user_id,
|
||||
email_sent=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_email_addresses_for_user(self, user_id):
|
||||
"""Retrieve the list of email addresses attached to a user's account.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user to lookup email addresses for.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[list[str]]: Email addresses for this account.
|
||||
"""
|
||||
threepids = yield self.store.user_get_threepids(user_id)
|
||||
|
||||
addresses = []
|
||||
for threepid in threepids:
|
||||
if threepid["medium"] == "email":
|
||||
addresses.append(threepid["address"])
|
||||
|
||||
defer.returnValue(addresses)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_renewal_token(self, user_id):
|
||||
"""Generates a 32-byte long random string that will be inserted into the
|
||||
user's renewal email's unique link, then saves it into the database.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user to generate a string for.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[str]: The generated string.
|
||||
|
||||
Raises:
|
||||
StoreError(500): Couldn't generate a unique string after 5 attempts.
|
||||
"""
|
||||
attempts = 0
|
||||
while attempts < 5:
|
||||
try:
|
||||
renewal_token = stringutils.random_string(32)
|
||||
yield self.store.set_renewal_token_for_user(user_id, renewal_token)
|
||||
defer.returnValue(renewal_token)
|
||||
except StoreError:
|
||||
attempts += 1
|
||||
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def renew_account(self, renewal_token):
|
||||
"""Renews the account attached to a given renewal token by pushing back the
|
||||
expiration date by the current validity period in the server's configuration.
|
||||
|
||||
Args:
|
||||
renewal_token (str): Token sent with the renewal request.
|
||||
"""
|
||||
user_id = yield self.store.get_user_from_renewal_token(renewal_token)
|
||||
|
||||
logger.debug("Renewing an account for user %s", user_id)
|
||||
|
||||
new_expiration_date = self.clock.time_msec() + self._account_validity.period
|
||||
|
||||
yield self.store.renew_account_for_user(
|
||||
user_id=user_id,
|
||||
new_expiration_ts=new_expiration_date,
|
||||
)
|
|
@ -521,11 +521,11 @@ def format_ts_filter(value, format):
|
|||
return time.strftime(format, time.localtime(value / 1000))
|
||||
|
||||
|
||||
def load_jinja2_templates(config):
|
||||
def load_jinja2_templates(config, template_html_name, template_text_name):
|
||||
"""Load the jinja2 email templates from disk
|
||||
|
||||
Returns:
|
||||
(notif_template_html, notif_template_text)
|
||||
(template_html, template_text)
|
||||
"""
|
||||
logger.info("loading email templates from '%s'", config.email_template_dir)
|
||||
loader = jinja2.FileSystemLoader(config.email_template_dir)
|
||||
|
@ -533,14 +533,10 @@ def load_jinja2_templates(config):
|
|||
env.filters["format_ts"] = format_ts_filter
|
||||
env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
|
||||
|
||||
notif_template_html = env.get_template(
|
||||
config.email_notif_template_html
|
||||
)
|
||||
notif_template_text = env.get_template(
|
||||
config.email_notif_template_text
|
||||
)
|
||||
template_html = env.get_template(template_html_name)
|
||||
template_text = env.get_template(template_text_name)
|
||||
|
||||
return notif_template_html, notif_template_text
|
||||
return template_html, template_text
|
||||
|
||||
|
||||
def _create_mxc_to_http_filter(config):
|
||||
|
|
|
@ -44,7 +44,11 @@ class PusherFactory(object):
|
|||
if hs.config.email_enable_notifs:
|
||||
self.mailers = {} # app_name -> Mailer
|
||||
|
||||
templates = load_jinja2_templates(hs.config)
|
||||
templates = load_jinja2_templates(
|
||||
config=hs.config,
|
||||
template_html_name=hs.config.email_notif_template_html,
|
||||
template_text_name=hs.config.email_notif_template_text,
|
||||
)
|
||||
self.notif_template_html, self.notif_template_text = templates
|
||||
|
||||
self.pusher_types["email"] = self._create_email_pusher
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
.noticetext {
|
||||
margin-top: 10px;
|
||||
margin-bottom: 10px;
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<style type="text/css">
|
||||
{% include 'mail.css' without context %}
|
||||
{% include "mail-%s.css" % app_name ignore missing without context %}
|
||||
{% include 'mail-expiry.css' without context %}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<table id="page">
|
||||
<tr>
|
||||
<td> </td>
|
||||
<td id="inner">
|
||||
<table class="header">
|
||||
<tr>
|
||||
<td>
|
||||
<div class="salutation">Hi {{ display_name }},</div>
|
||||
</td>
|
||||
<td class="logo">
|
||||
{% if app_name == "Riot" %}
|
||||
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
|
||||
{% elif app_name == "Vector" %}
|
||||
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
|
||||
{% else %}
|
||||
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
|
||||
{% endif %}
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td colspan="2">
|
||||
<div class="noticetext">Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.</div>
|
||||
<div class="noticetext">To extend the validity of your account, please click on the link bellow (or copy and paste it into a new browser tab):</div>
|
||||
<div class="noticetext"><a href="{{ url }}">{{ url }}</a></div>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</td>
|
||||
<td> </td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
|
@ -0,0 +1,7 @@
|
|||
Hi {{ display_name }},
|
||||
|
||||
Your account will expire on {{ expiration_ts|format_ts("%d-%m-%Y") }}. This means that you will lose access to your account after this date.
|
||||
|
||||
To extend the validity of your account, please click on the link bellow (or copy and paste it to a new browser tab):
|
||||
|
||||
{{ url }}
|
|
@ -33,6 +33,7 @@ from synapse.rest.client.v1 import (
|
|||
from synapse.rest.client.v2_alpha import (
|
||||
account,
|
||||
account_data,
|
||||
account_validity,
|
||||
auth,
|
||||
capabilities,
|
||||
devices,
|
||||
|
@ -109,3 +110,4 @@ class ClientRestResource(JsonResource):
|
|||
groups.register_servlets(hs, client_resource)
|
||||
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
|
||||
capabilities.register_servlets(hs, client_resource)
|
||||
account_validity.register_servlets(hs, client_resource)
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AccountValidityRenewServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account_validity/renew$")
|
||||
SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>"
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(AccountValidityRenewServlet, self).__init__()
|
||||
|
||||
self.hs = hs
|
||||
self.account_activity_handler = hs.get_account_validity_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
if b"token" not in request.args:
|
||||
raise SynapseError(400, "Missing renewal token")
|
||||
renewal_token = request.args[b"token"][0]
|
||||
|
||||
yield self.account_activity_handler.renew_account(renewal_token.decode('utf8'))
|
||||
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (
|
||||
len(AccountValidityRenewServlet.SUCCESS_HTML),
|
||||
))
|
||||
request.write(AccountValidityRenewServlet.SUCCESS_HTML)
|
||||
finish_request(request)
|
||||
defer.returnValue(None)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
AccountValidityRenewServlet(hs).register(http_server)
|
|
@ -47,6 +47,7 @@ from synapse.federation.transport.client import TransportLayerClient
|
|||
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
|
||||
from synapse.groups.groups_server import GroupsServerHandler
|
||||
from synapse.handlers import Handlers
|
||||
from synapse.handlers.account_validity import AccountValidityHandler
|
||||
from synapse.handlers.acme import AcmeHandler
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
||||
|
@ -183,6 +184,7 @@ class HomeServer(object):
|
|||
'room_context_handler',
|
||||
'sendmail',
|
||||
'registration_handler',
|
||||
'account_validity_handler',
|
||||
]
|
||||
|
||||
REQUIRED_ON_MASTER_STARTUP = [
|
||||
|
@ -506,6 +508,9 @@ class HomeServer(object):
|
|||
def build_registration_handler(self):
|
||||
return RegistrationHandler(self)
|
||||
|
||||
def build_account_validity_handler(self):
|
||||
return AccountValidityHandler(self)
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
super(RegistrationWorkerStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.config = hs.config
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@cached()
|
||||
def get_user_by_id(self, user_id):
|
||||
|
@ -87,25 +88,156 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_expiration_ts_for_user(self, user):
|
||||
def get_expiration_ts_for_user(self, user_id):
|
||||
"""Get the expiration timestamp for the account bearing a given user ID.
|
||||
|
||||
Args:
|
||||
user (str): The ID of the user.
|
||||
user_id (str): The ID of the user.
|
||||
Returns:
|
||||
defer.Deferred: None, if the account has no expiration timestamp,
|
||||
otherwise int representation of the timestamp (as a number of
|
||||
milliseconds since epoch).
|
||||
otherwise int representation of the timestamp (as a number of
|
||||
milliseconds since epoch).
|
||||
"""
|
||||
res = yield self._simple_select_one_onecol(
|
||||
table="account_validity",
|
||||
keyvalues={"user_id": user.to_string()},
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="expiration_ts_ms",
|
||||
allow_none=True,
|
||||
desc="get_expiration_date_for_user",
|
||||
desc="get_expiration_ts_for_user",
|
||||
)
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def renew_account_for_user(self, user_id, new_expiration_ts):
|
||||
"""Updates the account validity table with a new timestamp for a given
|
||||
user, removes the existing renewal token from this user, and unsets the
|
||||
flag indicating that an email has been sent for renewing this account.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user whose account validity to renew.
|
||||
new_expiration_ts: New expiration date, as a timestamp in milliseconds
|
||||
since epoch.
|
||||
"""
|
||||
def renew_account_for_user_txn(txn):
|
||||
self._simple_update_txn(
|
||||
txn=txn,
|
||||
table="account_validity",
|
||||
keyvalues={"user_id": user_id},
|
||||
updatevalues={
|
||||
"expiration_ts_ms": new_expiration_ts,
|
||||
"email_sent": False,
|
||||
"renewal_token": None,
|
||||
},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_expiration_ts_for_user, (user_id,),
|
||||
)
|
||||
|
||||
yield self.runInteraction(
|
||||
"renew_account_for_user",
|
||||
renew_account_for_user_txn,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_renewal_token_for_user(self, user_id, renewal_token):
|
||||
"""Defines a renewal token for a given user.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user to set the renewal token for.
|
||||
renewal_token (str): Random unique string that will be used to renew the
|
||||
user's account.
|
||||
|
||||
Raises:
|
||||
StoreError: The provided token is already set for another user.
|
||||
"""
|
||||
yield self._simple_update_one(
|
||||
table="account_validity",
|
||||
keyvalues={"user_id": user_id},
|
||||
updatevalues={"renewal_token": renewal_token},
|
||||
desc="set_renewal_token_for_user",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_from_renewal_token(self, renewal_token):
|
||||
"""Get a user ID from a renewal token.
|
||||
|
||||
Args:
|
||||
renewal_token (str): The renewal token to perform the lookup with.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[str]: The ID of the user to which the token belongs.
|
||||
"""
|
||||
res = yield self._simple_select_one_onecol(
|
||||
table="account_validity",
|
||||
keyvalues={"renewal_token": renewal_token},
|
||||
retcol="user_id",
|
||||
desc="get_user_from_renewal_token",
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_renewal_token_for_user(self, user_id):
|
||||
"""Get the renewal token associated with a given user ID.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID to lookup a token for.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[str]: The renewal token associated with this user ID.
|
||||
"""
|
||||
res = yield self._simple_select_one_onecol(
|
||||
table="account_validity",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="renewal_token",
|
||||
desc="get_renewal_token_for_user",
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_users_expiring_soon(self):
|
||||
"""Selects users whose account will expire in the [now, now + renew_at] time
|
||||
window (see configuration for account_validity for information on what renew_at
|
||||
refers to).
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
|
||||
"""
|
||||
def select_users_txn(txn, now_ms, renew_at):
|
||||
sql = (
|
||||
"SELECT user_id, expiration_ts_ms FROM account_validity"
|
||||
" WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
|
||||
)
|
||||
values = [False, now_ms, renew_at]
|
||||
txn.execute(sql, values)
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
res = yield self.runInteraction(
|
||||
"get_users_expiring_soon",
|
||||
select_users_txn,
|
||||
self.clock.time_msec(), self.config.account_validity.renew_at,
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_renewal_mail_status(self, user_id, email_sent):
|
||||
"""Sets or unsets the flag that indicates whether a renewal email has been sent
|
||||
to the user (and the user hasn't renewed their account yet).
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user to set/unset the flag for.
|
||||
email_sent (bool): Flag which indicates whether a renewal email has been sent
|
||||
to this user.
|
||||
"""
|
||||
yield self._simple_update_one(
|
||||
table="account_validity",
|
||||
keyvalues={"user_id": user_id},
|
||||
updatevalues={"email_sent": email_sent},
|
||||
desc="set_renewal_mail_status",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_server_admin(self, user):
|
||||
res = yield self._simple_select_one_onecol(
|
||||
|
@ -508,22 +640,24 @@ class RegistrationStore(RegistrationWorkerStore,
|
|||
}
|
||||
)
|
||||
|
||||
if self._account_validity.enabled:
|
||||
now_ms = self.clock.time_msec()
|
||||
expiration_ts = now_ms + self._account_validity.period
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"account_validity",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"expiration_ts_ms": expiration_ts,
|
||||
}
|
||||
)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
raise StoreError(
|
||||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||
)
|
||||
|
||||
if self._account_validity.enabled:
|
||||
now_ms = self.clock.time_msec()
|
||||
expiration_ts = now_ms + self._account_validity.period
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"account_validity",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"expiration_ts_ms": expiration_ts,
|
||||
"email_sent": False,
|
||||
}
|
||||
)
|
||||
|
||||
if token:
|
||||
# it's possible for this to get a conflict, but only for a single user
|
||||
# since tokens are namespaced based on their user ID
|
||||
|
|
|
@ -13,8 +13,15 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
DROP TABLE IF EXISTS account_validity;
|
||||
|
||||
-- Track what users are in public rooms.
|
||||
CREATE TABLE IF NOT EXISTS account_validity (
|
||||
user_id TEXT PRIMARY KEY,
|
||||
expiration_ts_ms BIGINT NOT NULL
|
||||
expiration_ts_ms BIGINT NOT NULL,
|
||||
email_sent BOOLEAN NOT NULL,
|
||||
renewal_token TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX account_validity_email_sent_idx ON account_validity(email_sent, expiration_ts_ms)
|
||||
CREATE UNIQUE INDEX account_validity_renewal_string_idx ON account_validity(renewal_token)
|
||||
|
|
|
@ -1,14 +1,22 @@
|
|||
import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.rest.client.v1 import admin, login
|
||||
from synapse.rest.client.v2_alpha import register, sync
|
||||
from synapse.rest.client.v2_alpha import account_validity, register, sync
|
||||
|
||||
from tests import unittest
|
||||
|
||||
try:
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
except ImportError:
|
||||
load_jinja2_templates = None
|
||||
|
||||
|
||||
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
|
@ -197,6 +205,7 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
|||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
config = self.default_config()
|
||||
# Test for account expiring after a week.
|
||||
config.enable_registration = True
|
||||
config.account_validity.enabled = True
|
||||
config.account_validity.period = 604800000 # Time in ms for 1 week
|
||||
|
@ -228,3 +237,92 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
|||
self.assertEquals(
|
||||
channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result,
|
||||
)
|
||||
|
||||
|
||||
class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
skip = "No Jinja installed" if not load_jinja2_templates else None
|
||||
servlets = [
|
||||
register.register_servlets,
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
account_validity.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
config = self.default_config()
|
||||
# Test for account expiring after a week and renewal emails being sent 2
|
||||
# days before expiry.
|
||||
config.enable_registration = True
|
||||
config.account_validity.enabled = True
|
||||
config.account_validity.renew_by_email_enabled = True
|
||||
config.account_validity.period = 604800000 # Time in ms for 1 week
|
||||
config.account_validity.renew_at = 172800000 # Time in ms for 2 days
|
||||
config.account_validity.renew_email_subject = "Renew your account"
|
||||
|
||||
# Email config.
|
||||
self.email_attempts = []
|
||||
|
||||
def sendmail(*args, **kwargs):
|
||||
self.email_attempts.append((args, kwargs))
|
||||
return
|
||||
|
||||
config.email_template_dir = os.path.abspath(
|
||||
pkg_resources.resource_filename('synapse', 'res/templates')
|
||||
)
|
||||
config.email_expiry_template_html = "notice_expiry.html"
|
||||
config.email_expiry_template_text = "notice_expiry.txt"
|
||||
config.email_smtp_host = "127.0.0.1"
|
||||
config.email_smtp_port = 20
|
||||
config.require_transport_security = False
|
||||
config.email_smtp_user = None
|
||||
config.email_smtp_pass = None
|
||||
config.email_notif_from = "test@example.com"
|
||||
|
||||
self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
return self.hs
|
||||
|
||||
def test_renewal_email(self):
|
||||
user_id = self.register_user("kermit", "monkey")
|
||||
tok = self.login("kermit", "monkey")
|
||||
# We need to manually add an email address otherwise the handler will do
|
||||
# nothing.
|
||||
now = self.hs.clock.time_msec()
|
||||
self.get_success(self.store.user_add_threepid(
|
||||
user_id=user_id, medium="email", address="kermit@example.com",
|
||||
validated_at=now, added_at=now,
|
||||
))
|
||||
|
||||
# The specific endpoint doesn't matter, all we need is an authenticated
|
||||
# endpoint.
|
||||
request, channel = self.make_request(
|
||||
b"GET", "/sync", access_token=tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
# Move 6 days forward. This should trigger a renewal email to be sent.
|
||||
self.reactor.advance(datetime.timedelta(days=6).total_seconds())
|
||||
self.assertEqual(len(self.email_attempts), 1)
|
||||
|
||||
# Retrieving the URL from the email is too much pain for now, so we
|
||||
# retrieve the token from the DB.
|
||||
renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
|
||||
url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
|
||||
request, channel = self.make_request(b"GET", url)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
# Move 3 days forward. If the renewal failed, every authed request with
|
||||
# our access token should be denied from now, otherwise they should
|
||||
# succeed.
|
||||
self.reactor.advance(datetime.timedelta(days=3).total_seconds())
|
||||
request, channel = self.make_request(
|
||||
b"GET", "/sync", access_token=tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
|
Loading…
Reference in New Issue