Merge branch 'develop' into markjh/split_pusher

markjh/split_pusher
Mark Haines 2016-04-21 16:21:49 +01:00
commit 712030aeef
21 changed files with 733 additions and 659 deletions

View File

@ -25,7 +25,9 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27 tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2 $TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"} : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}

View File

@ -24,6 +24,8 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27 tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"} : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}

View File

@ -1,86 +0,0 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
TOX_BIN=$WORKSPACE/.tox/py27/bin
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PERL5LIB:=$WORKSPACE/perl5/lib/perl5}
: ${PERL_MB_OPT:=--install_base=$WORKSPACE/perl5}
: ${PERL_MM_OPT:=INSTALL_BASE=$WORKSPACE/perl5}
export PERL5LIB PERL_MB_OPT PERL_MM_OPT
./install-deps.pl
: ${PORT_BASE:=8000}
echo >&2 "Running sytest with SQLite3";
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
--python $TOX_BIN/python --all --port-base $PORT_BASE > results-sqlite3.tap
RUN_POSTGRES=""
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
if psql synapse_jenkins_$port <<< ""; then
RUN_POSTGRES="$RUN_POSTGRES:$port"
cat > localhost-$port/database.yaml << EOF
name: psycopg2
args:
database: synapse_jenkins_$port
EOF
fi
done
# Run if both postgresql databases exist
if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
echo >&2 "Running sytest with PostgreSQL";
$TOX_BIN/pip install psycopg2
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
--python $TOX_BIN/python --all --port-base $PORT_BASE > results-postgresql.tap
else
echo >&2 "Skipping running sytest with PostgreSQL, $RUN_POSTGRES"
fi
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@ -179,7 +179,8 @@ class TransportLayerClient(object):
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
retry_on_dns_fail=True, retry_on_dns_fail=False,
timeout=20000,
) )
defer.returnValue(content) defer.returnValue(content)

View File

@ -428,24 +428,31 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
defer.returnValue( """
not ( Returns:
(yield self._check_ldap_password(user_id, password)) True if the user_id successfully authenticated
or """
(yield self._check_local_password(user_id, password)) valid_ldap = yield self._check_ldap_password(user_id, password)
)) if valid_ldap:
defer.returnValue(True)
valid_local_password = yield self._check_local_password(user_id, password)
if valid_local_password:
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
try: try:
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(not self.validate_hash(password, password_hash)) defer.returnValue(self.validate_hash(password, password_hash))
except LoginError: except LoginError:
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_ldap_password(self, user_id, password): def _check_ldap_password(self, user_id, password):
if self.ldap_enabled is not True: if not self.ldap_enabled:
logger.debug("LDAP not configured") logger.debug("LDAP not configured")
defer.returnValue(False) defer.returnValue(False)

View File

@ -462,5 +462,8 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
self._context = SSL.Context(SSL.SSLv23_METHOD) self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: None) self._context.set_verify(VERIFY_NONE, lambda *_: None)
def getContext(self, hostname, port): def getContext(self, hostname=None, port=None):
return self._context return self._context
def creatorForNetloc(self, hostname, port):
return self

View File

@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json, finish_request
from synapse.api.errors import (
cs_error, Codes, SynapseError
)
from twisted.internet import defer
from twisted.protocols.basic import FileSender
from synapse.util.stringutils import is_ascii
import os
import logging
import urllib
import urlparse
logger = logging.getLogger(__name__)
def parse_media_id(request):
try:
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2]
file_name = None
if len(request.postpath) > 2:
try:
file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
except UnicodeDecodeError:
pass
return server_name, media_id, file_name
except:
raise SynapseError(
404,
"Invalid media id token %r" % (request.postpath,),
Codes.UNKNOWN,
)
def respond_404(request):
respond_with_json(
request, 404,
cs_error(
"Not found %r" % (request.postpath,),
code=Codes.NOT_FOUND,
),
send_cors=True
)
@defer.inlineCallbacks
def respond_with_file(request, media_type, file_path,
file_size=None, upload_name=None):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
if is_ascii(upload_name):
request.setHeader(
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else:
request.setHeader(
b"Content-Disposition",
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
if file_size is None:
stat = os.stat(file_path)
file_size = stat.st_size
request.setHeader(
b"Content-Length", b"%d" % (file_size,)
)
with open(file_path, "rb") as f:
yield FileSender().beginFileTransfer(f, request)
finish_request(request)
else:
respond_404(request)

View File

@ -1,460 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.http.server import respond_with_json, finish_request
from synapse.util.stringutils import random_string
from synapse.api.errors import (
cs_error, Codes, SynapseError
)
from twisted.internet import defer, threads
from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import preserve_context_over_fn
import os
import cgi
import logging
import urllib
import urlparse
logger = logging.getLogger(__name__)
def parse_media_id(request):
try:
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2]
file_name = None
if len(request.postpath) > 2:
try:
file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
except UnicodeDecodeError:
pass
return server_name, media_id, file_name
except:
raise SynapseError(
404,
"Invalid media id token %r" % (request.postpath,),
Codes.UNKNOWN,
)
class BaseMediaResource(Resource):
isLeaf = True
def __init__(self, hs, filepaths):
Resource.__init__(self)
self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
self.max_spider_size = hs.config.max_spider_size
self.filepaths = filepaths
self.version_string = hs.version_string
self.downloads = {}
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
def _respond_404(self, request):
respond_with_json(
request, 404,
cs_error(
"Not found %r" % (request.postpath,),
code=Codes.NOT_FOUND,
),
send_cors=True
)
@staticmethod
def _makedirs(filepath):
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
def _get_remote_media(self, server_name, media_id):
key = (server_name, media_id)
download = self.downloads.get(key)
if download is None:
download = self._get_remote_media_impl(server_name, media_id)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.downloads[key] = download
@download.addBoth
def callback(media_info):
del self.downloads[key]
return media_info
return download.observe()
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
if not media_info:
media_info = yield self._download_remote_file(
server_name, media_id
)
defer.returnValue(media_info)
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id):
file_id = random_string(24)
fname = self.filepaths.remote_media_filepath(
server_name, file_id
)
self._makedirs(fname)
try:
with open(fname, "wb") as f:
request_path = "/".join((
"/_matrix/media/v1/download", server_name, media_id,
))
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
max_size=self.max_upload_size,
)
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get("filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith("utf-8''"):
upload_name = upload_name_utf8[7:]
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get("filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
upload_name = upload_name_ascii
if upload_name:
upload_name = urlparse.unquote(upload_name)
try:
upload_name = upload_name.decode("utf-8")
except UnicodeDecodeError:
upload_name = None
else:
upload_name = None
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
except:
os.remove(fname)
raise
media_info = {
"media_type": media_type,
"media_length": length,
"upload_name": upload_name,
"created_ts": time_now_ms,
"filesystem_id": file_id,
}
yield self._generate_remote_thumbnails(
server_name, media_id, media_info
)
defer.returnValue(media_info)
@defer.inlineCallbacks
def _respond_with_file(self, request, media_type, file_path,
file_size=None, upload_name=None):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
if is_ascii(upload_name):
request.setHeader(
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else:
request.setHeader(
b"Content-Disposition",
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
if file_size is None:
stat = os.stat(file_path)
file_size = stat.st_size
request.setHeader(
b"Content-Length", b"%d" % (file_size,)
)
with open(file_path, "rb") as f:
yield FileSender().beginFileTransfer(f, request)
finish_request(request)
else:
self._respond_404(request)
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
t_method, t_type):
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
if t_method == "crop":
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
elif t_method == "scale":
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
else:
t_len = None
return t_len
@defer.inlineCallbacks
def _generate_local_exact_thumbnail(self, media_id, t_width, t_height,
t_method, t_type):
input_path = self.filepaths.local_media_filepath(media_id)
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
if t_len:
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
defer.returnValue(t_path)
@defer.inlineCallbacks
def _generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type):
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
if t_len:
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
defer.returnValue(t_path)
@defer.inlineCallbacks
def _generate_local_thumbnails(self, media_id, media_info):
media_type = media_info["media_type"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
input_path = self.filepaths.local_media_filepath(media_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
local_thumbnails = []
def generate_thumbnails():
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for l in local_thumbnails:
yield self.store.store_local_thumbnail(*l)
defer.returnValue({
"width": m_width,
"height": m_height,
})
@defer.inlineCallbacks
def _generate_remote_thumbnails(self, server_name, media_id, media_info):
media_type = media_info["media_type"]
file_id = media_info["filesystem_id"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
remote_thumbnails = []
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
def generate_thumbnails():
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
remote_thumbnails.append([
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
])
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
remote_thumbnails.append([
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
])
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for r in remote_thumbnails:
yield self.store.store_remote_media_thumbnail(*r)
defer.returnValue({
"width": m_width,
"height": m_height,
})

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .base_resource import BaseMediaResource, parse_media_id from ._base import parse_media_id, respond_with_file, respond_404
from twisted.web.resource import Resource
from synapse.http.server import request_handler from synapse.http.server import request_handler
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@ -24,7 +25,18 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DownloadResource(BaseMediaResource): class DownloadResource(Resource):
isLeaf = True
def __init__(self, hs, media_repo):
Resource.__init__(self)
self.filepaths = media_repo.filepaths
self.media_repo = media_repo
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.version_string = hs.version_string
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)
return NOT_DONE_YET return NOT_DONE_YET
@ -44,7 +56,7 @@ class DownloadResource(BaseMediaResource):
def _respond_local_file(self, request, media_id, name): def _respond_local_file(self, request, media_id, name):
media_info = yield self.store.get_local_media(media_id) media_info = yield self.store.get_local_media(media_id)
if not media_info: if not media_info:
self._respond_404(request) respond_404(request)
return return
media_type = media_info["media_type"] media_type = media_info["media_type"]
@ -52,14 +64,14 @@ class DownloadResource(BaseMediaResource):
upload_name = name if name else media_info["upload_name"] upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.local_media_filepath(media_id) file_path = self.filepaths.local_media_filepath(media_id)
yield self._respond_with_file( yield respond_with_file(
request, media_type, file_path, media_length, request, media_type, file_path, media_length,
upload_name=upload_name, upload_name=upload_name,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id, name): def _respond_remote_file(self, request, server_name, media_id, name):
media_info = yield self._get_remote_media(server_name, media_id) media_info = yield self.media_repo.get_remote_media(server_name, media_id)
media_type = media_info["media_type"] media_type = media_info["media_type"]
media_length = media_info["media_length"] media_length = media_info["media_length"]
@ -70,7 +82,7 @@ class DownloadResource(BaseMediaResource):
server_name, filesystem_id server_name, filesystem_id
) )
yield self._respond_with_file( yield respond_with_file(
request, media_type, file_path, media_length, request, media_type, file_path, media_length,
upload_name=upload_name, upload_name=upload_name,
) )

View File

@ -22,11 +22,395 @@ from .filepath import MediaFilePaths
from twisted.web.resource import Resource from twisted.web.resource import Resource
from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string
from twisted.internet import defer, threads
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import preserve_context_over_fn
import os
import cgi
import logging import logging
import urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MediaRepository(object):
def __init__(self, hs, filepaths):
self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
self.filepaths = filepaths
self.downloads = {}
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
@staticmethod
def _makedirs(filepath):
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
auth_user):
media_id = random_string(24)
fname = self.filepaths.local_media_filepath(media_id)
self._makedirs(fname)
# This shouldn't block for very long because the content will have
# already been uploaded at this point.
with open(fname, "wb") as f:
f.write(content)
yield self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
)
media_info = {
"media_type": media_type,
"media_length": content_length,
}
yield self._generate_local_thumbnails(media_id, media_info)
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
def get_remote_media(self, server_name, media_id):
key = (server_name, media_id)
download = self.downloads.get(key)
if download is None:
download = self._get_remote_media_impl(server_name, media_id)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.downloads[key] = download
@download.addBoth
def callback(media_info):
del self.downloads[key]
return media_info
return download.observe()
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
if not media_info:
media_info = yield self._download_remote_file(
server_name, media_id
)
defer.returnValue(media_info)
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id):
file_id = random_string(24)
fname = self.filepaths.remote_media_filepath(
server_name, file_id
)
self._makedirs(fname)
try:
with open(fname, "wb") as f:
request_path = "/".join((
"/_matrix/media/v1/download", server_name, media_id,
))
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
max_size=self.max_upload_size,
)
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get("filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith("utf-8''"):
upload_name = upload_name_utf8[7:]
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get("filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
upload_name = upload_name_ascii
if upload_name:
upload_name = urlparse.unquote(upload_name)
try:
upload_name = upload_name.decode("utf-8")
except UnicodeDecodeError:
upload_name = None
else:
upload_name = None
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
except:
os.remove(fname)
raise
media_info = {
"media_type": media_type,
"media_length": length,
"upload_name": upload_name,
"created_ts": time_now_ms,
"filesystem_id": file_id,
}
yield self._generate_remote_thumbnails(
server_name, media_id, media_info
)
defer.returnValue(media_info)
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
t_method, t_type):
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
if t_method == "crop":
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
elif t_method == "scale":
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
else:
t_len = None
return t_len
@defer.inlineCallbacks
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
t_method, t_type):
input_path = self.filepaths.local_media_filepath(media_id)
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
if t_len:
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
defer.returnValue(t_path)
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type):
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
if t_len:
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
defer.returnValue(t_path)
@defer.inlineCallbacks
def _generate_local_thumbnails(self, media_id, media_info):
media_type = media_info["media_type"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
input_path = self.filepaths.local_media_filepath(media_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
local_thumbnails = []
def generate_thumbnails():
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for l in local_thumbnails:
yield self.store.store_local_thumbnail(*l)
defer.returnValue({
"width": m_width,
"height": m_height,
})
@defer.inlineCallbacks
def _generate_remote_thumbnails(self, server_name, media_id, media_info):
media_type = media_info["media_type"]
file_id = media_info["filesystem_id"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
remote_thumbnails = []
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
def generate_thumbnails():
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
remote_thumbnails.append([
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
])
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
remote_thumbnails.append([
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
])
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for r in remote_thumbnails:
yield self.store.store_remote_media_thumbnail(*r)
defer.returnValue({
"width": m_width,
"height": m_height,
})
class MediaRepositoryResource(Resource): class MediaRepositoryResource(Resource):
"""File uploading and downloading. """File uploading and downloading.
@ -75,9 +459,12 @@ class MediaRepositoryResource(Resource):
def __init__(self, hs): def __init__(self, hs):
Resource.__init__(self) Resource.__init__(self)
filepaths = MediaFilePaths(hs.config.media_store_path) filepaths = MediaFilePaths(hs.config.media_store_path)
self.putChild("upload", UploadResource(hs, filepaths))
self.putChild("download", DownloadResource(hs, filepaths)) media_repo = MediaRepository(hs, filepaths)
self.putChild("thumbnail", ThumbnailResource(hs, filepaths))
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))
self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
self.putChild("identicon", IdenticonResource()) self.putChild("identicon", IdenticonResource())
if hs.config.url_preview_enabled: if hs.config.url_preview_enabled:
self.putChild("preview_url", PreviewUrlResource(hs, filepaths)) self.putChild("preview_url", PreviewUrlResource(hs, media_repo))

View File

@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .base_resource import BaseMediaResource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, SynapseError, Codes,
@ -41,12 +40,22 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PreviewUrlResource(BaseMediaResource): class PreviewUrlResource(Resource):
isLeaf = True isLeaf = True
def __init__(self, hs, filepaths): def __init__(self, hs, media_repo):
BaseMediaResource.__init__(self, hs, filepaths) Resource.__init__(self)
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.version_string = hs.version_string
self.filepaths = media_repo.filepaths
self.max_spider_size = hs.config.max_spider_size
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.client = SpiderHttpClient(hs) self.client = SpiderHttpClient(hs)
self.media_repo = media_repo
if hasattr(hs.config, "url_preview_url_blacklist"): if hasattr(hs.config, "url_preview_url_blacklist"):
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
@ -156,7 +165,7 @@ class PreviewUrlResource(BaseMediaResource):
logger.debug("got media_info of '%s'" % media_info) logger.debug("got media_info of '%s'" % media_info)
if self._is_media(media_info['media_type']): if self._is_media(media_info['media_type']):
dims = yield self._generate_local_thumbnails( dims = yield self.media_repo._generate_local_thumbnails(
media_info['filesystem_id'], media_info media_info['filesystem_id'], media_info
) )
@ -179,23 +188,27 @@ class PreviewUrlResource(BaseMediaResource):
elif self._is_html(media_info['media_type']): elif self._is_html(media_info['media_type']):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM # TODO: somehow stop a big HTML tree from exploding synapse's RAM
from lxml import html from lxml import etree
file = open(media_info['filename'])
body = file.read()
file.close()
# clobber the encoding from the content-type, or default to utf-8
# XXX: this overrides any <meta/> or XML charset headers in the body
# which may pose problems, but so far seems to work okay.
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
encoding = match.group(1) if match else "utf-8"
try: try:
tree = html.parse(media_info['filename']) parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body, parser)
og = yield self._calc_og(tree, media_info, requester) og = yield self._calc_og(tree, media_info, requester)
except UnicodeDecodeError: except UnicodeDecodeError:
# XXX: evil evil bodge # blindly try decoding the body as utf-8, which seems to fix
# Empirically, sites like google.com mix Latin-1 and utf-8 # the charset mismatches on https://google.com
# encodings in the same page. The rogue Latin-1 characters parser = etree.HTMLParser(recover=True, encoding=encoding)
# cause lxml to choke with a UnicodeDecodeError, so if we tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
# see this we go and do a manual decode of the HTML before
# handing it to lxml as utf-8 encoding, counter-intuitively,
# which seems to make it happier...
file = open(media_info['filename'])
body = file.read()
file.close()
tree = html.fromstring(body.decode('utf-8', 'ignore'))
og = yield self._calc_og(tree, media_info, requester) og = yield self._calc_og(tree, media_info, requester)
else: else:
@ -287,7 +300,7 @@ class PreviewUrlResource(BaseMediaResource):
if self._is_media(image_info['media_type']): if self._is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images # TODO: make sure we don't choke on white-on-transparent images
dims = yield self._generate_local_thumbnails( dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info image_info['filesystem_id'], image_info
) )
if dims: if dims:
@ -358,7 +371,7 @@ class PreviewUrlResource(BaseMediaResource):
file_id = random_string(24) file_id = random_string(24)
fname = self.filepaths.local_media_filepath(file_id) fname = self.filepaths.local_media_filepath(file_id)
self._makedirs(fname) self.media_repo._makedirs(fname)
try: try:
with open(fname, "wb") as f: with open(fname, "wb") as f:

View File

@ -14,7 +14,8 @@
# limitations under the License. # limitations under the License.
from .base_resource import BaseMediaResource, parse_media_id from ._base import parse_media_id, respond_404, respond_with_file
from twisted.web.resource import Resource
from synapse.http.servlet import parse_string, parse_integer from synapse.http.servlet import parse_string, parse_integer
from synapse.http.server import request_handler from synapse.http.server import request_handler
@ -26,9 +27,19 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThumbnailResource(BaseMediaResource): class ThumbnailResource(Resource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo):
Resource.__init__(self)
self.store = hs.get_datastore()
self.filepaths = media_repo.filepaths
self.media_repo = media_repo
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
self.version_string = hs.version_string
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)
return NOT_DONE_YET return NOT_DONE_YET
@ -69,12 +80,12 @@ class ThumbnailResource(BaseMediaResource):
media_info = yield self.store.get_local_media(media_id) media_info = yield self.store.get_local_media(media_id)
if not media_info: if not media_info:
self._respond_404(request) respond_404(request)
return return
# if media_info["media_type"] == "image/svg+xml": # if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.local_media_filepath(media_id) # file_path = self.filepaths.local_media_filepath(media_id)
# yield self._respond_with_file(request, media_info["media_type"], file_path) # yield respond_with_file(request, media_info["media_type"], file_path)
# return # return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
@ -91,7 +102,7 @@ class ThumbnailResource(BaseMediaResource):
file_path = self.filepaths.local_media_thumbnail( file_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method, media_id, t_width, t_height, t_type, t_method,
) )
yield self._respond_with_file(request, t_type, file_path) yield respond_with_file(request, t_type, file_path)
else: else:
yield self._respond_default_thumbnail( yield self._respond_default_thumbnail(
@ -105,12 +116,12 @@ class ThumbnailResource(BaseMediaResource):
media_info = yield self.store.get_local_media(media_id) media_info = yield self.store.get_local_media(media_id)
if not media_info: if not media_info:
self._respond_404(request) respond_404(request)
return return
# if media_info["media_type"] == "image/svg+xml": # if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.local_media_filepath(media_id) # file_path = self.filepaths.local_media_filepath(media_id)
# yield self._respond_with_file(request, media_info["media_type"], file_path) # yield respond_with_file(request, media_info["media_type"], file_path)
# return # return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
@ -124,18 +135,18 @@ class ThumbnailResource(BaseMediaResource):
file_path = self.filepaths.local_media_thumbnail( file_path = self.filepaths.local_media_thumbnail(
media_id, desired_width, desired_height, desired_type, desired_method, media_id, desired_width, desired_height, desired_type, desired_method,
) )
yield self._respond_with_file(request, desired_type, file_path) yield respond_with_file(request, desired_type, file_path)
return return
logger.debug("We don't have a local thumbnail of that size. Generating") logger.debug("We don't have a local thumbnail of that size. Generating")
# Okay, so we generate one. # Okay, so we generate one.
file_path = yield self._generate_local_exact_thumbnail( file_path = yield self.media_repo.generate_local_exact_thumbnail(
media_id, desired_width, desired_height, desired_method, desired_type media_id, desired_width, desired_height, desired_method, desired_type
) )
if file_path: if file_path:
yield self._respond_with_file(request, desired_type, file_path) yield respond_with_file(request, desired_type, file_path)
else: else:
yield self._respond_default_thumbnail( yield self._respond_default_thumbnail(
request, media_info, desired_width, desired_height, request, media_info, desired_width, desired_height,
@ -146,11 +157,11 @@ class ThumbnailResource(BaseMediaResource):
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
desired_width, desired_height, desired_width, desired_height,
desired_method, desired_type): desired_method, desired_type):
media_info = yield self._get_remote_media(server_name, media_id) media_info = yield self.media_repo.get_remote_media(server_name, media_id)
# if media_info["media_type"] == "image/svg+xml": # if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.remote_media_filepath(server_name, media_id) # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
# yield self._respond_with_file(request, media_info["media_type"], file_path) # yield respond_with_file(request, media_info["media_type"], file_path)
# return # return
thumbnail_infos = yield self.store.get_remote_media_thumbnails( thumbnail_infos = yield self.store.get_remote_media_thumbnails(
@ -170,19 +181,19 @@ class ThumbnailResource(BaseMediaResource):
server_name, file_id, desired_width, desired_height, server_name, file_id, desired_width, desired_height,
desired_type, desired_method, desired_type, desired_method,
) )
yield self._respond_with_file(request, desired_type, file_path) yield respond_with_file(request, desired_type, file_path)
return return
logger.debug("We don't have a local thumbnail of that size. Generating") logger.debug("We don't have a local thumbnail of that size. Generating")
# Okay, so we generate one. # Okay, so we generate one.
file_path = yield self._generate_remote_exact_thumbnail( file_path = yield self.media_repo.generate_remote_exact_thumbnail(
server_name, file_id, media_id, desired_width, server_name, file_id, media_id, desired_width,
desired_height, desired_method, desired_type desired_height, desired_method, desired_type
) )
if file_path: if file_path:
yield self._respond_with_file(request, desired_type, file_path) yield respond_with_file(request, desired_type, file_path)
else: else:
yield self._respond_default_thumbnail( yield self._respond_default_thumbnail(
request, media_info, desired_width, desired_height, request, media_info, desired_width, desired_height,
@ -194,11 +205,11 @@ class ThumbnailResource(BaseMediaResource):
height, method, m_type): height, method, m_type):
# TODO: Don't download the whole remote file # TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead. # We should proxy the thumbnail from the remote server instead.
media_info = yield self._get_remote_media(server_name, media_id) media_info = yield self.media_repo.get_remote_media(server_name, media_id)
# if media_info["media_type"] == "image/svg+xml": # if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.remote_media_filepath(server_name, media_id) # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
# yield self._respond_with_file(request, media_info["media_type"], file_path) # yield respond_with_file(request, media_info["media_type"], file_path)
# return # return
thumbnail_infos = yield self.store.get_remote_media_thumbnails( thumbnail_infos = yield self.store.get_remote_media_thumbnails(
@ -219,7 +230,7 @@ class ThumbnailResource(BaseMediaResource):
file_path = self.filepaths.remote_media_thumbnail( file_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method, server_name, file_id, t_width, t_height, t_type, t_method,
) )
yield self._respond_with_file(request, t_type, file_path, t_length) yield respond_with_file(request, t_type, file_path, t_length)
else: else:
yield self._respond_default_thumbnail( yield self._respond_default_thumbnail(
request, media_info, width, height, method, m_type, request, media_info, width, height, method, m_type,
@ -245,7 +256,7 @@ class ThumbnailResource(BaseMediaResource):
"_default", "_default", "_default", "_default",
) )
if not thumbnail_infos: if not thumbnail_infos:
self._respond_404(request) respond_404(request)
return return
thumbnail_info = self._select_thumbnail( thumbnail_info = self._select_thumbnail(
@ -261,7 +272,7 @@ class ThumbnailResource(BaseMediaResource):
file_path = self.filepaths.default_thumbnail( file_path = self.filepaths.default_thumbnail(
top_level_type, sub_type, t_width, t_height, t_type, t_method, top_level_type, sub_type, t_width, t_height, t_type, t_method,
) )
yield self.respond_with_file(request, t_type, file_path, t_length) yield respond_with_file(request, t_type, file_path, t_length)
def _select_thumbnail(self, desired_width, desired_height, desired_method, def _select_thumbnail(self, desired_width, desired_height, desired_method,
desired_type, thumbnail_infos): desired_type, thumbnail_infos):

View File

@ -15,20 +15,33 @@
from synapse.http.server import respond_with_json, request_handler from synapse.http.server import respond_with_json, request_handler
from synapse.util.stringutils import random_string
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer from twisted.internet import defer
from .base_resource import BaseMediaResource from twisted.web.resource import Resource
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UploadResource(BaseMediaResource): class UploadResource(Resource):
isLeaf = True
def __init__(self, hs, media_repo):
Resource.__init__(self)
self.media_repo = media_repo
self.filepaths = media_repo.filepaths
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.auth = hs.get_auth()
self.max_upload_size = hs.config.max_upload_size
self.version_string = hs.version_string
def render_POST(self, request): def render_POST(self, request):
self._async_render_POST(request) self._async_render_POST(request)
return NOT_DONE_YET return NOT_DONE_YET
@ -37,36 +50,6 @@ class UploadResource(BaseMediaResource):
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET return NOT_DONE_YET
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
auth_user):
media_id = random_string(24)
fname = self.filepaths.local_media_filepath(media_id)
self._makedirs(fname)
# This shouldn't block for very long because the content will have
# already been uploaded at this point.
with open(fname, "wb") as f:
f.write(content)
yield self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
)
media_info = {
"media_type": media_type,
"media_length": content_length,
}
yield self._generate_local_thumbnails(media_id, media_info)
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_POST(self, request): def _async_render_POST(self, request):
@ -108,7 +91,7 @@ class UploadResource(BaseMediaResource):
# disposition = headers.getRawHeaders("Content-Disposition")[0] # disposition = headers.getRawHeaders("Content-Disposition")[0]
# TODO(markjh): parse content-dispostion # TODO(markjh): parse content-dispostion
content_uri = yield self.create_content( content_uri = yield self.media_repo.create_content(
media_type, upload_name, request.content.read(), media_type, upload_name, request.content.read(),
content_length, requester.user content_length, requester.user
) )

View File

@ -214,7 +214,7 @@ class StateHandler(object):
if self._state_cache is not None: if self._state_cache is not None:
cache = self._state_cache.get(group_names, None) cache = self._state_cache.get(group_names, None)
if cache and cache.state_group: if cache:
cache.ts = self.clock.time_msec() cache.ts = self.clock.time_msec()
event_dict = yield self.store.get_events(cache.state.values()) event_dict = yield self.store.get_events(cache.state.values())
@ -230,22 +230,34 @@ class StateHandler(object):
(cache.state_group, state, prev_states) (cache.state_group, state, prev_states)
) )
logger.info("Resolving state for %s with %d groups", room_id, len(state_groups))
new_state, prev_states = self._resolve_events( new_state, prev_states = self._resolve_events(
state_groups.values(), event_type, state_key state_groups.values(), event_type, state_key
) )
state_group = None
new_state_event_ids = frozenset(e.event_id for e in new_state.values())
for sg, events in state_groups.items():
if new_state_event_ids == frozenset(e.event_id for e in events):
state_group = sg
break
if self._state_cache is not None: if self._state_cache is not None:
cache = _StateCacheEntry( cache = _StateCacheEntry(
state={key: event.event_id for key, event in new_state.items()}, state={key: event.event_id for key, event in new_state.items()},
state_group=None, state_group=state_group,
ts=self.clock.time_msec() ts=self.clock.time_msec()
) )
self._state_cache[group_names] = cache self._state_cache[group_names] = cache
defer.returnValue((None, new_state, prev_states)) defer.returnValue((state_group, new_state, prev_states))
def resolve_events(self, state_sets, event): def resolve_events(self, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
if event.is_state(): if event.is_state():
return self._resolve_events( return self._resolve_events(
state_sets, event.type, event.state_key state_sets, event.type, event.state_key

View File

@ -174,6 +174,12 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results] return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f) return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, lru=True, max_entries=1000)
def _get_state_group_from_group(self, group, types):
raise NotImplementedError()
@cachedList(cached_method_name="_get_state_group_from_group",
list_name="groups", num_args=2, inlineCallbacks=True)
def _get_state_groups_from_groups(self, groups, types): def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id) """Returns dictionary state_group -> (dict of (type, state_key) -> event id)
""" """
@ -201,18 +207,23 @@ class StateStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
results = {} results = {group: {} for group in groups}
for row in rows: for row in rows:
key = (row["type"], row["state_key"]) key = (row["type"], row["state_key"])
results.setdefault(row["state_group"], {})[key] = row["event_id"] results[row["state_group"]][key] = row["event_id"]
return results return results
results = {}
chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)] chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
for chunk in chunks: for chunk in chunks:
return self.runInteraction( res = yield self.runInteraction(
"_get_state_groups_from_groups", "_get_state_groups_from_groups",
f, chunk f, chunk
) )
results.update(res)
defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_events(self, event_ids, types): def get_state_for_events(self, event_ids, types):
@ -359,6 +370,8 @@ class StateStore(SQLBaseStore):
a `state_key` of None matches all state_keys. If `types` is None then a `state_key` of None matches all state_keys. If `types` is None then
all events are returned. all events are returned.
""" """
if types:
types = frozenset(types)
results = {} results = {}
missing_groups = [] missing_groups = []
if types is not None: if types is not None:

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.errors import SynapseError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task from twisted.internet import defer, reactor, task
@ -80,7 +81,7 @@ class Clock(object):
def timed_out_fn(): def timed_out_fn():
try: try:
ret_deferred.errback(RuntimeError("Timed out")) ret_deferred.errback(SynapseError(504, "Timed out"))
except: except:
pass pass

View File

@ -50,7 +50,7 @@ block_db_txn_duration = metrics.register_distribution(
class Measure(object): class Measure(object):
__slots__ = [ __slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime", "clock", "name", "start_context", "start", "new_context", "ru_utime",
"ru_stime", "db_txn_count", "db_txn_duration" "ru_stime", "db_txn_count", "db_txn_duration", "created_context"
] ]
def __init__(self, clock, name): def __init__(self, clock, name):
@ -58,14 +58,20 @@ class Measure(object):
self.name = name self.name = name
self.start_context = None self.start_context = None
self.start = None self.start = None
self.created_context = False
def __enter__(self): def __enter__(self):
self.start = self.clock.time_msec() self.start = self.clock.time_msec()
self.start_context = LoggingContext.current_context() self.start_context = LoggingContext.current_context()
if self.start_context: if not self.start_context:
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() logger.warn("Entered Measure without log context: %s", self.name)
self.db_txn_count = self.start_context.db_txn_count self.start_context = LoggingContext("Measure")
self.db_txn_duration = self.start_context.db_txn_duration self.start_context.__enter__()
self.created_context = True
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
self.db_txn_count = self.start_context.db_txn_count
self.db_txn_duration = self.start_context.db_txn_duration
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None or not self.start_context: if exc_type is not None or not self.start_context:
@ -91,7 +97,12 @@ class Measure(object):
block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name) block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name) block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name) block_db_txn_count.inc_by(
context.db_txn_count - self.db_txn_count, self.name
)
block_db_txn_duration.inc_by( block_db_txn_duration.inc_by(
context.db_txn_duration - self.db_txn_duration, self.name context.db_txn_duration - self.db_txn_duration, self.name
) )
if self.created_context:
self.start_context.__exit__(exc_type, exc_val, exc_tb)

View File

@ -15,8 +15,6 @@
from twisted.internet import defer from twisted.internet import defer
from tests import unittest from tests import unittest
from synapse.replication.slave.storage.events import SlavedEventStore
from mock import Mock, NonCallableMock from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
from synapse.replication.resource import ReplicationResource from synapse.replication.resource import ReplicationResource
@ -38,7 +36,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.replication = ReplicationResource(self.hs) self.replication = ReplicationResource(self.hs)
self.master_store = self.hs.get_datastore() self.master_store = self.hs.get_datastore()
self.slaved_store = SlavedEventStore(self.hs.get_db_conn(), self.hs) self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0 self.event_id = 0
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -16,6 +16,7 @@ from ._base import BaseSlavedStoreTestCase
from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from twisted.internet import defer from twisted.internet import defer
@ -43,6 +44,8 @@ def patch__eq__(cls):
class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedEventStore
def setUp(self): def setUp(self):
# Patch up the equality operator for events so that we can check # Patch up the equality operator for events so that we can check
# whether lists of events match using assertEquals # whether lists of events match using assertEquals
@ -251,6 +254,18 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict()) redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted) yield self.check("get_event", [msg.event_id], redacted)
@defer.inlineCallbacks
def test_invites(self):
yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
event = yield self.persist(
type="m.room.member", key=USER_ID_2, membership="invite"
)
yield self.replicate()
yield self.check("get_invited_rooms_for_user", [USER_ID_2], [RoomsForUser(
ROOM_ID, USER_ID, "invite", event.event_id,
event.internal_metadata.stream_ordering
)])
event_id = 0 event_id = 0
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -0,0 +1,39 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStoreTestCase
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from twisted.internet import defer
USER_ID = "@feeling:blue"
ROOM_ID = "!room:blue"
EVENT_ID = "$event:blue"
class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedReceiptsStore
@defer.inlineCallbacks
def test_receipt(self):
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
yield self.master_store.insert_receipt(
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
)
yield self.replicate()
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {
ROOM_ID: EVENT_ID
})

View File

@ -140,13 +140,13 @@ class StateTestCase(unittest.TestCase):
"add_event_hashes", "add_event_hashes",
] ]
) )
hs = Mock(spec=[ hs = Mock(spec_set=[
"get_datastore", "get_auth", "get_state_handler", "get_clock", "get_datastore", "get_auth", "get_state_handler", "get_clock",
]) ])
hs.get_datastore.return_value = self.store hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None hs.get_state_handler.return_value = None
hs.get_auth.return_value = Auth(hs)
hs.get_clock.return_value = MockClock() hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0