Merge branch 'develop' of github.com:matrix-org/synapse into erikj/paginate_sync

erikj/paginate_sync
Erik Johnston 2016-05-18 11:38:24 +01:00
commit 5941346c5b
36 changed files with 370 additions and 427 deletions

View File

@ -2,6 +2,11 @@ body {
margin: 0px;
}
pre, code {
word-break: break-word;
white-space: pre-wrap;
}
#page {
font-family: 'Open Sans', Helvetica, Arial, Sans-Serif;
font-color: #454545;

View File

@ -30,6 +30,17 @@
{% include 'room.html' with context %}
{% endfor %}
<div class="footer">
<small>
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room '{{ reason.room_name }}' because:<br/>
1. An event was received at {{ reason.received_at|format_ts("%c") }}
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago.<br/>
{% if reason.last_sent_ts %}
2. The last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
{% else %}
2. We can't remember the last time we sent a mail for this room.
{% endif %}
</small>
<a href="{{ unsubscribe_link }}">Unsubscribe</a>
</div>
</td>

View File

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, UserID, get_domian_from_id
from synapse.types import Requester, UserID, get_domain_from_id
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure
@ -91,8 +91,8 @@ class Auth(object):
"Room %r does not exist" % (event.room_id,)
)
creating_domain = get_domian_from_id(event.room_id)
originating_domain = get_domian_from_id(event.sender)
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
@ -219,7 +219,7 @@ class Auth(object):
for event in curr_state.values():
if event.type == EventTypes.Member:
try:
if get_domian_from_id(event.state_key) != host:
if get_domain_from_id(event.state_key) != host:
continue
except:
logger.warn("state_key not user_id: %s", event.state_key)
@ -266,8 +266,8 @@ class Auth(object):
target_user_id = event.state_key
creating_domain = get_domian_from_id(event.room_id)
target_domain = get_domian_from_id(target_user_id)
creating_domain = get_domain_from_id(event.room_id)
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
@ -890,8 +890,8 @@ class Auth(object):
if user_level >= redact_level:
return False
redacter_domain = get_domian_from_id(event.event_id)
redactee_domain = get_domian_from_id(event.redacts)
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True

View File

@ -66,6 +66,10 @@ def main():
config = yaml.load(open(configfile))
pidfile = config["pid_file"]
cache_factor = config.get("synctl_cache_factor", None)
if cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
action = sys.argv[1] if sys.argv[1:] else "usage"
if action == "start":

View File

@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
from ._base import Config, ConfigError
from synapse.appservice import ApplicationService
from synapse.types import UserID
import urllib
import yaml
import logging
logger = logging.getLogger(__name__)
class AppServiceConfig(Config):
@ -25,3 +34,99 @@ class AppServiceConfig(Config):
# A list of application service config file to use
app_service_config_files: []
"""
def load_appservices(hostname, config_files):
"""Returns a list of Application Services from the config files."""
if not isinstance(config_files, list):
logger.warning(
"Expected %s to be a list of AS config files.", config_files
)
return []
# Dicts of value -> filename
seen_as_tokens = {}
seen_ids = {}
appservices = []
for config_file in config_files:
try:
with open(config_file, 'r') as f:
appservice = _load_appservice(
hostname, yaml.load(f), config_file
)
if appservice.id in seen_ids:
raise ConfigError(
"Cannot reuse ID across application services: "
"%s (files: %s, %s)" % (
appservice.id, config_file, seen_ids[appservice.id],
)
)
seen_ids[appservice.id] = config_file
if appservice.token in seen_as_tokens:
raise ConfigError(
"Cannot reuse as_token across application services: "
"%s (files: %s, %s)" % (
appservice.token,
config_file,
seen_as_tokens[appservice.token],
)
)
seen_as_tokens[appservice.token] = config_file
logger.info("Loaded application service: %s", appservice)
appservices.append(appservice)
except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e)
raise
return appservices
def _load_appservice(hostname, as_info, config_filename):
required_string_fields = [
"id", "url", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
if not isinstance(as_info.get(field), basestring):
raise KeyError("Required string field: '%s' (%s)" % (
field, config_filename,
))
localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
user = UserID(localpart, hostname)
user_id = user.to_string()
# namespace checks
if not isinstance(as_info.get("namespaces"), dict):
raise KeyError("Requires 'namespaces' object.")
for ns in ApplicationService.NS_LIST:
# specific namespaces are optional
if ns in as_info["namespaces"]:
# expect a list of dicts with exclusive and regex keys
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
"Expected namespace entry in %s to be an object,"
" but got %s", ns, regex_obj
)
if not isinstance(regex_obj.get("regex"), basestring):
raise ValueError(
"Missing/bad type 'regex' key in %s", regex_obj
)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
)

View File

@ -100,8 +100,13 @@ class ContentRepositoryConfig(Config):
"to work"
)
if "url_preview_url_blacklist" in config:
self.url_preview_url_blacklist = config["url_preview_url_blacklist"]
self.url_preview_ip_range_whitelist = IPSet(
config.get("url_preview_ip_range_whitelist", ())
)
self.url_preview_url_blacklist = config.get(
"url_preview_url_blacklist", ()
)
def default_config(self, **kwargs):
media_store = self.default_path("media_store")
@ -162,6 +167,15 @@ class ContentRepositoryConfig(Config):
# - '10.0.0.0/8'
# - '172.16.0.0/12'
# - '192.168.0.0/16'
#
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
# This is useful for specifying exceptions to wide-ranging blacklisted
# target IP ranges - e.g. for enabling URL previews for a specific private
# website only visible in your network.
#
# url_preview_ip_range_whitelist:
# - '192.168.1.1'
# Optional list of URL matches that the URL preview spider is
# denied from accessing. You should use url_preview_ip_range_blacklist

View File

@ -24,12 +24,9 @@ from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler
from .profile import ProfileHandler
from .presence import PresenceHandler
from .directory import DirectoryHandler
from .typing import TypingNotificationHandler
from .admin import AdminHandler
from .appservice import ApplicationServicesHandler
from .sync import SyncHandler
from .auth import AuthHandler
from .identity import IdentityHandler
from .receipts import ReceiptsHandler
@ -53,10 +50,8 @@ class Handlers(object):
self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs)
self.presence_handler = PresenceHandler(hs)
self.room_list_handler = RoomListHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs)
asapi = ApplicationServiceApi(hs)
@ -67,7 +62,6 @@ class Handlers(object):
as_api=asapi
)
)
self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs)

View File

@ -58,7 +58,7 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down.
"""
auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_handlers().presence_handler
presence_handler = self.hs.get_presence_handler()
context = yield presence_handler.user_syncing(
auth_user_id, affect_presence=affect_presence,

View File

@ -33,7 +33,7 @@ from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures,
)
from synapse.types import UserID, get_domian_from_id
from synapse.types import UserID, get_domain_from_id
from synapse.events.utils import prune_event
@ -453,7 +453,7 @@ class FederationHandler(BaseHandler):
joined_domains = {}
for u, d in joined_users:
try:
dom = get_domian_from_id(u)
dom = get_domain_from_id(u)
old_d = joined_domains.get(dom)
if old_d:
joined_domains[dom] = min(d, old_d)
@ -744,7 +744,7 @@ class FederationHandler(BaseHandler):
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(get_domian_from_id(s.state_key))
destinations.add(get_domain_from_id(s.state_key))
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
@ -970,7 +970,7 @@ class FederationHandler(BaseHandler):
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE:
destinations.add(get_domian_from_id(s.state_key))
destinations.add(get_domain_from_id(s.state_key))
except:
logger.warn(
"Failed to get destination from event %s", s.event_id

View File

@ -23,7 +23,7 @@ from synapse.events.validator import EventValidator
from synapse.push.action_generator import ActionGenerator
from synapse.streams.config import PaginationConfig
from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domian_from_id
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
)
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
@ -236,7 +236,7 @@ class MessageHandler(BaseHandler):
)
if event.type == EventTypes.Message:
presence = self.hs.get_handlers().presence_handler
presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user)
def deduplicate_state_event(self, event, context):
@ -674,7 +674,7 @@ class MessageHandler(BaseHandler):
and m.content["membership"] == Membership.JOIN
]
presence_handler = self.hs.get_handlers().presence_handler
presence_handler = self.hs.get_presence_handler()
@defer.inlineCallbacks
def get_presence():
@ -902,7 +902,7 @@ class MessageHandler(BaseHandler):
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(get_domian_from_id(s.state_key))
destinations.add(get_domain_from_id(s.state_key))
except SynapseError:
logger.warn(
"Failed to get destination from event %s", s.event_id

View File

@ -33,11 +33,9 @@ from synapse.util.logcontext import preserve_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domian_from_id
from synapse.types import UserID, get_domain_from_id
import synapse.metrics
from ._base import BaseHandler
import logging
@ -73,11 +71,11 @@ FEDERATION_PING_INTERVAL = 25 * 60 * 1000
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(BaseHandler):
class PresenceHandler(object):
def __init__(self, hs):
super(PresenceHandler, self).__init__(hs)
self.hs = hs
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.wheel_timer = WheelTimer()
@ -138,7 +136,7 @@ class PresenceHandler(BaseHandler):
obj=state.user_id,
then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
)
if self.hs.is_mine_id(state.user_id):
if self.is_mine_id(state.user_id):
self.wheel_timer.insert(
now=now,
obj=state.user_id,
@ -228,7 +226,7 @@ class PresenceHandler(BaseHandler):
new_state, should_notify, should_ping = handle_update(
prev_state, new_state,
is_mine=self.hs.is_mine_id(user_id),
is_mine=self.is_mine_id(user_id),
wheel_timer=self.wheel_timer,
now=now
)
@ -287,7 +285,7 @@ class PresenceHandler(BaseHandler):
changes = handle_timeouts(
states,
is_mine_fn=self.hs.is_mine_id,
is_mine_fn=self.is_mine_id,
user_to_num_current_syncs=self.user_to_num_current_syncs,
now=now,
)
@ -427,7 +425,7 @@ class PresenceHandler(BaseHandler):
hosts_to_states = {}
for room_id, states in room_ids_to_states.items():
local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states)
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
if not local_states:
continue
@ -436,11 +434,11 @@ class PresenceHandler(BaseHandler):
hosts_to_states.setdefault(host, []).extend(local_states)
for user_id, states in users_to_states.items():
local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states)
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
if not local_states:
continue
host = get_domian_from_id(user_id)
host = get_domain_from_id(user_id)
hosts_to_states.setdefault(host, []).extend(local_states)
# TODO: de-dup hosts_to_states, as a single host might have multiple
@ -611,14 +609,14 @@ class PresenceHandler(BaseHandler):
# don't need to send to local clients here, as that is done as part
# of the event stream/sync.
# TODO: Only send to servers not already in the room.
if self.hs.is_mine(user):
if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string())
hosts = yield self.store.get_joined_hosts_for_room(room_id)
self._push_to_remotes({host: (state,) for host in hosts})
else:
user_ids = yield self.store.get_users_in_room(room_id)
user_ids = filter(self.hs.is_mine_id, user_ids)
user_ids = filter(self.is_mine_id, user_ids)
states = yield self.current_state_for_users(user_ids)
@ -628,7 +626,7 @@ class PresenceHandler(BaseHandler):
def get_presence_list(self, observer_user, accepted=None):
"""Returns the presence for all users in their presence list.
"""
if not self.hs.is_mine(observer_user):
if not self.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
presence_list = yield self.store.get_presence_list(
@ -659,7 +657,7 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string()
)
if self.hs.is_mine(observed_user):
if self.is_mine(observed_user):
yield self.invite_presence(observed_user, observer_user)
else:
yield self.federation.send_edu(
@ -675,11 +673,11 @@ class PresenceHandler(BaseHandler):
def invite_presence(self, observed_user, observer_user):
"""Handles new presence invites.
"""
if not self.hs.is_mine(observed_user):
if not self.is_mine(observed_user):
raise SynapseError(400, "User is not hosted on this Home Server")
# TODO: Don't auto accept
if self.hs.is_mine(observer_user):
if self.is_mine(observer_user):
yield self.accept_presence(observed_user, observer_user)
else:
self.federation.send_edu(
@ -742,7 +740,7 @@ class PresenceHandler(BaseHandler):
Returns:
A Deferred.
"""
if not self.hs.is_mine(observer_user):
if not self.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
yield self.store.del_presence_list(
@ -834,7 +832,11 @@ def _format_user_presence_state(state, now):
class PresenceEventSource(object):
def __init__(self, hs):
self.hs = hs
# We can't call get_presence_handler here because there's a cycle:
#
# Presence -> Notifier -> PresenceEventSource -> Presence
#
self.get_presence_handler = hs.get_presence_handler
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@ -860,7 +862,7 @@ class PresenceEventSource(object):
from_key = int(from_key)
room_ids = room_ids or []
presence = self.hs.get_handlers().presence_handler
presence = self.get_presence_handler()
stream_change_cache = self.store.presence_stream_cache
if not room_ids:

View File

@ -29,6 +29,8 @@ class ReceiptsHandler(BaseHandler):
def __init__(self, hs):
super(ReceiptsHandler, self).__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
@ -131,12 +133,9 @@ class ReceiptsHandler(BaseHandler):
event_ids = receipt["event_ids"]
data = receipt["data"]
remotedomains = set()
rm_handler = self.hs.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=None, remotedomains=remotedomains
)
remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
remotedomains = remotedomains.copy()
remotedomains.discard(self.server_name)
logger.debug("Sending receipt to: %r", remotedomains)

View File

@ -55,35 +55,6 @@ class RoomMemberHandler(BaseHandler):
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def _local_membership_update(
self, requester, target, room_id, membership,
@ -426,21 +397,6 @@ class RoomMemberHandler(BaseHandler):
if invite:
defer.returnValue(UserID.from_string(invite.sender))
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks
def do_3pid_invite(
self,
@ -457,8 +413,7 @@ class RoomMemberHandler(BaseHandler):
)
if invitee:
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
yield self.update_membership(
requester,
UserID.from_string(invitee),
room_id,

View File

@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseHandler
from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes
from synapse.util.async import concurrently_execute
@ -135,10 +133,12 @@ class SyncResult(collections.namedtuple("SyncResult", [
)
class SyncHandler(BaseHandler):
class SyncHandler(object):
def __init__(self, hs):
super(SyncHandler, self).__init__(hs)
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache()
@ -604,7 +604,6 @@ class SyncHandler(BaseHandler):
sync_config, now_token, since_token
)
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string()
)
@ -612,9 +611,10 @@ class SyncHandler(BaseHandler):
rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms)
else:
joined_room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
rooms = yield self.store.get_rooms_for_user(
sync_config.user.to_string()
)
joined_room_ids = set(r.room_id for r in rooms)
user_id = sync_config.user.to_string()
@ -785,7 +785,7 @@ class SyncHandler(BaseHandler):
# For each newly joined room, we want to send down presence of
# existing users.
presence_handler = self.hs.get_handlers().presence_handler
presence_handler = self.presence_handler
extra_presence_users = set()
for room_id in newly_joined_rooms:
users = yield self.store.get_users_in_room(event.room_id)

View File

@ -15,8 +15,6 @@
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.metrics import Measure
@ -35,11 +33,13 @@ logger = logging.getLogger(__name__)
RoomMember = namedtuple("RoomMember", ("room_id", "user"))
class TypingNotificationHandler(BaseHandler):
class TypingHandler(object):
def __init__(self, hs):
super(TypingNotificationHandler, self).__init__(hs)
self.homeserver = hs
self.store = hs.get_datastore()
self.server_name = hs.config.server_name
self.auth = hs.get_auth()
self.is_mine = hs.is_mine
self.notifier = hs.get_notifier()
self.clock = hs.get_clock()
@ -67,7 +67,7 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout):
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user:
@ -110,7 +110,7 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks
def stopped_typing(self, target_user, auth_user, room_id):
if not self.hs.is_mine(target_user):
if not self.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user:
@ -132,7 +132,7 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
if self.hs.is_mine(user):
if self.is_mine(user):
member = RoomMember(room_id=room_id, user=user)
yield self._stopped_typing(member)
@ -157,32 +157,26 @@ class TypingNotificationHandler(BaseHandler):
@defer.inlineCallbacks
def _push_update(self, room_id, user, typing):
localusers = set()
remotedomains = set()
rm_handler = self.homeserver.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=localusers, remotedomains=remotedomains
)
if localusers:
self._push_update_local(
room_id=room_id,
user=user,
typing=typing
)
domains = yield self.store.get_joined_hosts_for_room(room_id)
deferreds = []
for domain in remotedomains:
deferreds.append(self.federation.send_edu(
destination=domain,
edu_type="m.typing",
content={
"room_id": room_id,
"user_id": user.to_string(),
"typing": typing,
},
))
for domain in domains:
if domain == self.server_name:
self._push_update_local(
room_id=room_id,
user=user,
typing=typing
)
else:
deferreds.append(self.federation.send_edu(
destination=domain,
edu_type="m.typing",
content={
"room_id": room_id,
"user_id": user.to_string(),
"typing": typing,
},
))
yield defer.DeferredList(deferreds, consumeErrors=True)
@ -191,14 +185,9 @@ class TypingNotificationHandler(BaseHandler):
room_id = content["room_id"]
user = UserID.from_string(content["user_id"])
localusers = set()
domains = yield self.store.get_joined_hosts_for_room(room_id)
rm_handler = self.homeserver.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=localusers
)
if localusers:
if self.server_name in domains:
self._push_update_local(
room_id=room_id,
user=user,
@ -238,22 +227,14 @@ class TypingNotificationEventSource(object):
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self._handler = None
self._room_member_handler = None
def handler(self):
# Avoid cyclic dependency in handler setup
if not self._handler:
self._handler = self.hs.get_handlers().typing_notification_handler
return self._handler
def room_member_handler(self):
if not self._room_member_handler:
self._room_member_handler = self.hs.get_handlers().room_member_handler
return self._room_member_handler
# We can't call get_typing_handler here because there's a cycle:
#
# Typing -> Notifier -> TypingNotificationEventSource -> Typing
#
self.get_typing_handler = hs.get_typing_handler
def _make_event_for(self, room_id):
typing = self.handler()._room_typing[room_id]
typing = self.get_typing_handler()._room_typing[room_id]
return {
"type": "m.typing",
"room_id": room_id,
@ -265,7 +246,7 @@ class TypingNotificationEventSource(object):
def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
handler = self.handler()
handler = self.get_typing_handler()
events = []
for room_id in room_ids:
@ -279,7 +260,7 @@ class TypingNotificationEventSource(object):
return events, handler._latest_room_serial
def get_current_key(self):
return self.handler()._latest_room_serial
return self.get_typing_handler()._latest_room_serial
def get_pagination_rows(self, user, pagination_config, key):
return ([], pagination_config.from_key)

View File

@ -380,13 +380,14 @@ class CaptchaServerHttpClient(SimpleHttpClient):
class SpiderEndpointFactory(object):
def __init__(self, hs):
self.blacklist = hs.config.url_preview_ip_range_blacklist
self.whitelist = hs.config.url_preview_ip_range_whitelist
self.policyForHTTPS = hs.get_http_client_context_factory()
def endpointForURI(self, uri):
logger.info("Getting endpoint for %s", uri.toBytes())
if uri.scheme == "http":
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist,
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=TCP4ClientEndpoint,
endpoint_kw_args={
'timeout': 15
@ -395,7 +396,7 @@ class SpiderEndpointFactory(object):
elif uri.scheme == "https":
tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist,
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=SSL4ClientEndpoint,
endpoint_kw_args={
'sslContextFactory': tlsPolicy,

View File

@ -79,12 +79,13 @@ class SpiderEndpoint(object):
"""An endpoint which refuses to connect to blacklisted IP addresses
Implements twisted.internet.interfaces.IStreamClientEndpoint.
"""
def __init__(self, reactor, host, port, blacklist,
def __init__(self, reactor, host, port, blacklist, whitelist,
endpoint=TCP4ClientEndpoint, endpoint_kw_args={}):
self.reactor = reactor
self.host = host
self.port = port
self.blacklist = blacklist
self.whitelist = whitelist
self.endpoint = endpoint
self.endpoint_kw_args = endpoint_kw_args
@ -93,10 +94,13 @@ class SpiderEndpoint(object):
address = yield self.reactor.resolve(self.host)
from netaddr import IPAddress
if IPAddress(address) in self.blacklist:
raise ConnectError(
"Refusing to spider blacklisted IP address %s" % address
)
ip_address = IPAddress(address)
if ip_address in self.blacklist:
if self.whitelist is None or ip_address not in self.whitelist:
raise ConnectError(
"Refusing to spider blacklisted IP address %s" % address
)
logger.info("Connecting to %s:%s", address, self.port)
endpoint = self.endpoint(

View File

@ -26,11 +26,14 @@ logger = logging.getLogger(__name__)
# The amount of time we always wait before ever emailing about a notification
# (to give the user a chance to respond to other push or notice the window)
DELAY_BEFORE_MAIL_MS = 2 * 60 * 1000
DELAY_BEFORE_MAIL_MS = 10 * 60 * 1000
THROTTLE_START_MS = 2 * 60 * 1000
THROTTLE_MAX_MS = (2 * 60 * 1000) * (2 ** 11) # ~3 days
THROTTLE_MULTIPLIER = 2
# THROTTLE is the minimum time between mail notifications sent for a given room.
# Each room maintains its own throttle counter, but each new mail notification
# sends the pending notifications for all rooms.
THROTTLE_START_MS = 10 * 60 * 1000
THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # (2 * 60 * 1000) * (2 ** 11) # ~3 days
THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours
# If no event triggers a notification for this long after the previous,
# the throttle is released.
@ -146,7 +149,18 @@ class EmailPusher(object):
# *one* email updating the user on their notifications,
# we then consider all previously outstanding notifications
# to be delivered.
yield self.send_notification(unprocessed)
# debugging:
reason = {
'room_id': push_action['room_id'],
'now': self.clock.time_msec(),
'received_at': received_at,
'delay_before_mail_ms': DELAY_BEFORE_MAIL_MS,
'last_sent_ts': self.get_room_last_sent_ts(push_action['room_id']),
'throttle_ms': self.get_room_throttle_ms(push_action['room_id']),
}
yield self.send_notification(unprocessed, reason)
yield self.save_last_stream_ordering_and_success(max([
ea['stream_ordering'] for ea in unprocessed
@ -195,7 +209,8 @@ class EmailPusher(object):
"""
Determines whether throttling should prevent us from sending an email
for the given room
Returns: True if we should send, False if we should not
Returns: The timestamp when we are next allowed to send an email notif
for this room
"""
last_sent_ts = self.get_room_last_sent_ts(room_id)
throttle_ms = self.get_room_throttle_ms(room_id)
@ -244,8 +259,9 @@ class EmailPusher(object):
)
@defer.inlineCallbacks
def send_notification(self, push_actions):
def send_notification(self, push_actions, reason):
logger.info("Sending notif email for user %r", self.user_id)
yield self.mailer.send_notification_mail(
self.user_id, self.email, push_actions
self.user_id, self.email, push_actions, reason
)

View File

@ -92,7 +92,7 @@ class Mailer(object):
)
@defer.inlineCallbacks
def send_notification_mail(self, user_id, email_address, push_actions):
def send_notification_mail(self, user_id, email_address, push_actions, reason):
raw_from = email.utils.parseaddr(self.hs.config.email_notif_from)[1]
raw_to = email.utils.parseaddr(email_address)[1]
@ -143,12 +143,17 @@ class Mailer(object):
notifs_by_room, state_by_room, notif_events, user_id
)
reason['room_name'] = calculate_room_name(
state_by_room[reason['room_id']], user_id, fallback_to_members=False
)
template_vars = {
"user_display_name": user_display_name,
"unsubscribe_link": self.make_unsubscribe_link(),
"summary_text": summary_text,
"app_name": self.app_name,
"rooms": rooms,
"reason": reason,
}
html_text = self.notif_template_html.render(**template_vars)

View File

@ -18,6 +18,17 @@ from httppusher import HttpPusher
import logging
logger = logging.getLogger(__name__)
# We try importing this if we can (it will fail if we don't
# have the optional email dependencies installed). We don't
# yet have the config to know if we need the email pusher,
# but importing this after daemonizing seems to fail
# (even though a simple test of importing from a daemonized
# process works fine)
try:
from synapse.push.emailpusher import EmailPusher
except:
pass
def create_pusher(hs, pusherdict):
logger.info("trying to create_pusher for %r", pusherdict)
@ -28,7 +39,6 @@ def create_pusher(hs, pusherdict):
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
from synapse.push.emailpusher import EmailPusher
PUSHER_TYPES["email"] = EmailPusher
logger.info("defined email pusher type")

View File

@ -109,8 +109,8 @@ class ReplicationResource(Resource):
self.version_string = hs.version_string
self.store = hs.get_datastore()
self.sources = hs.get_event_sources()
self.presence_handler = hs.get_handlers().presence_handler
self.typing_handler = hs.get_handlers().typing_notification_handler
self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler()
self.notifier = hs.notifier
self.clock = hs.get_clock()

View File

@ -30,20 +30,24 @@ logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
def __init__(self, hs):
super(PresenceStatusRestServlet, self).__init__(hs)
self.presence_handler = hs.get_presence_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if requester.user != user:
allowed = yield self.handlers.presence_handler.is_visible(
allowed = yield self.presence_handler.is_visible(
observed_user=user, observer_user=requester.user,
)
if not allowed:
raise AuthError(403, "You are not allowed to see their presence.")
state = yield self.handlers.presence_handler.get_state(target_user=user)
state = yield self.presence_handler.get_state(target_user=user)
defer.returnValue((200, state))
@ -74,7 +78,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
except:
raise SynapseError(400, "Unable to parse state")
yield self.handlers.presence_handler.set_state(user, state)
yield self.presence_handler.set_state(user, state)
defer.returnValue((200, {}))
@ -85,6 +89,10 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
class PresenceListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(PresenceListRestServlet, self).__init__(hs)
self.presence_handler = hs.get_presence_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
requester = yield self.auth.get_user_by_req(request)
@ -96,7 +104,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if requester.user != user:
raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.handlers.presence_handler.get_presence_list(
presence = yield self.presence_handler.get_presence_list(
observer_user=user, accepted=True
)
@ -123,7 +131,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if len(u) == 0:
continue
invited_user = UserID.from_string(u)
yield self.handlers.presence_handler.send_presence_invite(
yield self.presence_handler.send_presence_invite(
observer_user=user, observed_user=invited_user
)
@ -134,7 +142,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if len(u) == 0:
continue
dropped_user = UserID.from_string(u)
yield self.handlers.presence_handler.drop(
yield self.presence_handler.drop(
observer_user=user, observed_user=dropped_user
)

View File

@ -570,7 +570,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomTypingRestServlet, self).__init__(hs)
self.presence_handler = hs.get_handlers().presence_handler
self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler()
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
@ -581,19 +582,17 @@ class RoomTypingRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
typing_handler = self.handlers.typing_notification_handler
yield self.presence_handler.bump_presence_active_time(requester.user)
if content["typing"]:
yield typing_handler.started_typing(
yield self.typing_handler.started_typing(
target_user=target_user,
auth_user=requester.user,
room_id=room_id,
timeout=content.get("timeout", 30000),
)
else:
yield typing_handler.stopped_typing(
yield self.typing_handler.stopped_typing(
target_user=target_user,
auth_user=requester.user,
room_id=room_id,

View File

@ -37,7 +37,7 @@ class ReceiptRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_handlers().receipts_handler
self.presence_handler = hs.get_handlers().presence_handler
self.presence_handler = hs.get_presence_handler()
@defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id):

View File

@ -79,11 +79,10 @@ class SyncRestServlet(RestServlet):
def __init__(self, hs):
super(SyncRestServlet, self).__init__()
self.auth = hs.get_auth()
self.event_stream_handler = hs.get_handlers().event_stream_handler
self.sync_handler = hs.get_handlers().sync_handler
self.sync_handler = hs.get_sync_handler()
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
self.presence_handler = hs.get_handlers().presence_handler
self.presence_handler = hs.get_presence_handler()
@defer.inlineCallbacks
def on_GET(self, request):

View File

@ -56,8 +56,7 @@ class PreviewUrlResource(Resource):
self.client = SpiderHttpClient(hs)
self.media_repo = media_repo
if hasattr(hs.config, "url_preview_url_blacklist"):
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
# simple memory cache mapping urls to OG metadata
self.cache = ExpiringCache(
@ -86,39 +85,37 @@ class PreviewUrlResource(Resource):
else:
ts = self.clock.time_msec()
# impose the URL pattern blacklist
if hasattr(self, "url_preview_url_blacklist"):
url_tuple = urlparse.urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
for attrib in entry:
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
logger.debug((
"Matching attrib '%s' with value '%s' against"
" pattern '%s'"
) % (attrib, value, pattern))
url_tuple = urlparse.urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
for attrib in entry:
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
logger.debug((
"Matching attrib '%s' with value '%s' against"
" pattern '%s'"
) % (attrib, value, pattern))
if value is None:
if value is None:
match = False
continue
if pattern.startswith('^'):
if not re.match(pattern, getattr(url_tuple, attrib)):
match = False
continue
if pattern.startswith('^'):
if not re.match(pattern, getattr(url_tuple, attrib)):
match = False
continue
else:
if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
match = False
continue
if match:
logger.warn(
"URL %s blocked by url_blacklist entry %s", url, entry
)
raise SynapseError(
403, "URL blocked by url pattern blacklist entry",
Codes.UNKNOWN
)
else:
if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
match = False
continue
if match:
logger.warn(
"URL %s blocked by url_blacklist entry %s", url, entry
)
raise SynapseError(
403, "URL blocked by url pattern blacklist entry",
Codes.UNKNOWN
)
# first check the memory cache - good to handle all the clients on this
# HS thundering away to preview the same URL at the same time.

View File

@ -27,6 +27,9 @@ from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFa
from synapse.notifier import Notifier
from synapse.api.auth import Auth
from synapse.handlers import Handlers
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.state import StateHandler
from synapse.storage import DataStore
from synapse.util import Clock
@ -78,6 +81,9 @@ class HomeServer(object):
'auth',
'rest_servlet_factory',
'state_handler',
'presence_handler',
'sync_handler',
'typing_handler',
'notifier',
'distributor',
'client_resource',
@ -164,6 +170,15 @@ class HomeServer(object):
def build_state_handler(self):
return StateHandler(self)
def build_presence_handler(self):
return PresenceHandler(self)
def build_typing_handler(self):
return TypingHandler(self)
def build_sync_handler(self):
return SyncHandler(self)
def build_event_sources(self):
return EventSources(self)

View File

@ -13,16 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
import yaml
import simplejson as json
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.config._base import ConfigError
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage.roommember import RoomsForUser
from synapse.types import UserID
from ._base import SQLBaseStore
@ -34,7 +31,7 @@ class ApplicationServiceStore(SQLBaseStore):
def __init__(self, hs):
super(ApplicationServiceStore, self).__init__(hs)
self.hostname = hs.hostname
self.services_cache = ApplicationServiceStore.load_appservices(
self.services_cache = load_appservices(
hs.hostname,
hs.config.app_service_config_files
)
@ -144,102 +141,6 @@ class ApplicationServiceStore(SQLBaseStore):
return rooms_for_user_matching_user_id
@classmethod
def _load_appservice(cls, hostname, as_info, config_filename):
required_string_fields = [
"id", "url", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
if not isinstance(as_info.get(field), basestring):
raise KeyError("Required string field: '%s' (%s)" % (
field, config_filename,
))
localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
user = UserID(localpart, hostname)
user_id = user.to_string()
# namespace checks
if not isinstance(as_info.get("namespaces"), dict):
raise KeyError("Requires 'namespaces' object.")
for ns in ApplicationService.NS_LIST:
# specific namespaces are optional
if ns in as_info["namespaces"]:
# expect a list of dicts with exclusive and regex keys
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
"Expected namespace entry in %s to be an object,"
" but got %s", ns, regex_obj
)
if not isinstance(regex_obj.get("regex"), basestring):
raise ValueError(
"Missing/bad type 'regex' key in %s", regex_obj
)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
)
@classmethod
def load_appservices(cls, hostname, config_files):
"""Returns a list of Application Services from the config files."""
if not isinstance(config_files, list):
logger.warning(
"Expected %s to be a list of AS config files.", config_files
)
return []
# Dicts of value -> filename
seen_as_tokens = {}
seen_ids = {}
appservices = []
for config_file in config_files:
try:
with open(config_file, 'r') as f:
appservice = ApplicationServiceStore._load_appservice(
hostname, yaml.load(f), config_file
)
if appservice.id in seen_ids:
raise ConfigError(
"Cannot reuse ID across application services: "
"%s (files: %s, %s)" % (
appservice.id, config_file, seen_ids[appservice.id],
)
)
seen_ids[appservice.id] = config_file
if appservice.token in seen_as_tokens:
raise ConfigError(
"Cannot reuse as_token across application services: "
"%s (files: %s, %s)" % (
appservice.token,
config_file,
seen_as_tokens[appservice.token],
)
)
seen_as_tokens[appservice.token] = config_file
logger.info("Loaded application service: %s", appservice)
appservices.append(appservice)
except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e)
raise
return appservices
class ApplicationServiceTransactionStore(SQLBaseStore):

View File

@ -149,6 +149,7 @@ class PresenceStore(SQLBaseStore):
"status_msg",
"currently_active",
),
desc="get_presence_for_users",
)
for row in rows:

View File

@ -21,7 +21,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.api.constants import Membership
from synapse.types import get_domian_from_id
from synapse.types import get_domain_from_id
import logging
@ -137,24 +137,6 @@ class RoomMemberStore(SQLBaseStore):
return [r["user_id"] for r in rows]
return self.runInteraction("get_users_in_room", f)
def get_room_members(self, room_id, membership=None):
"""Retrieve the current room member list for a room.
Args:
room_id (str): The room to get the list of members.
membership (synapse.api.constants.Membership): The filter to apply
to this list, or None to return all members with some state
associated with this room.
Returns:
list of namedtuples representing the members in this room.
"""
return self.runInteraction(
"get_room_members",
self._get_members_events_txn,
room_id,
membership=membership,
).addCallback(self._get_events)
@cached()
def get_invited_rooms_for_user(self, user_id):
""" Get all the rooms the user is invited to
@ -273,7 +255,7 @@ class RoomMemberStore(SQLBaseStore):
room_id, membership=Membership.JOIN
)
joined_domains = set(get_domian_from_id(r["user_id"]) for r in rows)
joined_domains = set(get_domain_from_id(r["user_id"]) for r in rows)
return joined_domains

View File

@ -24,7 +24,7 @@ import ujson as json
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
def get_domian_from_id(string):
def get_domain_from_id(string):
return string.split(":", 1)[1]

View File

@ -25,8 +25,6 @@ from ..utils import (
)
from synapse.api.errors import AuthError
from synapse.handlers.typing import TypingNotificationHandler
from synapse.types import UserID
@ -49,11 +47,6 @@ def _make_edu_json(origin, edu_type, content):
return json.dumps(_expect_edu("test", edu_type, content, origin=origin))
class JustTypingNotificationHandlers(object):
def __init__(self, hs):
self.typing_notification_handler = TypingNotificationHandler(hs)
class TypingNotificationsTestCase(unittest.TestCase):
"""Tests typing notifications to rooms."""
@defer.inlineCallbacks
@ -71,6 +64,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.auth = Mock(spec=[])
hs = yield setup_test_homeserver(
"test",
auth=self.auth,
clock=self.clock,
datastore=Mock(spec=[
@ -88,9 +82,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
http_client=self.mock_http_client,
keyring=Mock(),
)
hs.handlers = JustTypingNotificationHandlers(hs)
self.handler = hs.get_handlers().typing_notification_handler
self.handler = hs.get_typing_handler()
self.event_source = hs.get_event_sources().sources["typing"]
@ -110,56 +103,16 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.room_id = "a-room"
# Mock the RoomMemberHandler
hs.handlers.room_member_handler = Mock(spec=[])
self.room_member_handler = hs.handlers.room_member_handler
self.room_members = []
def get_rooms_for_user(user):
if user in self.room_members:
return defer.succeed([self.room_id])
else:
return defer.succeed([])
self.room_member_handler.get_rooms_for_user = get_rooms_for_user
def get_room_members(room_id):
if room_id == self.room_id:
return defer.succeed(self.room_members)
else:
return defer.succeed([])
self.room_member_handler.get_room_members = get_room_members
def get_joined_rooms_for_user(user):
if user in self.room_members:
return defer.succeed([self.room_id])
else:
return defer.succeed([])
self.room_member_handler.get_joined_rooms_for_user = get_joined_rooms_for_user
@defer.inlineCallbacks
def fetch_room_distributions_into(
room_id, localusers=None, remotedomains=None, ignore_user=None
):
members = yield get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
self.room_member_handler.fetch_room_distributions_into = (
fetch_room_distributions_into
)
def check_joined_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
def get_joined_hosts_for_room(room_id):
return set(member.domain for member in self.room_members)
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
self.auth.check_joined_room = check_joined_room
# Some local users to test with

View File

@ -78,7 +78,7 @@ class ReplicationResourceCase(unittest.TestCase):
@defer.inlineCallbacks
def test_presence(self):
get = self.get(presence="-1")
yield self.hs.get_handlers().presence_handler.set_state(
yield self.hs.get_presence_handler().set_state(
self.user, {"presence": "online"}
)
code, body = yield get
@ -93,7 +93,7 @@ class ReplicationResourceCase(unittest.TestCase):
def test_typing(self):
room_id = yield self.create_room()
get = self.get(typing="-1")
yield self.hs.get_handlers().typing_notification_handler.started_typing(
yield self.hs.get_typing_handler().started_typing(
self.user, self.user, room_id, timeout=2
)
code, body = yield get

View File

@ -106,7 +106,7 @@ class RoomTypingTestCase(RestTestCase):
yield self.join(self.room_id, user="@jim:red")
def tearDown(self):
self.hs.get_handlers().typing_notification_handler.tearDown()
self.hs.get_typing_handler().tearDown()
@defer.inlineCallbacks
def test_set_typing(self):

View File

@ -70,12 +70,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
def test_one_member(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
self.assertEquals(
[self.u_alice.to_string()],
[m.user_id for m in (
yield self.store.get_room_members(self.room.to_string())
)]
)
self.assertEquals(
[self.room.to_string()],
[m.room_id for m in (
@ -85,18 +79,6 @@ class RoomMemberStoreTestCase(unittest.TestCase):
)]
)
@defer.inlineCallbacks
def test_two_members(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)
yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN)
self.assertEquals(
{self.u_alice.to_string(), self.u_bob.to_string()},
{m.user_id for m in (
yield self.store.get_room_members(self.room.to_string())
)}
)
@defer.inlineCallbacks
def test_room_hosts(self):
yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN)

View File

@ -50,7 +50,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.enable_registration = True
config.macaroon_secret_key = "not even a little secret"
config.expire_access_token = False
config.server_name = "server.under.test"
config.server_name = name
config.trusted_third_party_id_servers = []
config.room_invite_state_types = []