Exchange 3pid invites for m.room.member invites

pull/342/merge
Daniel Wagner-Hall 2015-11-05 16:43:19 +00:00
parent 32fc0737d6
commit 2cebe53545
10 changed files with 230 additions and 180 deletions

View File

@ -24,7 +24,6 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import RoomID, UserID, EventID from synapse.types import RoomID, UserID, EventID
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util import third_party_invites
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import logging
@ -318,6 +317,11 @@ class Auth(object):
} }
) )
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
return True
if Membership.JOIN != membership: if Membership.JOIN != membership:
if (caller_invited if (caller_invited
and Membership.LEAVE == membership and Membership.LEAVE == membership
@ -361,8 +365,7 @@ class Auth(object):
pass pass
elif join_rule == JoinRules.INVITE: elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited: if not caller_in_room and not caller_invited:
if not self._verify_third_party_invite(event, auth_events): raise AuthError(403, "You are not invited to this room.")
raise AuthError(403, "You are not invited to this room.")
else: else:
# TODO (erikj): may_join list # TODO (erikj): may_join list
# TODO (erikj): private rooms # TODO (erikj): private rooms
@ -390,10 +393,10 @@ class Auth(object):
def _verify_third_party_invite(self, event, auth_events): def _verify_third_party_invite(self, event, auth_events):
""" """
Validates that the join event is authorized by a previous third-party invite. Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the invite, Checks that the public key, and keyserver, match those in the third party invite,
and that the join event has a signature issued using that public key. and that the invite event has a signature issued using that public key.
Args: Args:
event: The m.room.member join event being validated. event: The m.room.member join event being validated.
@ -404,35 +407,28 @@ class Auth(object):
True if the event fulfills the expectations of a previous third party True if the event fulfills the expectations of a previous third party
invite event. invite event.
""" """
if not third_party_invites.join_has_third_party_invite(event.content): if "third_party_invite" not in event.content:
return False return False
join_third_party_invite = event.content["third_party_invite"] if "signed" not in event.content["third_party_invite"]:
token = join_third_party_invite["token"] return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get( invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,) (EventTypes.ThirdPartyInvite, token,)
) )
if not invite_event: if not invite_event:
logger.info("Failing 3pid invite because no invite found for token %s", token) return False
if event.user_id != invite_event.user_id:
return False return False
try: try:
public_key = join_third_party_invite["public_key"] public_key = invite_event.content["public_key"]
key_validity_url = join_third_party_invite["key_validity_url"] if signed["mxid"] != event.state_key:
if invite_event.content["public_key"] != public_key:
logger.info(
"Failing 3pid invite because public key invite: %s != join: %s",
invite_event.content["public_key"],
public_key
)
return False
if invite_event.content["key_validity_url"] != key_validity_url:
logger.info(
"Failing 3pid invite because key_validity_url invite: %s != join: %s",
invite_event.content["key_validity_url"],
key_validity_url
)
return False
signed = join_third_party_invite["signed"]
if signed["mxid"] != event.user_id:
return False return False
if signed["token"] != token: if signed["token"] != token:
return False return False
@ -445,6 +441,11 @@ class Auth(object):
decode_base64(public_key) decode_base64(public_key)
) )
verify_signed_json(signed, server, verify_key) verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True return True
return False return False
except (KeyError, SignatureVerifyException,): except (KeyError, SignatureVerifyException,):
@ -751,17 +752,19 @@ class Auth(object):
if e_type == Membership.JOIN: if e_type == Membership.JOIN:
if member_event and not is_public: if member_event and not is_public:
auth_ids.append(member_event.event_id) auth_ids.append(member_event.event_id)
if third_party_invites.join_has_third_party_invite(event.content): else:
if member_event:
auth_ids.append(member_event.event_id)
if e_type == Membership.INVITE:
if "third_party_invite" in event.content:
key = ( key = (
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["token"] event.content["third_party_invite"]["token"]
) )
invite = current_state.get(key) third_party_invite = current_state.get(key)
if invite: if third_party_invite:
auth_ids.append(invite.event_id) auth_ids.append(third_party_invite.event_id)
else:
if member_event:
auth_ids.append(member_event.event_id)
elif member_event: elif member_event:
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id) auth_ids.append(member_event.event_id)

View File

@ -26,7 +26,6 @@ from synapse.api.errors import (
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util import third_party_invites
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
@ -358,7 +357,7 @@ class FederationClient(FederationBase):
defer.returnValue(signed_auth) defer.returnValue(signed_auth)
@defer.inlineCallbacks @defer.inlineCallbacks
def make_membership_event(self, destinations, room_id, user_id, membership, content): def make_membership_event(self, destinations, room_id, user_id, membership):
""" """
Creates an m.room.member event, with context, without participating in the room. Creates an m.room.member event, with context, without participating in the room.
@ -390,14 +389,9 @@ class FederationClient(FederationBase):
if destination == self.server_name: if destination == self.server_name:
continue continue
args = {}
if third_party_invites.join_has_third_party_invite(content):
args = third_party_invites.extract_join_keys(
content["third_party_invite"]
)
try: try:
ret = yield self.transport_layer.make_membership_event( ret = yield self.transport_layer.make_membership_event(
destination, room_id, user_id, membership, args destination, room_id, user_id, membership
) )
pdu_dict = ret["event"] pdu_dict = ret["event"]
@ -704,3 +698,26 @@ class FederationClient(FederationBase):
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier
return event return event
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
if destination == self.server_name:
continue
try:
yield self.transport_layer.exchange_third_party_invite(
destination=destination,
room_id=room_id,
event_dict=event_dict,
)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
"Failed to send_third_party_invite via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")

View File

@ -23,12 +23,10 @@ from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
from synapse.api.errors import FederationError, SynapseError, Codes from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
from synapse.util import third_party_invites
import simplejson as json import simplejson as json
import logging import logging
@ -230,19 +228,8 @@ class FederationServer(FederationBase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id, query): def on_make_join_request(self, room_id, user_id):
threepid_details = {} pdu = yield self.handler.on_make_join_request(room_id, user_id)
if third_party_invites.has_join_keys(query):
for k in third_party_invites.JOIN_KEYS:
if not isinstance(query[k], list) or len(query[k]) != 1:
raise FederationError(
"FATAL",
Codes.MISSING_PARAM,
"key %s value %s" % (k, query[k],),
None
)
threepid_details[k] = query[k][0]
pdu = yield self.handler.on_make_join_request(room_id, user_id, threepid_details)
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)}) defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@ -556,3 +543,15 @@ class FederationServer(FederationBase):
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier
return event return event
@defer.inlineCallbacks
def exchange_third_party_invite(self, invite):
ret = yield self.handler.exchange_third_party_invite(invite)
defer.returnValue(ret)
@defer.inlineCallbacks
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
ret = yield self.handler.on_exchange_third_party_invite_request(
origin, room_id, event_dict
)
defer.returnValue(ret)

View File

@ -161,7 +161,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_membership_event(self, destination, room_id, user_id, membership, args={}): def make_membership_event(self, destination, room_id, user_id, membership):
valid_memberships = {Membership.JOIN, Membership.LEAVE} valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships: if membership not in valid_memberships:
raise RuntimeError( raise RuntimeError(
@ -173,7 +173,6 @@ class TransportLayerClient(object):
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
args=args,
retry_on_dns_fail=True, retry_on_dns_fail=True,
) )
@ -218,6 +217,19 @@ class TransportLayerClient(object):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,)
response = yield self.client.put_json(
destination=destination,
path=path,
data=event_dict,
)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_event_auth(self, destination, room_id, event_id): def get_event_auth(self, destination, room_id, event_id):

View File

@ -292,7 +292,7 @@ class FederationMakeJoinServlet(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id): def on_GET(self, origin, content, query, context, user_id):
content = yield self.handler.on_make_join_request(context, user_id, query) content = yield self.handler.on_make_join_request(context, user_id)
defer.returnValue((200, content)) defer.returnValue((200, content))
@ -343,6 +343,17 @@ class FederationInviteServlet(BaseFederationServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/([^/]*)"
@defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id):
content = yield self.handler.on_exchange_third_party_invite_request(
origin, room_id, content
)
defer.returnValue((200, content))
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query" PATH = "/user/keys/query"
@ -396,6 +407,30 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind"
@defer.inlineCallbacks
def on_POST(self, request):
content_bytes = request.content.read()
content = json.loads(content_bytes)
if "invites" in content:
last_exception = None
for invite in content["invites"]:
try:
yield self.handler.exchange_third_party_invite(invite)
except Exception as e:
last_exception = e
if last_exception:
raise last_exception
defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationPullServlet, FederationPullServlet,
FederationEventServlet, FederationEventServlet,
@ -413,4 +448,6 @@ SERVLET_CLASSES = (
FederationEventAuthServlet, FederationEventAuthServlet,
FederationClientKeysQueryServlet, FederationClientKeysQueryServlet,
FederationClientKeysClaimServlet, FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
) )

View File

@ -21,7 +21,6 @@ from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias from synapse.types import UserID, RoomAlias
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util import third_party_invites
import logging import logging
@ -192,16 +191,6 @@ class BaseHandler(object):
) )
) )
if (
event.type == EventTypes.Member and
event.content["membership"] == Membership.JOIN and
third_party_invites.join_has_third_party_invite(event.content)
):
yield third_party_invites.check_key_valid(
self.hs.get_simple_http_client(),
event
)
federation_handler = self.hs.get_handlers().federation_handler federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member: if event.type == EventTypes.Member:

View File

@ -21,6 +21,7 @@ from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError, AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
) )
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -39,7 +40,6 @@ from twisted.internet import defer
import itertools import itertools
import logging import logging
from synapse.util import third_party_invites
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -58,6 +58,8 @@ class FederationHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(FederationHandler, self).__init__(hs) super(FederationHandler, self).__init__(hs)
self.hs = hs
self.distributor.observe( self.distributor.observe(
"user_joined_room", "user_joined_room",
self._on_user_joined self._on_user_joined
@ -68,7 +70,6 @@ class FederationHandler(BaseHandler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer() self.replication_layer = hs.get_replication_layer()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
# self.auth_handler = gs.get_auth_handler()
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -563,7 +564,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def do_invite_join(self, target_hosts, room_id, joinee, content): def do_invite_join(self, target_hosts, room_id, joinee):
""" Attempts to join the `joinee` to the room `room_id` via the """ Attempts to join the `joinee` to the room `room_id` via the
server `target_host`. server `target_host`.
@ -583,8 +584,7 @@ class FederationHandler(BaseHandler):
target_hosts, target_hosts,
room_id, room_id,
joinee, joinee,
"join", "join"
content
) )
self.room_queues[room_id] = [] self.room_queues[room_id] = []
@ -661,16 +661,12 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_make_join_request(self, room_id, user_id, query): def on_make_join_request(self, room_id, user_id):
""" We've received a /make_join/ request, so we create a partial """ We've received a /make_join/ request, so we create a partial
join event for the room and return that. We do *not* persist or join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back. process it until the other server has signed it and sent it back.
""" """
event_content = {"membership": Membership.JOIN} event_content = {"membership": Membership.JOIN}
if third_party_invites.has_join_keys(query):
event_content["third_party_invite"] = (
third_party_invites.extract_join_keys(query)
)
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new({
"type": EventTypes.Member, "type": EventTypes.Member,
@ -686,9 +682,6 @@ class FederationHandler(BaseHandler):
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
if third_party_invites.join_has_third_party_invite(event.content):
third_party_invites.check_key_valid(self.hs.get_simple_http_client(), event)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -828,8 +821,7 @@ class FederationHandler(BaseHandler):
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
"leave", "leave"
{}
) )
signed_event = self._sign_event(event) signed_event = self._sign_event(event)
@ -848,13 +840,12 @@ class FederationHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, content): def _make_and_verify_event(self, target_hosts, room_id, user_id, membership):
origin, pdu = yield self.replication_layer.make_membership_event( origin, pdu = yield self.replication_layer.make_membership_event(
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
membership, membership
content
) )
logger.debug("Got response to make_%s: %s", membership, pdu) logger.debug("Got response to make_%s: %s", membership, pdu)
@ -1647,3 +1638,75 @@ class FederationHandler(BaseHandler):
}, },
"missing": [e.event_id for e in missing_locals], "missing": [e.event_id for e in missing_locals],
}) })
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, invite):
sender = invite["sender"]
room_id = invite["room_id"]
event_dict = {
"type": EventTypes.Member,
"content": {
"membership": Membership.INVITE,
"third_party_invite": invite,
},
"room_id": room_id,
"sender": sender,
"state_key": invite["mxid"],
}
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
event, context = yield self._create_new_client_event(builder=builder)
self.auth.check(event, context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
else:
destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)])
yield self.replication_layer.forward_third_party_invite(
destinations,
room_id,
event_dict,
)
@defer.inlineCallbacks
@log_function
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
builder = self.event_builder_factory.new(event_dict)
event, context = yield self._create_new_client_event(
builder=builder,
)
self.auth.check(event, auth_events=context.current_state)
yield self._validate_keyserver(event, auth_events=context.current_state)
returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context)
@defer.inlineCallbacks
def _validate_keyserver(self, event, auth_events):
token = event.content["third_party_invite"]["signed"]["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
try:
response = yield self.hs.get_simple_http_client().get_json(
invite_event.content["key_validity_url"],
{"public_key": invite_event.content["public_key"]}
)
except Exception:
raise SynapseError(
502,
"Third party certificate could not be checked"
)
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")

View File

@ -38,6 +38,8 @@ import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
id_server_scheme = "https://"
class RoomCreationHandler(BaseHandler): class RoomCreationHandler(BaseHandler):
@ -488,8 +490,7 @@ class RoomMemberHandler(BaseHandler):
yield handler.do_invite_join( yield handler.do_invite_join(
room_hosts, room_hosts,
room_id, room_id,
event.user_id, event.user_id
event.content # FIXME To get a non-frozen dict
) )
else: else:
logger.debug("Doing normal join") logger.debug("Doing normal join")
@ -632,7 +633,7 @@ class RoomMemberHandler(BaseHandler):
""" """
try: try:
data = yield self.hs.get_simple_http_client().get_json( data = yield self.hs.get_simple_http_client().get_json(
"https://%s/_matrix/identity/api/v1/lookup" % (id_server,), "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{ {
"medium": medium, "medium": medium,
"address": address, "address": address,
@ -655,8 +656,8 @@ class RoomMemberHandler(BaseHandler):
raise AuthError(401, "No signature from server %s" % (server_hostname,)) raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items(): for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json( key_data = yield self.hs.get_simple_http_client().get_json(
"https://%s/_matrix/identity/api/v1/pubkey/%s" % "%s%s/_matrix/identity/api/v1/pubkey/%s" %
(server_hostname, key_name,), (id_server_scheme, server_hostname, key_name,),
) )
if "public_key" not in key_data: if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" % raise AuthError(401, "No public key named %s from %s" %
@ -709,7 +710,9 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _ask_id_server_for_third_party_invite( def _ask_id_server_for_third_party_invite(
self, id_server, medium, address, room_id, sender): self, id_server, medium, address, room_id, sender):
is_url = "https://%s/_matrix/identity/api/v1/store-invite" % (id_server,) is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json( data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url, is_url,
{ {
@ -722,8 +725,8 @@ class RoomMemberHandler(BaseHandler):
# TODO: Check for success # TODO: Check for success
token = data["token"] token = data["token"]
public_key = data["public_key"] public_key = data["public_key"]
key_validity_url = "https://%s/_matrix/identity/api/v1/pubkey/isvalid" % ( key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server, id_server_scheme, id_server,
) )
defer.returnValue((token, public_key, key_validity_url)) defer.returnValue((token, public_key, key_validity_url))

View File

@ -26,7 +26,6 @@ from synapse.events.utils import serialize_event
import simplejson as json import simplejson as json
import logging import logging
import urllib import urllib
from synapse.util import third_party_invites
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -453,7 +452,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
# target user is you unless it is an invite # target user is you unless it is an invite
state_key = user.to_string() state_key = user.to_string()
if membership_action == "invite" and third_party_invites.has_invite_keys(content): if membership_action == "invite" and self._has_3pid_invite_keys(content):
yield self.handlers.room_member_handler.do_3pid_invite( yield self.handlers.room_member_handler.do_3pid_invite(
room_id, room_id,
user, user,
@ -480,19 +479,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
event_content = {
"membership": unicode(membership_action),
}
if membership_action == "join" and third_party_invites.has_join_keys(content):
event_content["third_party_invite"] = (
third_party_invites.extract_join_keys(content)
)
yield msg_handler.create_and_send_event( yield msg_handler.create_and_send_event(
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
"content": event_content, "content": {"membership": unicode(membership_action)},
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": user.to_string(),
"state_key": state_key, "state_key": state_key,
@ -503,6 +493,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address", "display_name"}:
if key not in content:
return False
return True
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id): def on_PUT(self, request, room_id, membership_action, txn_id):
try: try:

View File

@ -1,69 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError
INVITE_KEYS = {"id_server", "medium", "address", "display_name"}
JOIN_KEYS = {
"token",
"public_key",
"key_validity_url",
"sender",
"signed",
}
def has_invite_keys(content):
for key in INVITE_KEYS:
if key not in content:
return False
return True
def has_join_keys(content):
for key in JOIN_KEYS:
if key not in content:
return False
return True
def join_has_third_party_invite(content):
if "third_party_invite" not in content:
return False
return has_join_keys(content["third_party_invite"])
def extract_join_keys(src):
return {
key: value
for key, value in src.items()
if key in JOIN_KEYS
}
@defer.inlineCallbacks
def check_key_valid(http_client, event):
try:
response = yield http_client.get_json(
event.content["third_party_invite"]["key_validity_url"],
{"public_key": event.content["third_party_invite"]["public_key"]}
)
except Exception:
raise AuthError(502, "Third party certificate could not be checked")
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")