Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

pull/2587/head
Erik Johnston 2017-05-05 11:00:45 +01:00
commit 85999aade0
12 changed files with 222 additions and 63 deletions

View File

@ -66,6 +66,17 @@ class CodeMessageException(RuntimeError):
return cs_error(self.msg)
class MatrixCodeMessageException(CodeMessageException):
"""An error from a general matrix endpoint, eg. from a proxied Matrix API call.
Attributes:
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
"""
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
super(MatrixCodeMessageException, self).__init__(code, msg)
self.errcode = errcode
class SynapseError(CodeMessageException):
"""A base exception type for matrix errors which have an errcode and error
message (as well as an HTTP status code).

View File

@ -24,7 +24,6 @@ from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
import synapse.metrics
@ -183,15 +182,12 @@ class TransactionQueue(object):
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
users_in_room = yield self.state.get_current_user_in_room(
destinations = yield self.state.get_current_hosts_in_room(
event.room_id, latest_event_ids=[
prev_id for prev_id, _ in event.prev_events
],
)
destinations = set(
get_domain_from_id(user_id) for user_id in users_in_room
)
if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server
# then it already has the event and there is no reason to

View File

@ -18,7 +18,7 @@
from twisted.internet import defer
from synapse.api.errors import (
CodeMessageException
MatrixCodeMessageException, CodeMessageException
)
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
@ -90,6 +90,9 @@ class IdentityHandler(BaseHandler):
),
{'sid': creds['sid'], 'client_secret': client_secret}
)
except MatrixCodeMessageException as e:
logger.info("getValidated3pid failed with Matrix error: %r", e)
raise SynapseError(e.code, e.msg, e.errcode)
except CodeMessageException as e:
data = json.loads(e.msg)
@ -159,6 +162,9 @@ class IdentityHandler(BaseHandler):
params
)
defer.returnValue(data)
except MatrixCodeMessageException as e:
logger.info("Proxied requestToken failed with Matrix error: %r", e)
raise SynapseError(e.code, e.msg, e.errcode)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e
@ -193,6 +199,9 @@ class IdentityHandler(BaseHandler):
params
)
defer.returnValue(data)
except MatrixCodeMessageException as e:
logger.info("Proxied requestToken failed with Matrix error: %r", e)
raise SynapseError(e.code, e.msg, e.errcode)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e

View File

@ -16,7 +16,7 @@ from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import (
CodeMessageException, SynapseError, Codes,
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
)
from synapse.util.logcontext import preserve_context_over_fn
import synapse.metrics
@ -145,6 +145,11 @@ class SimpleHttpClient(object):
body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
else:
raise self._exceptionFromFailedRequest(response, body)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
@ -164,8 +169,11 @@ class SimpleHttpClient(object):
On a non-2xx HTTP response. The response body will be used as the
error message.
"""
body = yield self.get_raw(uri, args)
defer.returnValue(json.loads(body))
try:
body = yield self.get_raw(uri, args)
defer.returnValue(json.loads(body))
except CodeMessageException as e:
raise self._exceptionFromFailedRequest(e.code, e.msg)
@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}):
@ -246,6 +254,15 @@ class SimpleHttpClient(object):
else:
raise CodeMessageException(response.code, body)
def _exceptionFromFailedRequest(self, response, body):
try:
jsonBody = json.loads(body)
errcode = jsonBody['errcode']
error = jsonBody['error']
return MatrixCodeMessageException(response.code, error, errcode)
except (ValueError, KeyError):
return CodeMessageException(response.code, body)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.

View File

@ -144,6 +144,9 @@ class SlavedEventStore(BaseSlavedStore):
RoomMemberStore.__dict__["_get_joined_users_from_context"]
)
get_joined_hosts = DataStore.get_joined_hosts.__func__
_get_joined_hosts = RoomMemberStore.__dict__["_get_joined_hosts"]
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__

View File

@ -31,6 +31,7 @@ import logging
import hmac
from hashlib import sha1
from synapse.util.async import run_on_reactor
from synapse.util.ratelimitutils import FederationRateLimiter
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
@ -115,6 +116,45 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret))
class UsernameAvailabilityRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/available")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(UsernameAvailabilityRestServlet, self).__init__()
self.hs = hs
self.registration_handler = hs.get_handlers().registration_handler
self.ratelimiter = FederationRateLimiter(
hs.get_clock(),
# Time window of 2s
window_size=2000,
# Artificially delay requests if rate > sleep_limit/window_size
sleep_limit=1,
# Amount of artificial delay to apply
sleep_msec=1000,
# Error with 429 if more than reject_limit requests are queued
reject_limit=1,
# Allow 1 request at a time
concurrent_requests=1,
)
@defer.inlineCallbacks
def on_GET(self, request):
ip = self.hs.get_ip_from_request(request)
with self.ratelimiter.ratelimit(ip) as wait_deferred:
yield wait_deferred
body = parse_json_object_from_request(request)
assert_params_in_request(body, ['username'])
yield self.registration_handler.check_username(body['username'])
defer.returnValue((200, {"available": True}))
class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$")
@ -555,4 +595,5 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server):
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)

View File

@ -175,6 +175,17 @@ class StateHandler(object):
)
defer.returnValue(joined_users)
@defer.inlineCallbacks
def get_current_hosts_in_room(self, room_id, latest_event_ids=None):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts(
room_id, entry.state_id, entry.state
)
defer.returnValue(joined_hosts)
@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
"""Build an EventContext structure for the event.

View File

@ -60,12 +60,12 @@ class LoggingTransaction(object):
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
def call_after(self, callback, *args, **kwargs):
def call_after(self, callback, *args):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
"""
self.after_callbacks.append((callback, args, kwargs))
self.after_callbacks.append((callback, args))
def __getattr__(self, name):
return getattr(self.txn, name)
@ -319,8 +319,8 @@ class SQLBaseStore(object):
inner_func, *args, **kwargs
)
finally:
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
defer.returnValue(result)
@defer.inlineCallbacks

View File

@ -374,12 +374,6 @@ class EventsStore(SQLBaseStore):
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc_by(len(chunk))
for room_id, (_, _, new_state) in current_state_for_room.iteritems():
self.get_current_state_ids.prefill(
(room_id, ), new_state
)
for event, context in chunk:
if context.app_service:
origin_type = "local"
@ -441,10 +435,10 @@ class EventsStore(SQLBaseStore):
Assumes that we are only persisting events for one room at a time.
Returns:
3-tuple (to_delete, to_insert, new_state) where both are state dicts,
i.e. (type, state_key) -> event_id. `to_delete` are the entries to
2-tuple (to_delete, to_insert) where both are state dicts, i.e.
(type, state_key) -> event_id. `to_delete` are the entries to
first be deleted from current_state_events, `to_insert` are entries
to insert. `new_state` is the full set of state.
to insert.
May return None if there are no changes to be applied.
"""
# Now we need to work out the different state sets for
@ -551,7 +545,7 @@ class EventsStore(SQLBaseStore):
if ev_id in events_to_insert
}
defer.returnValue((to_delete, to_insert, current_state))
defer.returnValue((to_delete, to_insert))
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
@ -704,7 +698,7 @@ class EventsStore(SQLBaseStore):
def _update_current_state_txn(self, txn, state_delta_by_room):
for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert, _ = current_state_tuple
to_delete, to_insert = current_state_tuple
txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?",
[(ev_id,) for ev_id in to_delete.itervalues()],

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from collections import namedtuple
from ._base import SQLBaseStore
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.stringutils import to_ascii
@ -147,7 +148,7 @@ class RoomMemberStore(SQLBaseStore):
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
defer.returnValue(hosts)
@cached(max_entries=500000, iterable=True)
@cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id):
def f(txn):
sql = (
@ -160,7 +161,7 @@ class RoomMemberStore(SQLBaseStore):
)
txn.execute(sql, (room_id, Membership.JOIN,))
return [r[0] for r in txn]
return [to_ascii(r[0]) for r in txn]
return self.runInteraction("get_users_in_room", f)
@cached()
@ -417,25 +418,51 @@ class RoomMemberStore(SQLBaseStore):
if key[0] == EventTypes.Member
]
rows = yield self._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
retcols=['user_id', 'display_name', 'avatar_url'],
keyvalues={
"membership": Membership.JOIN,
},
batch_size=500,
desc="_get_joined_users_from_context",
# We check if we have any of the member event ids in the event cache
# before we ask the DB
event_map = self._get_events_from_cache(
member_event_ids,
allow_rejected=False,
)
users_in_room = {
to_ascii(row["user_id"]): ProfileInfo(
avatar_url=to_ascii(row["avatar_url"]),
display_name=to_ascii(row["display_name"]),
missing_member_event_ids = []
users_in_room = {}
for event_id in member_event_ids:
ev_entry = event_map.get(event_id)
if ev_entry:
if ev_entry.event.membership == Membership.JOIN:
users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
display_name=to_ascii(
ev_entry.event.content.get("displayname", None)
),
avatar_url=to_ascii(
ev_entry.event.content.get("avatar_url", None)
),
)
else:
missing_member_event_ids.append(event_id)
if missing_member_event_ids:
rows = yield self._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=missing_member_event_ids,
retcols=('user_id', 'display_name', 'avatar_url',),
keyvalues={
"membership": Membership.JOIN,
},
batch_size=500,
desc="_get_joined_users_from_context",
)
for row in rows
}
users_in_room.update({
to_ascii(row["user_id"]): ProfileInfo(
avatar_url=to_ascii(row["avatar_url"]),
display_name=to_ascii(row["display_name"]),
)
for row in rows
})
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
@ -482,6 +509,44 @@ class RoomMemberStore(SQLBaseStore):
defer.returnValue(False)
def get_joined_hosts(self, room_id, state_group, state_ids):
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_hosts(
room_id, state_group, state_ids
)
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
def _get_joined_hosts(self, room_id, state_group, current_state_ids):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None
joined_hosts = set()
for (etype, state_key), event_id in current_state_ids.items():
if etype == EventTypes.Member:
try:
host = get_domain_from_id(state_key)
except:
logger.warn("state_key not user_id: %s", state_key)
continue
if host in joined_hosts:
continue
event = yield self.get_event(event_id, allow_none=True)
if event and event.content["membership"] == Membership.JOIN:
joined_hosts.add(intern_string(host))
defer.returnValue(joined_hosts)
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(

View File

@ -227,18 +227,6 @@ class StateStore(SQLBaseStore):
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=context.state_group,
value=context.current_state_ids,
full=True,
)
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",

View File

@ -18,6 +18,7 @@ from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.stringutils import to_ascii
from . import register_cache
@ -163,10 +164,6 @@ class Cache(object):
def invalidate(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
@ -312,7 +309,7 @@ class CacheDescriptor(_CacheDescriptorBase):
iterable=self.iterable,
)
def get_cache_key(args, kwargs):
def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
the cache_key.
@ -330,13 +327,29 @@ class CacheDescriptor(_CacheDescriptorBase):
else:
yield self.arg_defaults[nm]
# By default our cache key is a tuple, but if there is only one item
# then don't bother wrapping in a tuple. This is to save memory.
if self.num_args == 1:
nm = self.arg_names[0]
def get_cache_key(args, kwargs):
if nm in kwargs:
return kwargs[nm]
elif len(args):
return args[0]
else:
return self.arg_defaults[nm]
else:
def get_cache_key(args, kwargs):
return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
cache_key = tuple(get_cache_key(args, kwargs))
cache_key = get_cache_key(args, kwargs)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
@ -363,6 +376,11 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
# If our cache_key is a string, try to convert to ascii to save
# a bit of space in large caches
if isinstance(cache_key, basestring):
cache_key = to_ascii(cache_key)
result_d = ObservableDeferred(ret, consumeErrors=True)
cache.set(cache_key, result_d, callback=invalidate_callback)
observer = result_d.observe()
@ -372,10 +390,16 @@ class CacheDescriptor(_CacheDescriptorBase):
else:
return observer
wrapped.invalidate = cache.invalidate
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val)
else:
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
wrapped.invalidate_all = cache.invalidate_all
wrapped.invalidate_many = cache.invalidate_many
wrapped.prefill = cache.prefill
wrapped.cache = cache
obj.__dict__[self.orig.__name__] = wrapped