Merge branch 'erikj/filter_refactor' into erikj/search

pull/324/head
Erik Johnston 2015-10-20 16:57:51 +01:00
commit 44e2933bf8
19 changed files with 609 additions and 341 deletions

View File

@ -14,7 +14,8 @@
# limitations under the License. # limitations under the License.
"""This module contains classes for authenticating the user.""" """This module contains classes for authenticating the user."""
from nacl.exceptions import BadSignatureError from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer from twisted.internet import defer
@ -26,7 +27,6 @@ from synapse.util import third_party_invites
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
import nacl.signing
import pymacaroons import pymacaroons
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -308,7 +308,11 @@ class Auth(object):
) )
if Membership.JOIN != membership: if Membership.JOIN != membership:
# JOIN is the only action you can perform if you're not in the room if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined if not caller_in_room: # caller isn't joined
raise AuthError( raise AuthError(
403, 403,
@ -416,16 +420,23 @@ class Auth(object):
key_validity_url key_validity_url
) )
return False return False
for _, signature_block in join_third_party_invite["signatures"].items(): signed = join_third_party_invite["signed"]
if signed["mxid"] != event.user_id:
return False
if signed["token"] != token:
return False
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items(): for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"): if not key_name.startswith("ed25519:"):
return False return False
verify_key = nacl.signing.VerifyKey(decode_base64(public_key)) verify_key = decode_verify_key_bytes(
signature = decode_base64(encoded_signature) key_name,
verify_key.verify(token, signature) decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
return True return True
return False return False
except (KeyError, BadSignatureError,): except (KeyError, SignatureVerifyException,):
return False return False
def _get_power_level_event(self, auth_events): def _get_power_level_event(self, auth_events):

View File

@ -24,7 +24,7 @@ class Filtering(object):
def get_user_filter(self, user_localpart, filter_id): def get_user_filter(self, user_localpart, filter_id):
result = self.store.get_user_filter(user_localpart, filter_id) result = self.store.get_user_filter(user_localpart, filter_id)
result.addCallback(Filter) result.addCallback(FilterCollection)
return result return result
def add_user_filter(self, user_localpart, user_filter): def add_user_filter(self, user_localpart, user_filter):
@ -131,125 +131,87 @@ class Filtering(object):
raise SynapseError(400, "Bad bundle_updates: expected bool.") raise SynapseError(400, "Bad bundle_updates: expected bool.")
class FilterCollection(object):
def __init__(self, filter_json):
self.filter_json = filter_json
self.room_timeline_filter = Filter(
self.filter_json.get("room", {}).get("timeline", {})
)
self.room_state_filter = Filter(
self.filter_json.get("room", {}).get("state", {})
)
self.room_ephemeral_filter = Filter(
self.filter_json.get("room", {}).get("ephemeral", {})
)
self.presence_filter = Filter(
self.filter_json.get("presence", {})
)
def timeline_limit(self):
return self.room_timeline_filter.limit()
def presence_limit(self):
return self.presence_filter.limit()
def ephemeral_limit(self):
return self.room_ephemeral_filter.limit()
def filter_presence(self, events):
return self.presence_filter.filter(events)
def filter_room_state(self, events):
return self.room_state_filter.filter(events)
def filter_room_timeline(self, events):
return self.room_timeline_filter.filter(events)
def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(events)
class Filter(object): class Filter(object):
def __init__(self, filter_json): def __init__(self, filter_json):
self.filter_json = filter_json self.filter_json = filter_json
def timeline_limit(self): def check(self, event):
return self.filter_json.get("room", {}).get("timeline", {}).get("limit", 10) """Checks whether the filter matches the given event.
def presence_limit(self):
return self.filter_json.get("presence", {}).get("limit", 10)
def ephemeral_limit(self):
return self.filter_json.get("room", {}).get("ephemeral", {}).get("limit", 10)
def filter_presence(self, events):
return self._filter_on_key(events, ["presence"])
def filter_room_state(self, events):
return self._filter_on_key(events, ["room", "state"])
def filter_room_timeline(self, events):
return self._filter_on_key(events, ["room", "timeline"])
def filter_room_ephemeral(self, events):
return self._filter_on_key(events, ["room", "ephemeral"])
def _filter_on_key(self, events, keys):
filter_json = self.filter_json
if not filter_json:
return events
try:
# extract the right definition from the filter
definition = filter_json
for key in keys:
definition = definition[key]
return self._filter_with_definition(events, definition)
except KeyError:
# return all events if definition isn't specified.
return events
def _filter_with_definition(self, events, definition):
return [e for e in events if self._passes_definition(definition, e)]
def _passes_definition(self, definition, event):
"""Check if the event passes the filter definition
Args:
definition(dict): The filter definition to check against
event(dict or Event): The event to check
Returns: Returns:
True if the event passes the filter in the definition bool: True if the event matches
""" """
if type(event) is dict: literal_keys = {
room_id = event.get("room_id") "rooms": lambda v: event.room_id == v,
sender = event.get("sender") "senders": lambda v: event.sender == v,
event_type = event["type"] "types": lambda v: _matches_wildcard(event.type, v)
else: }
room_id = getattr(event, "room_id", None)
sender = getattr(event, "sender", None)
event_type = event.type
return self._event_passes_definition(
definition, room_id, sender, event_type
)
def _event_passes_definition(self, definition, room_id, sender, for name, match_func in literal_keys.items():
event_type): not_name = "not_%s" % (name,)
"""Check if the event passes through the given definition. disallowed_values = self.filter_json.get(not_name, [])
if any(map(match_func, disallowed_values)):
Args:
definition(dict): The definition to check against.
room_id(str): The id of the room this event is in or None.
sender(str): The sender of the event
event_type(str): The type of the event.
Returns:
True if the event passes through the filter.
"""
# Algorithm notes:
# For each key in the definition, check the event meets the criteria:
# * For types: Literal match or prefix match (if ends with wildcard)
# * For senders/rooms: Literal match only
# * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
# and 'not_types' then it is treated as only being in 'not_types')
# room checks
if room_id is not None:
allow_rooms = definition.get("rooms", None)
reject_rooms = definition.get("not_rooms", None)
if reject_rooms and room_id in reject_rooms:
return False
if allow_rooms and room_id not in allow_rooms:
return False return False
# sender checks allowed_values = self.filter_json.get(name, None)
if sender is not None: if allowed_values is not None:
allow_senders = definition.get("senders", None) if not any(map(match_func, allowed_values)):
reject_senders = definition.get("not_senders", None)
if reject_senders and sender in reject_senders:
return False
if allow_senders and sender not in allow_senders:
return False
# type checks
if "not_types" in definition:
for def_type in definition["not_types"]:
if self._event_matches_type(event_type, def_type):
return False return False
if "types" in definition:
included = False
for def_type in definition["types"]:
if self._event_matches_type(event_type, def_type):
included = True
break
if not included:
return False
return True return True
def _event_matches_type(self, event_type, def_type): def filter(self, events):
if def_type.endswith("*"): return filter(self.check, events)
type_prefix = def_type[:-1]
return event_type.startswith(type_prefix) def limit(self):
else: return self.filter_json.get("limit", 10)
return event_type == def_type
def _matches_wildcard(actual_value, filter_value):
if filter_value.endswith("*"):
type_prefix = filter_value[:-1]
return actual_value.startswith(type_prefix)
else:
return actual_value == filter_value

View File

@ -33,6 +33,7 @@ class RegistrationConfig(Config):
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key") self.macaroon_secret_key = config.get("macaroon_secret_key")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
def default_config(self, **kwargs): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
@ -48,6 +49,11 @@ class RegistrationConfig(Config):
registration_shared_secret: "%(registration_shared_secret)s" registration_shared_secret: "%(registration_shared_secret)s"
macaroon_secret_key: "%(macaroon_secret_key)s" macaroon_secret_key: "%(macaroon_secret_key)s"
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12.
bcrypt_rounds: 12
""" % locals() """ % locals()
def add_arguments(self, parser): def add_arguments(self, parser):

View File

@ -66,7 +66,6 @@ def prune_event(event):
"users_default", "users_default",
"events", "events",
"events_default", "events_default",
"events_default",
"state_default", "state_default",
"ban", "ban",
"kick", "kick",

View File

@ -17,6 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
from .federation_base import FederationBase from .federation_base import FederationBase
from synapse.api.constants import Membership
from .units import Edu from .units import Edu
from synapse.api.errors import ( from synapse.api.errors import (
@ -357,7 +358,34 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth) defer.returnValue(signed_auth)
@defer.inlineCallbacks @defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id, content): def make_membership_event(self, destinations, room_id, user_id, membership, content):
"""
Creates an m.room.member event, with context, without participating in the room.
Does so by asking one of the already participating servers to create an
event with proper context.
Note that this does not append any events to any graphs.
Args:
destinations (str): Candidate homeservers which are probably
participating in the room.
room_id (str): The room in which the event will happen.
user_id (str): The user whose membership is being evented.
membership (str): The "membership" property of the event. Must be
one of "join" or "leave".
content (object): Any additional data to put into the content field
of the event.
Return:
A tuple of (origin (str), event (object)) where origin is the remote
homeserver which generated the event.
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
for destination in destinations: for destination in destinations:
if destination == self.server_name: if destination == self.server_name:
continue continue
@ -368,13 +396,13 @@ class FederationClient(FederationBase):
content["third_party_invite"] content["third_party_invite"]
) )
try: try:
ret = yield self.transport_layer.make_join( ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, args destination, room_id, user_id, membership, args
) )
pdu_dict = ret["event"] pdu_dict = ret["event"]
logger.debug("Got response to make_join: %s", pdu_dict) logger.debug("Got response to make_%s: %s", membership, pdu_dict)
defer.returnValue( defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict)) (destination, self.event_from_pdu_json(pdu_dict))
@ -384,8 +412,8 @@ class FederationClient(FederationBase):
raise raise
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to make_join via %s: %s", "Failed to make_%s via %s: %s",
destination, e.message membership, destination, e.message
) )
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@ -491,6 +519,33 @@ class FederationClient(FederationBase):
defer.returnValue(pdu) defer.returnValue(pdu)
@defer.inlineCallbacks
def send_leave(self, destinations, pdu):
for destination in destinations:
if destination == self.server_name:
continue
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_leave(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_leave via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks @defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth): def query_auth(self, destination, room_id, event_id, local_auth):
""" """

View File

@ -267,6 +267,20 @@ class FederationServer(FederationBase):
], ],
})) }))
@defer.inlineCallbacks
def on_make_leave_request(self, room_id, user_id):
pdu = yield self.handler.on_make_leave_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {}))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id): def on_event_auth(self, origin, room_id, event_id):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -160,8 +161,14 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_join(self, destination, room_id, user_id, args={}): def make_membership_event(self, destination, room_id, user_id, membership, args={}):
path = PREFIX + "/make_join/%s/%s" % (room_id, user_id) valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id)
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
@ -185,6 +192,19 @@ class TransportLayerClient(object):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def send_leave(self, destination, room_id, event_id, content):
path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_invite(self, destination, room_id, event_id, content): def send_invite(self, destination, room_id, event_id, content):

View File

@ -296,6 +296,24 @@ class FederationMakeJoinServlet(BaseFederationServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class FederationMakeLeaveServlet(BaseFederationServlet):
PATH = "/make_leave/([^/]*)/([^/]*)"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
content = yield self.handler.on_make_leave_request(context, user_id)
defer.returnValue((200, content))
class FederationSendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/([^/]*)/([^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id, txid):
content = yield self.handler.on_send_leave_request(origin, content)
defer.returnValue((200, content))
class FederationEventAuthServlet(BaseFederationServlet): class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/([^/]*)/([^/]*)" PATH = "/event_auth/([^/]*)/([^/]*)"
@ -385,8 +403,10 @@ SERVLET_CLASSES = (
FederationBackfillServlet, FederationBackfillServlet,
FederationQueryServlet, FederationQueryServlet,
FederationMakeJoinServlet, FederationMakeJoinServlet,
FederationMakeLeaveServlet,
FederationEventServlet, FederationEventServlet,
FederationSendJoinServlet, FederationSendJoinServlet,
FederationSendLeaveServlet,
FederationInviteServlet, FederationInviteServlet,
FederationQueryAuthServlet, FederationQueryAuthServlet,
FederationGetMissingEventsServlet, FederationGetMissingEventsServlet,

View File

@ -44,6 +44,7 @@ class AuthHandler(BaseHandler):
LoginType.EMAIL_IDENTITY: self._check_email_identity, LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.DUMMY: self._check_dummy_auth, LoginType.DUMMY: self._check_dummy_auth,
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {} self.sessions = {}
@defer.inlineCallbacks @defer.inlineCallbacks
@ -432,7 +433,7 @@ class AuthHandler(BaseHandler):
Returns: Returns:
Hashed password (str). Hashed password (str).
""" """
return bcrypt.hashpw(password, bcrypt.gensalt()) return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
def validate_hash(self, password, stored_hash): def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.

View File

@ -565,7 +565,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def do_invite_join(self, target_hosts, room_id, joinee, content, snapshot): def do_invite_join(self, target_hosts, room_id, joinee, content):
""" Attempts to join the `joinee` to the room `room_id` via the """ Attempts to join the `joinee` to the room `room_id` via the
server `target_host`. server `target_host`.
@ -581,50 +581,19 @@ class FederationHandler(BaseHandler):
yield self.store.clean_room_for_join(room_id) yield self.store.clean_room_for_join(room_id)
origin, pdu = yield self.replication_layer.make_join( origin, event = yield self._make_and_verify_event(
target_hosts, target_hosts,
room_id, room_id,
joinee, joinee,
"join",
content content
) )
logger.debug("Got response to make_join: %s", pdu)
event = pdu
# We should assert some things.
# FIXME: Do this in a nicer way
assert(event.type == EventTypes.Member)
assert(event.user_id == joinee)
assert(event.state_key == joinee)
assert(event.room_id == room_id)
event.internal_metadata.outlier = False
self.room_queues[room_id] = [] self.room_queues[room_id] = []
builder = self.event_builder_factory.new(
unfreeze(event.get_pdu_json())
)
handled_events = set() handled_events = set()
try: try:
builder.event_id = self.event_builder_factory.create_event_id() new_event = self._sign_event(event)
builder.origin = self.hs.hostname
builder.content = content
if not hasattr(event, "signatures"):
builder.signatures = {}
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0],
)
new_event = builder.build()
# Try the host we successfully got a response to /make_join/ # Try the host we successfully got a response to /make_join/
# request first. # request first.
try: try:
@ -632,11 +601,7 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin) target_hosts.insert(0, origin)
except ValueError: except ValueError:
pass pass
ret = yield self.replication_layer.send_join(target_hosts, new_event)
ret = yield self.replication_layer.send_join(
target_hosts,
new_event
)
origin = ret["origin"] origin = ret["origin"]
state = ret["state"] state = ret["state"]
@ -700,7 +665,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
def on_make_join_request(self, room_id, user_id, query): def on_make_join_request(self, room_id, user_id, query):
""" We've received a /make_join/ request, so we create a partial """ We've received a /make_join/ request, so we create a partial
join event for the room and return that. We don *not* persist or join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back. process it until the other server has signed it and sent it back.
""" """
event_content = {"membership": Membership.JOIN} event_content = {"membership": Membership.JOIN}
@ -859,6 +824,168 @@ class FederationHandler(BaseHandler):
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
origin, event = yield self._make_and_verify_event(
target_hosts,
room_id,
user_id,
"leave",
{}
)
signed_event = self._sign_event(event)
# Try the host we successfully got a response to /make_join/
# request first.
try:
target_hosts.remove(origin)
target_hosts.insert(0, origin)
except ValueError:
pass
yield self.replication_layer.send_leave(
target_hosts,
signed_event
)
defer.returnValue(None)
@defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, content):
origin, pdu = yield self.replication_layer.make_membership_event(
target_hosts,
room_id,
user_id,
membership,
content
)
logger.debug("Got response to make_%s: %s", membership, pdu)
event = pdu
# We should assert some things.
# FIXME: Do this in a nicer way
assert(event.type == EventTypes.Member)
assert(event.user_id == user_id)
assert(event.state_key == user_id)
assert(event.room_id == room_id)
defer.returnValue((origin, event))
def _sign_event(self, event):
event.internal_metadata.outlier = False
builder = self.event_builder_factory.new(
unfreeze(event.get_pdu_json())
)
builder.event_id = self.event_builder_factory.create_event_id()
builder.origin = self.hs.hostname
if not hasattr(event, "signatures"):
builder.signatures = {}
add_hashes_and_signatures(
builder,
self.hs.hostname,
self.hs.config.signing_key[0],
)
return builder.build()
@defer.inlineCallbacks
@log_function
def on_make_leave_request(self, room_id, user_id):
""" We've received a /make_leave/ request, so we create a partial
join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
"""
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"content": {"membership": Membership.LEAVE},
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
})
event, context = yield self._create_new_client_event(
builder=builder,
)
self.auth.check(event, auth_events=context.current_state)
defer.returnValue(event)
@defer.inlineCallbacks
@log_function
def on_send_leave_request(self, origin, pdu):
""" We have received a leave event for a room. Fully process it."""
event = pdu
logger.debug(
"on_send_leave_request: Got event: %s, signatures: %s",
event.event_id,
event.signatures,
)
event.internal_metadata.outlier = False
context, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, event
)
logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
extra_users = []
if event.type == EventTypes.Member:
target_user_id = event.state_key
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
new_pdu = event
destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE:
destinations.add(
UserID.from_string(s.state_key).domain
)
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
destinations.discard(origin)
logger.debug(
"on_send_leave_request: Sending event: %s, signatures: %s",
event.event_id,
event.signatures,
)
self.replication_layer.send_pdu(new_pdu, destinations)
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor() yield run_on_reactor()

View File

@ -389,7 +389,22 @@ class RoomMemberHandler(BaseHandler):
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
yield self._do_join(event, context, do_auth=do_auth) yield self._do_join(event, context, do_auth=do_auth)
else: else:
# This is not a JOIN, so we can handle it normally. if event.membership == Membership.LEAVE:
is_host_in_room = yield self.is_host_in_room(room_id, context)
if not is_host_in_room:
# Rejecting an invite, rather than leaving a joined room
handler = self.hs.get_handlers().federation_handler
inviter = yield self.get_inviter(event)
if not inviter:
# return the same error as join_room_alias does
raise SynapseError(404, "No known servers")
yield handler.do_remotely_reject_invite(
[inviter.domain],
room_id,
event.user_id
)
defer.returnValue({"room_id": room_id})
return
# FIXME: This isn't idempotency. # FIXME: This isn't idempotency.
if prev_state and prev_state.membership == event.membership: if prev_state and prev_state.membership == event.membership:
@ -413,7 +428,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def join_room_alias(self, joinee, room_alias, do_auth=True, content={}): def join_room_alias(self, joinee, room_alias, content={}):
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias) mapping = yield directory_handler.get_association(room_alias)
@ -447,8 +462,6 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_join(self, event, context, room_hosts=None, do_auth=True): def _do_join(self, event, context, room_hosts=None, do_auth=True):
joinee = UserID.from_string(event.state_key)
# room_id = RoomID.from_string(event.room_id, self.hs)
room_id = event.room_id room_id = event.room_id
# XXX: We don't do an auth check if we are doing an invite # XXX: We don't do an auth check if we are doing an invite
@ -456,48 +469,18 @@ class RoomMemberHandler(BaseHandler):
# that we are allowed to join when we decide whether or not we # that we are allowed to join when we decide whether or not we
# need to do the invite/join dance. # need to do the invite/join dance.
is_host_in_room = yield self.auth.check_host_in_room( is_host_in_room = yield self.is_host_in_room(room_id, context)
event.room_id,
self.hs.hostname
)
if not is_host_in_room:
# is *anyone* in the room?
room_member_keys = [
v for (k, v) in context.current_state.keys() if (
k == "m.room.member"
)
]
if len(room_member_keys) == 0:
# has the room been created so we can join it?
create_event = context.current_state.get(("m.room.create", ""))
if create_event:
is_host_in_room = True
if is_host_in_room: if is_host_in_room:
should_do_dance = False should_do_dance = False
elif room_hosts: # TODO: Shouldn't this be remote_room_host? elif room_hosts: # TODO: Shouldn't this be remote_room_host?
should_do_dance = True should_do_dance = True
else: else:
# TODO(markjh): get prev_state from snapshot inviter = yield self.get_inviter(event)
prev_state = yield self.store.get_room_member( if not inviter:
joinee.to_string(), room_id
)
if prev_state and prev_state.membership == Membership.INVITE:
inviter = UserID.from_string(prev_state.user_id)
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
elif "third_party_invite" in event.content:
if "sender" in event.content["third_party_invite"]:
inviter = UserID.from_string(
event.content["third_party_invite"]["sender"]
)
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
else:
# return the same error as join_room_alias does # return the same error as join_room_alias does
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
should_do_dance = not self.hs.is_mine(inviter)
room_hosts = [inviter.domain]
if should_do_dance: if should_do_dance:
handler = self.hs.get_handlers().federation_handler handler = self.hs.get_handlers().federation_handler
@ -505,8 +488,7 @@ class RoomMemberHandler(BaseHandler):
room_hosts, room_hosts,
room_id, room_id,
event.user_id, event.user_id,
event.content, # FIXME To get a non-frozen dict event.content # FIXME To get a non-frozen dict
context
) )
else: else:
logger.debug("Doing normal join") logger.debug("Doing normal join")
@ -523,6 +505,44 @@ class RoomMemberHandler(BaseHandler):
"user_joined_room", user=user, room_id=room_id "user_joined_room", user=user, room_id=room_id
) )
@defer.inlineCallbacks
def get_inviter(self, event):
# TODO(markjh): get prev_state from snapshot
prev_state = yield self.store.get_room_member(
event.user_id, event.room_id
)
if prev_state and prev_state.membership == Membership.INVITE:
defer.returnValue(UserID.from_string(prev_state.user_id))
return
elif "third_party_invite" in event.content:
if "sender" in event.content["third_party_invite"]:
inviter = UserID.from_string(
event.content["third_party_invite"]["sender"]
)
defer.returnValue(inviter)
defer.returnValue(None)
@defer.inlineCallbacks
def is_host_in_room(self, room_id, context):
is_host_in_room = yield self.auth.check_host_in_room(
room_id,
self.hs.hostname
)
if not is_host_in_room:
# is *anyone* in the room?
room_member_keys = [
v for (k, v) in context.current_state.keys() if (
k == "m.room.member"
)
]
if len(room_member_keys) == 0:
# has the room been created so we can join it?
create_event = context.current_state.get(("m.room.create", ""))
if create_event:
is_host_in_room = True
defer.returnValue(is_host_in_room)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_joined_rooms_for_user(self, user): def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given """Returns a list of roomids that the user has any of the given

View File

@ -192,36 +192,6 @@ class LoginRestServlet(ClientV1RestServlet):
return (user, attributes) return (user, attributes)
class LoginFallbackRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/fallback$")
def on_GET(self, request):
# TODO(kegan): This should be returning some HTML which is capable of
# hitting LoginRestServlet
return (200, {})
class PasswordResetRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/reset")
@defer.inlineCallbacks
def on_POST(self, request):
reset_info = _parse_json(request)
try:
email = reset_info["email"]
user_id = reset_info["user_id"]
handler = self.handlers.login_handler
yield handler.reset_password(user_id, email)
# purposefully give no feedback to avoid people hammering different
# combinations.
defer.returnValue((200, {}))
except KeyError:
raise SynapseError(
400,
"Missing keys. Requires 'email' and 'user_id'."
)
class SAML2RestServlet(ClientV1RestServlet): class SAML2RestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/saml2") PATTERN = client_path_pattern("/login/saml2")

View File

@ -23,7 +23,7 @@ from synapse.types import StreamToken
from synapse.events.utils import ( from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_event_id, serialize_event, format_event_for_client_v2_without_event_id,
) )
from synapse.api.filtering import Filter from synapse.api.filtering import FilterCollection
from ._base import client_v2_pattern from ._base import client_v2_pattern
import copy import copy
@ -103,7 +103,7 @@ class SyncRestServlet(RestServlet):
user.localpart, filter_id user.localpart, filter_id
) )
except: except:
filter = Filter({}) filter = FilterCollection({})
sync_config = SyncConfig( sync_config = SyncConfig(
user=user, user=user,

View File

@ -1,71 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module allows you to send out emails.
"""
import email.utils
import smtplib
import twisted.python.log
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
import logging
logger = logging.getLogger(__name__)
class EmailException(Exception):
pass
def send_email(smtp_server, from_addr, to_addr, subject, body):
"""Sends an email.
Args:
smtp_server(str): The SMTP server to use.
from_addr(str): The address to send from.
to_addr(str): The address to send to.
subject(str): The subject of the email.
body(str): The plain text body of the email.
Raises:
EmailException if there was a problem sending the mail.
"""
if not smtp_server or not from_addr or not to_addr:
raise EmailException("Need SMTP server, from and to addresses. Check"
" the config to set these.")
msg = MIMEMultipart('alternative')
msg['Subject'] = subject
msg['From'] = from_addr
msg['To'] = to_addr
plain_part = MIMEText(body)
msg.attach(plain_part)
raw_from = email.utils.parseaddr(from_addr)[1]
raw_to = email.utils.parseaddr(to_addr)[1]
if not raw_from or not raw_to:
raise EmailException("Couldn't parse from/to address.")
logger.info("Sending email to %s on server %s with subject %s",
to_addr, smtp_server, subject)
try:
smtp = smtplib.SMTP(smtp_server)
smtp.sendmail(raw_from, raw_to, msg.as_string())
smtp.quit()
except Exception as origException:
twisted.python.log.err()
ese = EmailException()
ese.cause = origException
raise ese

View File

@ -23,8 +23,8 @@ JOIN_KEYS = {
"token", "token",
"public_key", "public_key",
"key_validity_url", "key_validity_url",
"signatures",
"sender", "sender",
"signed",
} }

View File

@ -23,10 +23,17 @@ from tests.utils import (
) )
from synapse.types import UserID from synapse.types import UserID
from synapse.api.filtering import Filter from synapse.api.filtering import FilterCollection, Filter
user_localpart = "test_user" user_localpart = "test_user"
MockEvent = namedtuple("MockEvent", "sender type room_id") # MockEvent = namedtuple("MockEvent", "sender type room_id")
def MockEvent(**kwargs):
ev = NonCallableMock(spec_set=kwargs.keys())
ev.configure_mock(**kwargs)
return ev
class FilteringTestCase(unittest.TestCase): class FilteringTestCase(unittest.TestCase):
@ -44,7 +51,6 @@ class FilteringTestCase(unittest.TestCase):
) )
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
self.filter = Filter({})
self.datastore = hs.get_datastore() self.datastore = hs.get_datastore()
@ -57,8 +63,9 @@ class FilteringTestCase(unittest.TestCase):
type="m.room.message", type="m.room.message",
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_types_works_with_wildcards(self): def test_definition_types_works_with_wildcards(self):
@ -71,7 +78,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_types_works_with_unknowns(self): def test_definition_types_works_with_unknowns(self):
@ -84,7 +91,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_types_works_with_literals(self): def test_definition_not_types_works_with_literals(self):
@ -97,7 +104,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_types_works_with_wildcards(self): def test_definition_not_types_works_with_wildcards(self):
@ -110,7 +117,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_types_works_with_unknowns(self): def test_definition_not_types_works_with_unknowns(self):
@ -123,7 +130,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_types_takes_priority_over_types(self): def test_definition_not_types_takes_priority_over_types(self):
@ -137,7 +144,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_senders_works_with_literals(self): def test_definition_senders_works_with_literals(self):
@ -150,7 +157,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_senders_works_with_unknowns(self): def test_definition_senders_works_with_unknowns(self):
@ -163,7 +170,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_senders_works_with_literals(self): def test_definition_not_senders_works_with_literals(self):
@ -176,7 +183,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_senders_works_with_unknowns(self): def test_definition_not_senders_works_with_unknowns(self):
@ -189,7 +196,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_senders_takes_priority_over_senders(self): def test_definition_not_senders_takes_priority_over_senders(self):
@ -203,7 +210,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!foo:bar" room_id="!foo:bar"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_rooms_works_with_literals(self): def test_definition_rooms_works_with_literals(self):
@ -216,7 +223,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!secretbase:unknown" room_id="!secretbase:unknown"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_rooms_works_with_unknowns(self): def test_definition_rooms_works_with_unknowns(self):
@ -229,7 +236,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_rooms_works_with_literals(self): def test_definition_not_rooms_works_with_literals(self):
@ -242,7 +249,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_rooms_works_with_unknowns(self): def test_definition_not_rooms_works_with_unknowns(self):
@ -255,7 +262,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!anothersecretbase:unknown" room_id="!anothersecretbase:unknown"
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_not_rooms_takes_priority_over_rooms(self): def test_definition_not_rooms_takes_priority_over_rooms(self):
@ -269,7 +276,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!secretbase:unknown" room_id="!secretbase:unknown"
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_combined_event(self): def test_definition_combined_event(self):
@ -287,7 +294,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup room_id="!stage:unknown" # yup
) )
self.assertTrue( self.assertTrue(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_combined_event_bad_sender(self): def test_definition_combined_event_bad_sender(self):
@ -305,7 +312,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup room_id="!stage:unknown" # yup
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_combined_event_bad_room(self): def test_definition_combined_event_bad_room(self):
@ -323,7 +330,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!piggyshouse:muppets" # nope room_id="!piggyshouse:muppets" # nope
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
def test_definition_combined_event_bad_type(self): def test_definition_combined_event_bad_type(self):
@ -341,7 +348,7 @@ class FilteringTestCase(unittest.TestCase):
room_id="!stage:unknown" # yup room_id="!stage:unknown" # yup
) )
self.assertFalse( self.assertFalse(
self.filter._passes_definition(definition, event) Filter(definition).check(event)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -359,7 +366,6 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
type="m.profile", type="m.profile",
room_id="!foo:bar"
) )
events = [event] events = [event]
@ -386,7 +392,6 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent( event = MockEvent(
sender="@foo:bar", sender="@foo:bar",
type="custom.avatar.3d.crazy", type="custom.avatar.3d.crazy",
room_id="!foo:bar"
) )
events = [event] events = [event]

15
tests/crypto/__init__.py Normal file
View File

@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from synapse.events.builder import EventBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from unpaddedbase64 import decode_base64
import nacl.signing
# Perform these tests using given secret key so we get entirely deterministic
# signatures output that we can test against.
SIGNING_KEY_SEED = decode_base64(
"YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1"
)
KEY_ALG = "ed25519"
KEY_VER = 1
KEY_NAME = "%s:%d" % (KEY_ALG, KEY_VER)
HOSTNAME = "domain"
class EventSigningTestCase(unittest.TestCase):
def setUp(self):
self.signing_key = nacl.signing.SigningKey(SIGNING_KEY_SEED)
self.signing_key.alg = KEY_ALG
self.signing_key.version = KEY_VER
def test_sign_minimal(self):
builder = EventBuilder(
{
'event_id': "$0:domain",
'origin': "domain",
'origin_server_ts': 1000000,
'signatures': {},
'type': "X",
'unsigned': {'age_ts': 1000000},
},
)
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
event = builder.build()
self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes)
self.assertEquals(
event.hashes['sha256'],
"6tJjLpXtggfke8UxFhAKg82QVkJzvKOVOOSjUDK4ZSI",
)
self.assertTrue(hasattr(event, 'signatures'))
self.assertIn(HOSTNAME, event.signatures)
self.assertIn(KEY_NAME, event.signatures["domain"])
self.assertEquals(
event.signatures[HOSTNAME][KEY_NAME],
"2Wptgo4CwmLo/Y8B8qinxApKaCkBG2fjTWB7AbP5Uy+"
"aIbygsSdLOFzvdDjww8zUVKCmI02eP9xtyJxc/cLiBA",
)
def test_sign_message(self):
builder = EventBuilder(
{
'content': {
'body': "Here is the message content",
},
'event_id': "$0:domain",
'origin': "domain",
'origin_server_ts': 1000000,
'type': "m.room.message",
'room_id': "!r:domain",
'sender': "@u:domain",
'signatures': {},
'unsigned': {'age_ts': 1000000},
}
)
add_hashes_and_signatures(builder, HOSTNAME, self.signing_key)
event = builder.build()
self.assertTrue(hasattr(event, 'hashes'))
self.assertIn('sha256', event.hashes)
self.assertEquals(
event.hashes['sha256'],
"onLKD1bGljeBWQhWZ1kaP9SorVmRQNdN5aM2JYU2n/g",
)
self.assertTrue(hasattr(event, 'signatures'))
self.assertIn(HOSTNAME, event.signatures)
self.assertIn(KEY_NAME, event.signatures["domain"])
self.assertEquals(
event.signatures[HOSTNAME][KEY_NAME],
"Wm+VzmOUOz08Ds+0NTWb1d4CZrVsJSikkeRxh6aCcUw"
"u6pNC78FunoD7KNWzqFn241eYHYMGCA5McEiVPdhzBA"
)

View File

@ -277,10 +277,10 @@ class RoomPermissionsTestCase(RestTestCase):
expect_code=403) expect_code=403)
# set [invite/join/left] of self, set [invite/join/left] of other, # set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 403s # expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]: for usr in [self.user_id, self.rmcreator_id]:
yield self.join(room=room, user=usr, expect_code=404) yield self.join(room=room, user=usr, expect_code=404)
yield self.leave(room=room, user=usr, expect_code=403) yield self.leave(room=room, user=usr, expect_code=404)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_membership_private_room_perms(self): def test_membership_private_room_perms(self):