Revert "Merge two of the room join codepaths"

This reverts commit cf81375b94.

It subtly violates a guest joining auth check
pull/577/head
Daniel Wagner-Hall 2016-02-12 16:17:24 +00:00
parent d7aa103f00
commit 4de08a4672
5 changed files with 69 additions and 73 deletions

View File

@ -84,11 +84,6 @@ class RegistrationError(SynapseError):
pass pass
class BadIdentifierError(SynapseError):
"""An error indicating an identifier couldn't be parsed."""
pass
class UnrecognizedRequestError(SynapseError): class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make""" """An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):

View File

@ -169,14 +169,7 @@ class ProfileHandler(BaseHandler):
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
if displayname is None:
del state["displayname"]
else:
state["displayname"] = displayname state["displayname"] = displayname
if avatar_url is None:
del state["avatar_url"]
else:
state["avatar_url"] = avatar_url state["avatar_url"] = avatar_url
defer.returnValue(None) defer.returnValue(None)

View File

@ -527,17 +527,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def lookup_room_alias(self, room_alias): def join_room_alias(self, joinee, room_alias, content={}):
"""
Gets the room ID for an alias.
Args:
room_alias (str): The room alias to look up.
Returns:
A tuple of the room ID (str) and the hosts hosting the room ([str])
Raises:
SynapseError if the room couldn't be looked up.
"""
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias) mapping = yield directory_handler.get_association(room_alias)
@ -549,40 +539,24 @@ class RoomMemberHandler(BaseHandler):
if not hosts: if not hosts:
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
defer.returnValue((room_id, hosts)) # If event doesn't include a display name, add one.
yield collect_presencelike_data(self.distributor, joinee, content)
@defer.inlineCallbacks
def do_join(self, requester, room_id, hosts=None):
"""
Joins requester to room_id.
Args:
requester (Requester): The user joining the room.
room_id (str): The room ID (not alias) being joined.
hosts ([str]): A list of hosts which are hopefully in the room.
Raises:
SynapseError if the room couldn't be joined.
"""
hosts = hosts or []
content = {"membership": Membership.JOIN}
if requester.is_guest:
content["kind"] = "guest"
yield collect_presencelike_data(self.distributor, requester.user, content)
content.update({"membership": Membership.JOIN})
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new({
"type": EventTypes.Member, "type": EventTypes.Member,
"state_key": requester.user.to_string(), "state_key": joinee.to_string(),
"room_id": room_id, "room_id": room_id,
"sender": requester.user.to_string(), "sender": joinee.to_string(),
"membership": Membership.JOIN, # For backwards compatibility "membership": Membership.JOIN,
"content": content, "content": content,
}) })
event, context = yield self._create_new_client_event(builder) event, context = yield self._create_new_client_event(builder)
yield self._do_join(event, context, room_hosts=hosts) yield self._do_join(event, context, room_hosts=hosts)
defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_join(self, event, context, room_hosts=None): def _do_join(self, event, context, room_hosts=None):
room_id = event.room_id room_id = event.room_id

View File

@ -216,7 +216,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet): class JoinRoomAliasServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/join/(?P<room_identifier>[^/]*)$")
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
PATTERNS = ("/join/(?P<room_identifier>[^/]*)")
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None): def on_POST(self, request, room_identifier, txn_id=None):
@ -225,22 +229,60 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
allow_guest=True, allow_guest=True,
) )
handler = self.handlers.room_member_handler # the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid
# SynapseErrors.
room_id = None identifier = None
hosts = [] is_room_alias = False
if RoomAlias.is_valid(room_identifier): try:
room_alias = RoomAlias.from_string(room_identifier) identifier = RoomAlias.from_string(room_identifier)
room_id, hosts = yield handler.lookup_room_alias(room_alias) is_room_alias = True
else: except SynapseError:
room_id = RoomID.from_string(room_identifier).to_string() identifier = RoomID.from_string(room_identifier)
# TODO: Support for specifying the home server to join with? # TODO: Support for specifying the home server to join with?
yield handler.do_join( if is_room_alias:
requester, room_id, hosts=hosts handler = self.handlers.room_member_handler
ret_dict = yield handler.join_room_alias(
requester.user,
identifier,
) )
defer.returnValue((200, {"room_id": room_id})) defer.returnValue((200, ret_dict))
else: # room id
msg_handler = self.handlers.message_handler
content = {"membership": Membership.JOIN}
if requester.is_guest:
content["kind"] = "guest"
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": identifier.to_string(),
"sender": requester.user.to_string(),
"state_key": requester.user.to_string(),
},
token_id=requester.access_token_id,
txn_id=txn_id,
is_guest=requester.is_guest,
)
defer.returnValue((200, {"room_id": identifier.to_string()}))
@defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(request, room_identifier, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing # TODO: Needs unit testing

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.errors import SynapseError, BadIdentifierError from synapse.api.errors import SynapseError
from collections import namedtuple from collections import namedtuple
@ -51,13 +51,13 @@ class DomainSpecificString(
def from_string(cls, s): def from_string(cls, s):
"""Parse the string given by 's' into a structure object.""" """Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0] != cls.SIGIL: if len(s) < 1 or s[0] != cls.SIGIL:
raise BadIdentifierError(400, "Expected %s string to start with '%s'" % ( raise SynapseError(400, "Expected %s string to start with '%s'" % (
cls.__name__, cls.SIGIL, cls.__name__, cls.SIGIL,
)) ))
parts = s[1:].split(':', 1) parts = s[1:].split(':', 1)
if len(parts) != 2: if len(parts) != 2:
raise BadIdentifierError( raise SynapseError(
400, "Expected %s of the form '%slocalname:domain'" % ( 400, "Expected %s of the form '%slocalname:domain'" % (
cls.__name__, cls.SIGIL, cls.__name__, cls.SIGIL,
) )
@ -69,14 +69,6 @@ class DomainSpecificString(
# names on one HS # names on one HS
return cls(localpart=parts[0], domain=domain) return cls(localpart=parts[0], domain=domain)
@classmethod
def is_valid(cls, s):
try:
cls.from_string(s)
return True
except:
return False
def to_string(self): def to_string(self):
"""Return a string encoding the fields of the structure object.""" """Return a string encoding the fields of the structure object."""
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)