From 95614e52204c6ffd8be62a4e4cab716c9a985473 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 18 Nov 2014 15:36:36 +0000 Subject: [PATCH] Fix auth to correctly handle initial creation of rooms --- synapse/api/auth.py | 24 +++++++++++++-- synapse/app/homeserver.py | 61 ++++++++++++++++++++++++--------------- 2 files changed, 58 insertions(+), 27 deletions(-) diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 87f19a96d6..635571d2b6 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -21,7 +21,7 @@ from synapse.api.constants import Membership, JoinRules from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.events.room import ( RoomMemberEvent, RoomPowerLevelsEvent, RoomRedactionEvent, - RoomJoinRulesEvent, RoomCreateEvent, + RoomJoinRulesEvent, RoomCreateEvent, RoomAliasesEvent, ) from synapse.util.logutils import log_function from syutil.base64util import encode_base64 @@ -63,6 +63,10 @@ class Auth(object): # FIXME return True + # FIXME: Temp hack + if event.type == RoomAliasesEvent.TYPE: + return True + if event.type == RoomMemberEvent.TYPE: allowed = self.is_membership_change_allowed(event) if allowed: @@ -144,6 +148,17 @@ class Auth(object): @log_function def is_membership_change_allowed(self, event): + membership = event.content["membership"] + + # Check if this is the room creator joining: + if len(event.prev_events) == 1 and Membership.JOIN == membership: + # Get room creation event: + key = (RoomCreateEvent.TYPE, "", ) + create = event.old_state_events.get(key) + if event.prev_events[0][0] == create.event_id: + if create.content["creator"] == event.state_key: + return True + target_user_id = event.state_key # get info about the caller @@ -159,8 +174,6 @@ class Auth(object): target_in_room = target and target.membership == Membership.JOIN - membership = event.content["membership"] - key = (RoomJoinRulesEvent.TYPE, "", ) join_rule_event = event.old_state_events.get(key) if join_rule_event: @@ -255,6 +268,11 @@ class Auth(object): level = power_level_event.content.get("users", {}).get(user_id) if not level: level = power_level_event.content.get("users_default", 0) + else: + key = (RoomCreateEvent.TYPE, "", ) + create_event = event.old_state_events.get(key) + if create_event.content["creator"] == user_id: + return 100 return level diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 85284a4919..53ca1f8f51 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -184,15 +184,7 @@ class SynapseHomeServer(HomeServer): logger.info("Synapse now listening on port %d", unsecure_port) -def setup(): - config = HomeServerConfig.load_config( - "Synapse Homeserver", - sys.argv[1:], - generate_section="Homeserver" - ) - - config.setup_logging() - +def setup(config, run_http=True): logger.info("Server hostname: %s", config.server_name) if re.search(":[0-9]+$", config.server_name): @@ -212,12 +204,13 @@ def setup(): content_addr=config.content_addr, ) - hs.register_servlets() + if run_http: + hs.register_servlets() - hs.create_resource_tree( - web_client=config.webclient, - redirect_root_to_web_client=True, - ) + hs.create_resource_tree( + web_client=config.webclient, + redirect_root_to_web_client=True, + ) db_name = hs.get_db_name() @@ -237,11 +230,18 @@ def setup(): f.namespace['hs'] = hs reactor.listenTCP(config.manhole, f, interface='127.0.0.1') - bind_port = config.bind_port - if config.no_tls: - bind_port = None - hs.start_listening(bind_port, config.unsecure_port) + if run_http: + bind_port = config.bind_port + if config.no_tls: + bind_port = None + hs.start_listening(bind_port, config.unsecure_port) + hs.config = config + + return hs + + +def run(config): if config.daemonize: print config.pid_file daemon = Daemonize( @@ -257,13 +257,26 @@ def setup(): else: reactor.run() -def run(): - with LoggingContext("run"): - reactor.run() -def main(): +def main(args, run_http=True): with LoggingContext("main"): - setup() + config = HomeServerConfig.load_config( + "Synapse Homeserver", + args, + generate_section="Homeserver" + ) + + config.setup_logging() + + hs = setup(config, run_http=run_http) + + def r(): + run(config) + hs.run = r + + return hs + if __name__ == '__main__': - main() + hs = main(sys.argv[1:]) + hs.run()