Merge branch 'develop' of github.com:matrix-org/synapse into erikj/paginate_sync

erikj/paginate_sync
Erik Johnston 2016-06-24 14:19:16 +01:00
commit 434c51d538
16 changed files with 512 additions and 302 deletions

View File

@ -70,6 +70,7 @@ cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop) git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000} : ${PORT_BASE:=8000}
: ${PORT_COUNT=20}
./jenkins/prep_sytest_for_postgres.sh ./jenkins/prep_sytest_for_postgres.sh
@ -81,6 +82,6 @@ echo >&2 "Running sytest with PostgreSQL";
--dendron $WORKSPACE/dendron/bin/dendron \ --dendron $WORKSPACE/dendron/bin/dendron \
--pusher \ --pusher \
--synchrotron \ --synchrotron \
--port-base $PORT_BASE --port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1))
cd .. cd ..

View File

@ -44,6 +44,7 @@ cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop) git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000} : ${PORT_BASE:=8000}
: ${PORT_COUNT=20}
./jenkins/prep_sytest_for_postgres.sh ./jenkins/prep_sytest_for_postgres.sh
@ -51,7 +52,7 @@ echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --coverage \ ./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \ --python $TOX_BIN/python \
--synapse-directory $WORKSPACE \ --synapse-directory $WORKSPACE \
--port-base $PORT_BASE --port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
cd .. cd ..
cp sytest/.coverage.* . cp sytest/.coverage.* .

View File

@ -41,11 +41,12 @@ cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop) git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8500} : ${PORT_COUNT=20}
: ${PORT_BASE:=8000}
./jenkins/install_and_run.sh --coverage \ ./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \ --python $TOX_BIN/python \
--synapse-directory $WORKSPACE \ --synapse-directory $WORKSPACE \
--port-base $PORT_BASE --port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
cd .. cd ..
cp sytest/.coverage.* . cp sytest/.coverage.* .

View File

@ -36,7 +36,7 @@
<div class="debug"> <div class="debug">
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
an event was received at {{ reason.received_at|format_ts("%c") }} an event was received at {{ reason.received_at|format_ts("%c") }}
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago, which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
{% if reason.last_sent_ts %} {% if reason.last_sent_ts %}
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }}, and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago. which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.

View File

@ -147,7 +147,7 @@ class SynapseHomeServer(HomeServer):
MEDIA_PREFIX: media_repo, MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource( CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr self, self.config.uploads_path
), ),
}) })
@ -301,7 +301,6 @@ def setup(config_options):
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
config=config, config=config,
content_addr=config.content_addr,
version_string=version_string, version_string=version_string,
database_engine=database_engine, database_engine=database_engine,
) )

View File

@ -14,11 +14,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import argparse
import collections
import glob
import os import os
import os.path import os.path
import subprocess
import signal import signal
import subprocess
import sys
import yaml import yaml
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
@ -28,60 +31,181 @@ RED = "\x1b[1;31m"
NORMAL = "\x1b[m" NORMAL = "\x1b[m"
def write(message, colour=NORMAL, stream=sys.stdout):
if colour == NORMAL:
stream.write(message + "\n")
else:
stream.write(colour + message + NORMAL + "\n")
def start(configfile): def start(configfile):
print ("Starting ...") write("Starting ...")
args = SYNAPSE args = SYNAPSE
args.extend(["--daemonize", "-c", configfile]) args.extend(["--daemonize", "-c", configfile])
try: try:
subprocess.check_call(args) subprocess.check_call(args)
print (GREEN + "started" + NORMAL) write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print ( write(
RED + "error starting (exit code: %d); see above for logs" % e.returncode,
"error starting (exit code: %d); see above for logs" % e.returncode + colour=RED,
NORMAL
) )
def stop(pidfile): def start_worker(app, configfile, worker_configfile):
args = [
"python", "-B",
"-m", app,
"-c", configfile,
"-c", worker_configfile
]
try:
subprocess.check_call(args)
write("started %s(%r)" % (app, worker_configfile), colour=GREEN)
except subprocess.CalledProcessError as e:
write(
"error starting %s(%r) (exit code: %d); see above for logs" % (
app, worker_configfile, e.returncode,
),
colour=RED,
)
def stop(pidfile, app):
if os.path.exists(pidfile): if os.path.exists(pidfile):
pid = int(open(pidfile).read()) pid = int(open(pidfile).read())
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
print (GREEN + "stopped" + NORMAL) write("stopped %s" % (app,), colour=GREEN)
Worker = collections.namedtuple("Worker", [
"app", "configfile", "pidfile", "cache_factor"
])
def main(): def main():
configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml"
if not os.path.exists(configfile): parser = argparse.ArgumentParser()
sys.stderr.write(
"No config file found\n" parser.add_argument(
"To generate a config file, run '%s -c %s --generate-config" "action",
" --server-name=<server name>'\n" % ( choices=["start", "stop", "restart"],
" ".join(SYNAPSE), configfile help="whether to start, stop or restart the synapse",
) )
parser.add_argument(
"configfile",
nargs="?",
default="homeserver.yaml",
help="the homeserver config file, defaults to homserver.yaml",
)
parser.add_argument(
"-w", "--worker",
metavar="WORKERCONFIG",
help="start or stop a single worker",
)
parser.add_argument(
"-a", "--all-processes",
metavar="WORKERCONFIGDIR",
help="start or stop all the workers in the given directory"
" and the main synapse process",
)
options = parser.parse_args()
if options.worker and options.all_processes:
write(
'Cannot use "--worker" with "--all-processes"',
stream=sys.stderr
) )
sys.exit(1) sys.exit(1)
config = yaml.load(open(configfile)) configfile = options.configfile
if not os.path.exists(configfile):
write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), options.configfile
),
stream=sys.stderr,
)
sys.exit(1)
with open(configfile) as stream:
config = yaml.load(stream)
pidfile = config["pid_file"] pidfile = config["pid_file"]
cache_factor = config.get("synctl_cache_factor", None) cache_factor = config.get("synctl_cache_factor")
start_stop_synapse = True
if cache_factor: if cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
action = sys.argv[1] if sys.argv[1:] else "usage" worker_configfiles = []
if action == "start": if options.worker:
start(configfile) start_stop_synapse = False
elif action == "stop": worker_configfile = options.worker
stop(pidfile) if not os.path.exists(worker_configfile):
elif action == "restart": write(
stop(pidfile) "No worker config found at %r" % (worker_configfile,),
start(configfile) stream=sys.stderr,
else: )
sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],)) sys.exit(1)
sys.exit(1) worker_configfiles.append(worker_configfile)
if options.all_processes:
worker_configdir = options.all_processes
if not os.path.isdir(worker_configdir):
write(
"No worker config directory found at %r" % (worker_configdir,),
stream=sys.stderr,
)
sys.exit(1)
worker_configfiles.extend(sorted(glob.glob(
os.path.join(worker_configdir, "*.yaml")
)))
workers = []
for worker_configfile in worker_configfiles:
with open(worker_configfile) as stream:
worker_config = yaml.load(stream)
worker_app = worker_config["worker_app"]
worker_pidfile = worker_config["worker_pid_file"]
worker_daemonize = worker_config["worker_daemonize"]
assert worker_daemonize # TODO print something more user friendly
worker_cache_factor = worker_config.get("synctl_cache_factor")
workers.append(Worker(
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
))
action = options.action
if action == "stop" or action == "restart":
for worker in workers:
stop(worker.pidfile, worker.app)
if start_stop_synapse:
stop(pidfile, "synapse.app.homeserver")
# TODO: Wait for synapse to actually shutdown before starting it again
if action == "start" or action == "restart":
if start_stop_synapse:
start(configfile)
for worker in workers:
if worker.cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor)
start_worker(worker.app, configfile, worker.configfile)
if cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
else:
os.environ.pop("SYNAPSE_CACHE_FACTOR", None)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -27,6 +27,7 @@ class CaptchaConfig(Config):
def default_config(self, **kwargs): def default_config(self, **kwargs):
return """\ return """\
## Captcha ## ## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this.
# This Home Server's ReCAPTCHA public key. # This Home Server's ReCAPTCHA public key.
recaptcha_public_key: "YOUR_PUBLIC_KEY" recaptcha_public_key: "YOUR_PUBLIC_KEY"

View File

@ -13,40 +13,88 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config from ._base import Config, ConfigError
MISSING_LDAP3 = (
"Missing ldap3 library. This is required for LDAP Authentication."
)
class LDAPMode(object):
SIMPLE = "simple",
SEARCH = "search",
LIST = (SIMPLE, SEARCH)
class LDAPConfig(Config): class LDAPConfig(Config):
def read_config(self, config): def read_config(self, config):
ldap_config = config.get("ldap_config", None) ldap_config = config.get("ldap_config", {})
if ldap_config:
self.ldap_enabled = ldap_config.get("enabled", False) self.ldap_enabled = ldap_config.get("enabled", False)
self.ldap_server = ldap_config["server"]
self.ldap_port = ldap_config["port"] if self.ldap_enabled:
self.ldap_tls = ldap_config.get("tls", False) # verify dependencies are available
self.ldap_search_base = ldap_config["search_base"] try:
self.ldap_search_property = ldap_config["search_property"] import ldap3
self.ldap_email_property = ldap_config["email_property"] ldap3 # to stop unused lint
self.ldap_full_name_property = ldap_config["full_name_property"] except ImportError:
else: raise ConfigError(MISSING_LDAP3)
self.ldap_enabled = False
self.ldap_server = None self.ldap_mode = LDAPMode.SIMPLE
self.ldap_port = None
self.ldap_tls = False # verify config sanity
self.ldap_search_base = None self.require_keys(ldap_config, [
self.ldap_search_property = None "uri",
self.ldap_email_property = None "base",
self.ldap_full_name_property = None "attributes",
])
self.ldap_uri = ldap_config["uri"]
self.ldap_start_tls = ldap_config.get("start_tls", False)
self.ldap_base = ldap_config["base"]
self.ldap_attributes = ldap_config["attributes"]
if "bind_dn" in ldap_config:
self.ldap_mode = LDAPMode.SEARCH
self.require_keys(ldap_config, [
"bind_dn",
"bind_password",
])
self.ldap_bind_dn = ldap_config["bind_dn"]
self.ldap_bind_password = ldap_config["bind_password"]
self.ldap_filter = ldap_config.get("filter", None)
# verify attribute lookup
self.require_keys(ldap_config['attributes'], [
"uid",
"name",
"mail",
])
def require_keys(self, config, required):
missing = [key for key in required if key not in config]
if missing:
raise ConfigError(
"LDAP enabled but missing required config values: {}".format(
", ".join(missing)
)
)
def default_config(self, **kwargs): def default_config(self, **kwargs):
return """\ return """\
# ldap_config: # ldap_config:
# enabled: true # enabled: true
# server: "ldap://localhost" # uri: "ldap://ldap.example.com:389"
# port: 389 # start_tls: true
# tls: false # base: "ou=users,dc=example,dc=com"
# search_base: "ou=Users,dc=example,dc=com" # attributes:
# search_property: "cn" # uid: "cn"
# email_property: "email" # mail: "email"
# full_name_property: "givenName" # name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
""" """

View File

@ -107,26 +107,6 @@ class ServerConfig(Config):
] ]
}) })
# Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr")
if not content_addr:
for listener in self.listeners:
if listener["type"] == "http" and not listener.get("tls", False):
unsecure_port = listener["port"]
break
else:
raise RuntimeError("Could not determine 'content_addr'")
host = self.server_name
if ':' not in host:
host = "%s:%d" % (host, unsecure_port)
else:
host = host.split(':')[0]
host = "%s:%d" % (host, unsecure_port)
content_addr = "http://%s" % (host,)
self.content_addr = content_addr
def default_config(self, server_name, **kwargs): def default_config(self, server_name, **kwargs):
if ":" in server_name: if ":" in server_name:
bind_port = int(server_name.split(":")[1]) bind_port = int(server_name.split(":")[1])

View File

@ -49,6 +49,7 @@ class FederationServer(FederationBase):
super(FederationServer, self).__init__(hs) super(FederationServer, self).__init__(hs)
self._room_pdu_linearizer = Linearizer() self._room_pdu_linearizer = Linearizer()
self._server_linearizer = Linearizer()
def set_handler(self, handler): def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate """Sets the handler that the replication layer will use to communicate
@ -89,11 +90,14 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, origin, room_id, versions, limit): def on_backfill_request(self, origin, room_id, versions, limit):
pdus = yield self.handler.on_backfill_request( with (yield self._server_linearizer.queue((origin, room_id))):
origin, room_id, versions, limit pdus = yield self.handler.on_backfill_request(
) origin, room_id, versions, limit
)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) res = self._transaction_from_pdus(pdus).get_dict()
defer.returnValue((200, res))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -184,27 +188,28 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_context_state_request(self, origin, room_id, event_id): def on_context_state_request(self, origin, room_id, event_id):
if event_id: with (yield self._server_linearizer.queue((origin, room_id))):
pdus = yield self.handler.get_state_for_pdu( if event_id:
origin, room_id, event_id, pdus = yield self.handler.get_state_for_pdu(
) origin, room_id, event_id,
auth_chain = yield self.store.get_auth_chain( )
[pdu.event_id for pdu in pdus] auth_chain = yield self.store.get_auth_chain(
) [pdu.event_id for pdu in pdus]
)
for event in auth_chain: for event in auth_chain:
# We sign these again because there was a bug where we # We sign these again because there was a bug where we
# incorrectly signed things the first time round # incorrectly signed things the first time round
if self.hs.is_mine_id(event.event_id): if self.hs.is_mine_id(event.event_id):
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
event, event,
self.hs.hostname, self.hs.hostname,
self.hs.config.signing_key[0] self.hs.config.signing_key[0]
)
) )
) else:
else: raise NotImplementedError("Specify an event")
raise NotImplementedError("Specify an event")
defer.returnValue((200, { defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus], "pdus": [pdu.get_pdu_json() for pdu in pdus],
@ -283,14 +288,16 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id): def on_event_auth(self, origin, room_id, event_id):
time_now = self._clock.time_msec() with (yield self._server_linearizer.queue((origin, room_id))):
auth_pdus = yield self.handler.on_event_auth(event_id) time_now = self._clock.time_msec()
defer.returnValue((200, { auth_pdus = yield self.handler.on_event_auth(event_id)
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], res = {
})) "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
}
defer.returnValue((200, res))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_auth_request(self, origin, content, event_id): def on_query_auth_request(self, origin, content, room_id, event_id):
""" """
Content is a dict with keys:: Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain. auth_chain (list): A list of events that give the auth chain.
@ -309,32 +316,33 @@ class FederationServer(FederationBase):
Returns: Returns:
Deferred: Results in `dict` with the same format as `content` Deferred: Results in `dict` with the same format as `content`
""" """
auth_chain = [ with (yield self._server_linearizer.queue((origin, room_id))):
self.event_from_pdu_json(e) auth_chain = [
for e in content["auth_chain"] self.event_from_pdu_json(e)
] for e in content["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True origin, auth_chain, outlier=True
) )
ret = yield self.handler.on_query_auth( ret = yield self.handler.on_query_auth(
origin, origin,
event_id, event_id,
signed_auth, signed_auth,
content.get("rejects", []), content.get("rejects", []),
content.get("missing", []), content.get("missing", []),
) )
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
send_content = { send_content = {
"auth_chain": [ "auth_chain": [
e.get_pdu_json(time_now) e.get_pdu_json(time_now)
for e in ret["auth_chain"] for e in ret["auth_chain"]
], ],
"rejects": ret.get("rejects", []), "rejects": ret.get("rejects", []),
"missing": ret.get("missing", []), "missing": ret.get("missing", []),
} }
defer.returnValue( defer.returnValue(
(200, send_content) (200, send_content)
@ -386,21 +394,24 @@ class FederationServer(FederationBase):
@log_function @log_function
def on_get_missing_events(self, origin, room_id, earliest_events, def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth): latest_events, limit, min_depth):
logger.info( with (yield self._server_linearizer.queue((origin, room_id))):
"on_get_missing_events: earliest_events: %r, latest_events: %r," logger.info(
" limit: %d, min_depth: %d", "on_get_missing_events: earliest_events: %r, latest_events: %r,"
earliest_events, latest_events, limit, min_depth " limit: %d, min_depth: %d",
) earliest_events, latest_events, limit, min_depth
missing_events = yield self.handler.on_get_missing_events( )
origin, room_id, earliest_events, latest_events, limit, min_depth missing_events = yield self.handler.on_get_missing_events(
) origin, room_id, earliest_events, latest_events, limit, min_depth
)
if len(missing_events) < 5: if len(missing_events) < 5:
logger.info("Returning %d events: %r", len(missing_events), missing_events) logger.info(
else: "Returning %d events: %r", len(missing_events), missing_events
logger.info("Returning %d events", len(missing_events)) )
else:
logger.info("Returning %d events", len(missing_events))
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue({ defer.returnValue({
"events": [ev.get_pdu_json(time_now) for ev in missing_events], "events": [ev.get_pdu_json(time_now) for ev in missing_events],

View File

@ -388,7 +388,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, origin, content, query, context, event_id): def on_POST(self, origin, content, query, context, event_id):
new_content = yield self.handler.on_query_auth_request( new_content = yield self.handler.on_query_auth_request(
origin, content, event_id origin, content, context, event_id
) )
defer.returnValue((200, new_content)) defer.returnValue((200, new_content))

View File

@ -20,6 +20,7 @@ from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.config.ldap import LDAPMode
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -28,6 +29,12 @@ import bcrypt
import pymacaroons import pymacaroons
import simplejson import simplejson
try:
import ldap3
except ImportError:
ldap3 = None
pass
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@ -50,17 +57,20 @@ class AuthHandler(BaseHandler):
self.INVALID_TOKEN_HTTP_STATUS = 401 self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled self.ldap_enabled = hs.config.ldap_enabled
self.ldap_server = hs.config.ldap_server if self.ldap_enabled:
self.ldap_port = hs.config.ldap_port if not ldap3:
self.ldap_tls = hs.config.ldap_tls raise RuntimeError(
self.ldap_search_base = hs.config.ldap_search_base 'Missing ldap3 library. This is required for LDAP Authentication.'
self.ldap_search_property = hs.config.ldap_search_property )
self.ldap_email_property = hs.config.ldap_email_property self.ldap_mode = hs.config.ldap_mode
self.ldap_full_name_property = hs.config.ldap_full_name_property self.ldap_uri = hs.config.ldap_uri
self.ldap_start_tls = hs.config.ldap_start_tls
if self.ldap_enabled is True: self.ldap_base = hs.config.ldap_base
import ldap self.ldap_filter = hs.config.ldap_filter
logger.info("Import ldap version: %s", ldap.__version__) self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
@ -452,40 +462,167 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_ldap_password(self, user_id, password): def _check_ldap_password(self, user_id, password):
if not self.ldap_enabled: """ Attempt to authenticate a user against an LDAP Server
logger.debug("LDAP not configured") and register an account if none exists.
Returns:
True if authentication against LDAP was successful
"""
if not ldap3 or not self.ldap_enabled:
defer.returnValue(False) defer.returnValue(False)
import ldap if self.ldap_mode not in LDAPMode.LIST:
raise RuntimeError(
'Invalid ldap mode specified: {mode}'.format(
mode=self.ldap_mode
)
)
logger.info("Authenticating %s with LDAP" % user_id)
try: try:
ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) server = ldap3.Server(self.ldap_uri)
logger.debug("Connecting LDAP server at %s" % ldap_url) logger.debug(
l = ldap.initialize(ldap_url) "Attempting ldap connection with %s",
if self.ldap_tls: self.ldap_uri
logger.debug("Initiating TLS") )
self._connection.start_tls_s()
local_name = UserID.from_string(user_id).localpart localpart = UserID.from_string(user_id).localpart
if self.ldap_mode == LDAPMode.SIMPLE:
dn = "%s=%s, %s" % ( # bind with the the local users ldap credentials
self.ldap_search_property, bind_dn = "{prop}={value},{base}".format(
local_name, prop=self.ldap_attributes['uid'],
self.ldap_search_base) value=localpart,
logger.debug("DN for LDAP authentication: %s" % dn) base=self.ldap_base
)
l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
if not (yield self.does_user_exist(user_id)): "Established ldap connection in simple mode: %s",
handler = self.hs.get_handlers().registration_handler conn
user_id, access_token = (
yield handler.register(localpart=local_name)
) )
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in simple mode through StartTLS: %s",
conn
)
conn.bind()
elif self.ldap_mode == LDAPMode.SEARCH:
# connect with preconfigured credentials and search for local user
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established ldap connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in search mode through StartTLS: %s",
conn
)
conn.bind()
# find matching dn
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug("ldap search filter: %s", query)
result = conn.search(self.ldap_base, query)
if result and len(conn.response) == 1:
# found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('ldap search found dn: %s', user_dn)
# unbind and reconnect, rebind with found dn
conn.unbind()
conn = ldap3.Connection(
server,
user_dn,
password,
auto_bind=True
)
else:
# found 0 or > 1 results, abort!
logger.warn(
"ldap search returned unexpected (%d!=1) amount of results",
len(conn.response)
)
defer.returnValue(False)
logger.info(
"User authenticated against ldap server: %s",
conn
)
# check for existing account, if none exists, create one
if not (yield self.does_user_exist(user_id)):
# query user metadata for account creation
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
query = "(&{filter}{user_filter})".format(
filter=query,
user_filter=self.ldap_filter
)
logger.debug("ldap registration filter: %s", query)
result = conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
self.ldap_attributes['name'],
self.ldap_attributes['mail']
]
)
if len(conn.response) == 1:
attrs = conn.response[0]['attributes']
mail = attrs[self.ldap_attributes['mail']][0]
name = attrs[self.ldap_attributes['name']][0]
# create account
registration_handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield registration_handler.register(localpart=localpart)
)
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"ldap registration successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
else:
logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results",
len(result)
)
defer.returnValue(False)
defer.returnValue(True) defer.returnValue(True)
except ldap.LDAPError, e: except ldap3.core.exceptions.LDAPException as e:
logger.warn("LDAP error: %s", e) logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -49,6 +49,9 @@ CONDITIONAL_REQUIREMENTS = {
"Jinja2>=2.8": ["Jinja2>=2.8"], "Jinja2>=2.8": ["Jinja2>=2.8"],
"bleach>=1.4.2": ["bleach>=1.4.2"], "bleach>=1.4.2": ["bleach>=1.4.2"],
}, },
"ldap": {
"ldap3>=1.0": ["ldap3>=1.0"],
},
} }

View File

@ -15,14 +15,12 @@
from synapse.http.server import respond_with_json_bytes, finish_request from synapse.http.server import respond_with_json_bytes, finish_request
from synapse.util.stringutils import random_string
from synapse.api.errors import ( from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, Codes, cs_error Codes, cs_error
) )
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from twisted.web import server, resource from twisted.web import server, resource
from twisted.internet import defer
import base64 import base64
import simplejson as json import simplejson as json
@ -50,64 +48,10 @@ class ContentRepoResource(resource.Resource):
""" """
isLeaf = True isLeaf = True
def __init__(self, hs, directory, auth, external_addr): def __init__(self, hs, directory):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.hs = hs self.hs = hs
self.directory = directory self.directory = directory
self.auth = auth
self.external_addr = external_addr.rstrip('/')
self.max_upload_size = hs.config.max_upload_size
if not os.path.isdir(self.directory):
os.mkdir(self.directory)
logger.info("ContentRepoResource : Created %s directory.",
self.directory)
@defer.inlineCallbacks
def map_request_to_name(self, request):
# auth the user
requester = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user
prefix = base64.urlsafe_b64encode(
requester.user.to_string()
).replace('=', '')
# use a random string for the main portion
main_part = random_string(24)
# suffix with a file extension if we can make one. This is nice to
# provide a hint to clients on the file information. We will also reuse
# this info to spit back the content type to the client.
suffix = ""
if request.requestHeaders.hasHeader("Content-Type"):
content_type = request.requestHeaders.getRawHeaders(
"Content-Type")[0]
suffix = "." + base64.urlsafe_b64encode(content_type)
if (content_type.split("/")[0].lower() in
["image", "video", "audio"]):
file_ext = content_type.split("/")[-1]
# be a little paranoid and only allow a-z
file_ext = re.sub("[^a-z]", "", file_ext)
suffix += "." + file_ext
file_name = prefix + main_part + suffix
file_path = os.path.join(self.directory, file_name)
logger.info("User %s is uploading a file to path %s",
request.user.user_id.to_string(),
file_path)
# keep trying to make a non-clashing file, with a sensible max attempts
attempts = 0
while os.path.exists(file_path):
main_part = random_string(24)
file_name = prefix + main_part + suffix
file_path = os.path.join(self.directory, file_name)
attempts += 1
if attempts > 25: # really? Really?
raise SynapseError(500, "Unable to create file.")
defer.returnValue(file_path)
def render_GET(self, request): def render_GET(self, request):
# no auth here on purpose, to allow anyone to view, even across home # no auth here on purpose, to allow anyone to view, even across home
@ -155,58 +99,6 @@ class ContentRepoResource(resource.Resource):
return server.NOT_DONE_YET return server.NOT_DONE_YET
def render_POST(self, request):
self._async_render(request)
return server.NOT_DONE_YET
def render_OPTIONS(self, request): def render_OPTIONS(self, request):
respond_with_json_bytes(request, 200, {}, send_cors=True) respond_with_json_bytes(request, 200, {}, send_cors=True)
return server.NOT_DONE_YET return server.NOT_DONE_YET
@defer.inlineCallbacks
def _async_render(self, request):
try:
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(
msg="Request must specify a Content-Length", code=400
)
if int(content_length) > self.max_upload_size:
raise SynapseError(
msg="Upload request body is too large",
code=413,
)
fname = yield self.map_request_to_name(request)
# TODO I have a suspicious feeling this is just going to block
with open(fname, "wb") as f:
f.write(request.content.read())
# FIXME (erikj): These should use constants.
file_name = os.path.basename(fname)
# FIXME: we can't assume what the repo's public mounted path is
# ...plus self-signed SSL won't work to remote clients anyway
# ...and we can't assume that it's SSL anyway, as we might want to
# serve it via the non-SSL listener...
url = "%s/_matrix/content/%s" % (
self.external_addr, file_name
)
respond_with_json_bytes(request, 200,
json.dumps({"content_token": url}),
send_cors=True)
except CodeMessageException as e:
logger.exception(e)
respond_with_json_bytes(request, e.code,
json.dumps(cs_exception(e)))
except Exception as e:
logger.error("Failed to store file: %s" % e)
respond_with_json_bytes(
request,
500,
json.dumps({"error": "Internal server error"}),
send_cors=True)

View File

@ -152,7 +152,7 @@ class EventPushActionsStore(SQLBaseStore):
if max_stream_ordering is not None: if max_stream_ordering is not None:
sql += " AND ep.stream_ordering <= ?" sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering) args.append(max_stream_ordering)
sql += " ORDER BY ep.stream_ordering ASC LIMIT ?" sql += " ORDER BY ep.stream_ordering DESC LIMIT ?"
args.append(limit) args.append(limit)
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
@ -176,14 +176,16 @@ class EventPushActionsStore(SQLBaseStore):
if max_stream_ordering is not None: if max_stream_ordering is not None:
sql += " AND ep.stream_ordering <= ?" sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering) args.append(max_stream_ordering)
sql += " ORDER BY ep.stream_ordering ASC" sql += " ORDER BY ep.stream_ordering DESC LIMIT ?"
args.append(limit)
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
no_read_receipt = yield self.runInteraction( no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range", get_no_receipt "get_unread_push_actions_for_user_in_range", get_no_receipt
) )
defer.returnValue([ # Make a list of dicts from the two sets of results.
notifs = [
{ {
"event_id": row[0], "event_id": row[0],
"room_id": row[1], "room_id": row[1],
@ -191,7 +193,16 @@ class EventPushActionsStore(SQLBaseStore):
"actions": json.loads(row[3]), "actions": json.loads(row[3]),
"received_ts": row[4], "received_ts": row[4],
} for row in after_read_receipt + no_read_receipt } for row in after_read_receipt + no_read_receipt
]) ]
# Now sort it so it's ordered correctly, since currently it will
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by received_ts
notifs.sort(key=lambda r: -(r['received_ts'] or 0))
# Now return the first `limit`
defer.returnValue(notifs[:limit])
@defer.inlineCallbacks @defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering): def get_time_of_last_push_action_before(self, stream_ordering):

View File

@ -56,6 +56,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.use_frozen_dicts = True config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"} config.database_config = {"name": "sqlite3"}
config.ldap_enabled = False
if "clock" not in kargs: if "clock" not in kargs:
kargs["clock"] = MockClock() kargs["clock"] = MockClock()