commit
						cd2539ab2a
					
				| 
						 | 
					@ -28,6 +28,12 @@ import logging
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					AuthEventTypes = (
 | 
				
			||||||
 | 
					    EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
 | 
				
			||||||
 | 
					    EventTypes.JoinRules,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Auth(object):
 | 
					class Auth(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, hs):
 | 
					    def __init__(self, hs):
 | 
				
			||||||
| 
						 | 
					@ -166,6 +172,7 @@ class Auth(object):
 | 
				
			||||||
        target = auth_events.get(key)
 | 
					        target = auth_events.get(key)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        target_in_room = target and target.membership == Membership.JOIN
 | 
					        target_in_room = target and target.membership == Membership.JOIN
 | 
				
			||||||
 | 
					        target_banned = target and target.membership == Membership.BAN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        key = (EventTypes.JoinRules, "", )
 | 
					        key = (EventTypes.JoinRules, "", )
 | 
				
			||||||
        join_rule_event = auth_events.get(key)
 | 
					        join_rule_event = auth_events.get(key)
 | 
				
			||||||
| 
						 | 
					@ -194,6 +201,7 @@ class Auth(object):
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
                "caller_in_room": caller_in_room,
 | 
					                "caller_in_room": caller_in_room,
 | 
				
			||||||
                "caller_invited": caller_invited,
 | 
					                "caller_invited": caller_invited,
 | 
				
			||||||
 | 
					                "target_banned": target_banned,
 | 
				
			||||||
                "target_in_room": target_in_room,
 | 
					                "target_in_room": target_in_room,
 | 
				
			||||||
                "membership": membership,
 | 
					                "membership": membership,
 | 
				
			||||||
                "join_rule": join_rule,
 | 
					                "join_rule": join_rule,
 | 
				
			||||||
| 
						 | 
					@ -202,6 +210,11 @@ class Auth(object):
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if ban_level:
 | 
				
			||||||
 | 
					            ban_level = int(ban_level)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            ban_level = 50  # FIXME (erikj): What should we do here?
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if Membership.INVITE == membership:
 | 
					        if Membership.INVITE == membership:
 | 
				
			||||||
            # TODO (erikj): We should probably handle this more intelligently
 | 
					            # TODO (erikj): We should probably handle this more intelligently
 | 
				
			||||||
            # PRIVATE join rules.
 | 
					            # PRIVATE join rules.
 | 
				
			||||||
| 
						 | 
					@ -212,6 +225,10 @@ class Auth(object):
 | 
				
			||||||
                    403,
 | 
					                    403,
 | 
				
			||||||
                    "%s not in room %s." % (event.user_id, event.room_id,)
 | 
					                    "%s not in room %s." % (event.user_id, event.room_id,)
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					            elif target_banned:
 | 
				
			||||||
 | 
					                raise AuthError(
 | 
				
			||||||
 | 
					                    403, "%s is banned from the room" % (target_user_id,)
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
            elif target_in_room:  # the target is already in the room.
 | 
					            elif target_in_room:  # the target is already in the room.
 | 
				
			||||||
                raise AuthError(403, "%s is already in the room." %
 | 
					                raise AuthError(403, "%s is already in the room." %
 | 
				
			||||||
                                     target_user_id)
 | 
					                                     target_user_id)
 | 
				
			||||||
| 
						 | 
					@ -221,6 +238,8 @@ class Auth(object):
 | 
				
			||||||
            # joined: It's a NOOP
 | 
					            # joined: It's a NOOP
 | 
				
			||||||
            if event.user_id != target_user_id:
 | 
					            if event.user_id != target_user_id:
 | 
				
			||||||
                raise AuthError(403, "Cannot force another user to join.")
 | 
					                raise AuthError(403, "Cannot force another user to join.")
 | 
				
			||||||
 | 
					            elif target_banned:
 | 
				
			||||||
 | 
					                raise AuthError(403, "You are banned from this room")
 | 
				
			||||||
            elif join_rule == JoinRules.PUBLIC:
 | 
					            elif join_rule == JoinRules.PUBLIC:
 | 
				
			||||||
                pass
 | 
					                pass
 | 
				
			||||||
            elif join_rule == JoinRules.INVITE:
 | 
					            elif join_rule == JoinRules.INVITE:
 | 
				
			||||||
| 
						 | 
					@ -238,6 +257,10 @@ class Auth(object):
 | 
				
			||||||
                    403,
 | 
					                    403,
 | 
				
			||||||
                    "%s not in room %s." % (target_user_id, event.room_id,)
 | 
					                    "%s not in room %s." % (target_user_id, event.room_id,)
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					            elif target_banned and user_level < ban_level:
 | 
				
			||||||
 | 
					                raise AuthError(
 | 
				
			||||||
 | 
					                    403, "You cannot unban user &s." % (target_user_id,)
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
            elif target_user_id != event.user_id:
 | 
					            elif target_user_id != event.user_id:
 | 
				
			||||||
                if kick_level:
 | 
					                if kick_level:
 | 
				
			||||||
                    kick_level = int(kick_level)
 | 
					                    kick_level = int(kick_level)
 | 
				
			||||||
| 
						 | 
					@ -249,11 +272,6 @@ class Auth(object):
 | 
				
			||||||
                        403, "You cannot kick user %s." % target_user_id
 | 
					                        403, "You cannot kick user %s." % target_user_id
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
        elif Membership.BAN == membership:
 | 
					        elif Membership.BAN == membership:
 | 
				
			||||||
            if ban_level:
 | 
					 | 
				
			||||||
                ban_level = int(ban_level)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                ban_level = 50  # FIXME (erikj): What should we do here?
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if user_level < ban_level:
 | 
					            if user_level < ban_level:
 | 
				
			||||||
                raise AuthError(403, "You don't have permission to ban")
 | 
					                raise AuthError(403, "You don't have permission to ban")
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
| 
						 | 
					@ -412,12 +430,6 @@ class Auth(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        builder.auth_events = auth_events_entries
 | 
					        builder.auth_events = auth_events_entries
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        context.auth_events = {
 | 
					 | 
				
			||||||
            k: v
 | 
					 | 
				
			||||||
            for k, v in context.current_state.items()
 | 
					 | 
				
			||||||
            if v.event_id in auth_ids
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def compute_auth_events(self, event, current_state):
 | 
					    def compute_auth_events(self, event, current_state):
 | 
				
			||||||
        if event.type == EventTypes.Create:
 | 
					        if event.type == EventTypes.Create:
 | 
				
			||||||
            return []
 | 
					            return []
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -16,8 +16,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class EventContext(object):
 | 
					class EventContext(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, current_state=None, auth_events=None):
 | 
					    def __init__(self, current_state=None):
 | 
				
			||||||
        self.current_state = current_state
 | 
					        self.current_state = current_state
 | 
				
			||||||
        self.auth_events = auth_events
 | 
					 | 
				
			||||||
        self.state_group = None
 | 
					        self.state_group = None
 | 
				
			||||||
        self.rejected = False
 | 
					        self.rejected = False
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -90,8 +90,8 @@ class BaseHandler(object):
 | 
				
			||||||
        event = builder.build()
 | 
					        event = builder.build()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logger.debug(
 | 
					        logger.debug(
 | 
				
			||||||
            "Created event %s with auth_events: %s, current state: %s",
 | 
					            "Created event %s with current state: %s",
 | 
				
			||||||
            event.event_id, context.auth_events, context.current_state,
 | 
					            event.event_id, context.current_state,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        defer.returnValue(
 | 
					        defer.returnValue(
 | 
				
			||||||
| 
						 | 
					@ -106,7 +106,7 @@ class BaseHandler(object):
 | 
				
			||||||
        # We now need to go and hit out to wherever we need to hit out to.
 | 
					        # We now need to go and hit out to wherever we need to hit out to.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not suppress_auth:
 | 
					        if not suppress_auth:
 | 
				
			||||||
            self.auth.check(event, auth_events=context.auth_events)
 | 
					            self.auth.check(event, auth_events=context.current_state)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        yield self.store.persist_event(event, context=context)
 | 
					        yield self.store.persist_event(event, context=context)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -464,11 +464,9 @@ class FederationHandler(BaseHandler):
 | 
				
			||||||
            builder=builder,
 | 
					            builder=builder,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.auth.check(event, auth_events=context.auth_events)
 | 
					        self.auth.check(event, auth_events=context.current_state)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        pdu = event
 | 
					        defer.returnValue(event)
 | 
				
			||||||
 | 
					 | 
				
			||||||
        defer.returnValue(pdu)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @defer.inlineCallbacks
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
    @log_function
 | 
					    @log_function
 | 
				
			||||||
| 
						 | 
					@ -705,7 +703,7 @@ class FederationHandler(BaseHandler):
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not auth_events:
 | 
					        if not auth_events:
 | 
				
			||||||
            auth_events = context.auth_events
 | 
					            auth_events = context.current_state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        logger.debug(
 | 
					        logger.debug(
 | 
				
			||||||
            "_handle_new_event: %s, auth_events: %s",
 | 
					            "_handle_new_event: %s, auth_events: %s",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -21,6 +21,7 @@ from synapse.util.async import run_on_reactor
 | 
				
			||||||
from synapse.util.expiringcache import ExpiringCache
 | 
					from synapse.util.expiringcache import ExpiringCache
 | 
				
			||||||
from synapse.api.constants import EventTypes
 | 
					from synapse.api.constants import EventTypes
 | 
				
			||||||
from synapse.api.errors import AuthError
 | 
					from synapse.api.errors import AuthError
 | 
				
			||||||
 | 
					from synapse.api.auth import AuthEventTypes
 | 
				
			||||||
from synapse.events.snapshot import EventContext
 | 
					from synapse.events.snapshot import EventContext
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from collections import namedtuple
 | 
					from collections import namedtuple
 | 
				
			||||||
| 
						 | 
					@ -38,12 +39,6 @@ def _get_state_key_from_event(event):
 | 
				
			||||||
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 | 
					KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
AuthEventTypes = (
 | 
					 | 
				
			||||||
    EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
 | 
					 | 
				
			||||||
    EventTypes.JoinRules,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
SIZE_OF_CACHE = 1000
 | 
					SIZE_OF_CACHE = 1000
 | 
				
			||||||
EVICTION_TIMEOUT_SECONDS = 20
 | 
					EVICTION_TIMEOUT_SECONDS = 20
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -139,18 +134,6 @@ class StateHandler(object):
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            context.state_group = None
 | 
					            context.state_group = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if hasattr(event, "auth_events") and event.auth_events:
 | 
					 | 
				
			||||||
                auth_ids = self.hs.get_auth().compute_auth_events(
 | 
					 | 
				
			||||||
                    event, context.current_state
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                context.auth_events = {
 | 
					 | 
				
			||||||
                    k: v
 | 
					 | 
				
			||||||
                    for k, v in context.current_state.items()
 | 
					 | 
				
			||||||
                    if v.event_id in auth_ids
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                context.auth_events = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if event.is_state():
 | 
					            if event.is_state():
 | 
				
			||||||
                key = (event.type, event.state_key)
 | 
					                key = (event.type, event.state_key)
 | 
				
			||||||
                if key in context.current_state:
 | 
					                if key in context.current_state:
 | 
				
			||||||
| 
						 | 
					@ -187,18 +170,6 @@ class StateHandler(object):
 | 
				
			||||||
                replaces = context.current_state[key]
 | 
					                replaces = context.current_state[key]
 | 
				
			||||||
                event.unsigned["replaces_state"] = replaces.event_id
 | 
					                event.unsigned["replaces_state"] = replaces.event_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if hasattr(event, "auth_events") and event.auth_events:
 | 
					 | 
				
			||||||
            auth_ids = self.hs.get_auth().compute_auth_events(
 | 
					 | 
				
			||||||
                event, context.current_state
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            context.auth_events = {
 | 
					 | 
				
			||||||
                k: v
 | 
					 | 
				
			||||||
                for k, v in context.current_state.items()
 | 
					 | 
				
			||||||
                if v.event_id in auth_ids
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            context.auth_events = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        context.prev_state_events = prev_state
 | 
					        context.prev_state_events = prev_state
 | 
				
			||||||
        defer.returnValue(context)
 | 
					        defer.returnValue(context)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -82,7 +82,7 @@ class StateStore(SQLBaseStore):
 | 
				
			||||||
        if context.current_state is None:
 | 
					        if context.current_state is None:
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        state_events = context.current_state
 | 
					        state_events = dict(context.current_state)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if event.is_state():
 | 
					        if event.is_state():
 | 
				
			||||||
            state_events[(event.type, event.state_key)] = event
 | 
					            state_events[(event.type, event.state_key)] = event
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue