Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

pull/4534/head
Andrew Morgan 2019-01-29 10:07:13 +00:00
commit c7f2eaf4f4
58 changed files with 867 additions and 296 deletions

View File

@ -18,7 +18,7 @@ instructions that may be required are listed later in this document.
.. code:: bash .. code:: bash
pip install --upgrade --process-dependency-links matrix-synapse pip install --upgrade matrix-synapse
# restart synapse # restart synapse
synctl restart synctl restart

View File

@ -1 +1 @@
Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+. Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.

1
changelog.d/4412.bugfix Normal file
View File

@ -0,0 +1 @@
Copy over whether a room is a direct message and any associated room tags on room upgrade.

1
changelog.d/4470.misc Normal file
View File

@ -0,0 +1 @@
Add infrastructure to support different event formats

View File

@ -1 +1 @@
Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+. Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.

1
changelog.d/4477.misc Normal file
View File

@ -0,0 +1 @@
Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.

1
changelog.d/4482.misc Normal file
View File

@ -0,0 +1 @@
Add infrastructure to support different event formats

1
changelog.d/4485.misc Normal file
View File

@ -0,0 +1 @@
Remove deprecated --process-dependency-links option from UPGRADE.rst

1
changelog.d/4487.misc Normal file
View File

@ -0,0 +1 @@
Fix idna and ipv6 literal handling in MatrixFederationAgent

1
changelog.d/4488.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

1
changelog.d/4492.feature Normal file
View File

@ -0,0 +1 @@
Synapse can now automatically provision TLS certificates via ACME (the protocol used by CAs like Let's Encrypt).

1
changelog.d/4493.misc Normal file
View File

@ -0,0 +1 @@
Add infrastructure to support different event formats

1
changelog.d/4497.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

1
changelog.d/4505.misc Normal file
View File

@ -0,0 +1 @@
Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.

View File

@ -65,7 +65,7 @@ class Auth(object):
register_cache("cache", "token_cache", self.token_cache) register_cache("cache", "token_cache", self.token_cache)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True): def check_from_context(self, room_version, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.compute_auth_events( auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True, event, prev_state_ids, for_verification=True,
@ -74,12 +74,16 @@ class Auth(object):
auth_events = { auth_events = {
(e.type, e.state_key): e for e in itervalues(auth_events) (e.type, e.state_key): e for e in itervalues(auth_events)
} }
self.check(event, auth_events=auth_events, do_sig_check=do_sig_check) self.check(
room_version, event,
auth_events=auth_events, do_sig_check=do_sig_check,
)
def check(self, event, auth_events, do_sig_check=True): def check(self, room_version, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args: Args:
room_version (str): version of the room
event: the event being checked. event: the event being checked.
auth_events (dict: event-key -> event): the existing room state. auth_events (dict: event-key -> event): the existing room state.
@ -88,7 +92,9 @@ class Auth(object):
True if the auth checks pass. True if the auth checks pass.
""" """
with Measure(self.clock, "auth.check"): with Measure(self.clock, "auth.check"):
event_auth.check(event, auth_events, do_sig_check=do_sig_check) event_auth.check(
room_version, event, auth_events, do_sig_check=do_sig_check
)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None): def check_joined_room(self, room_id, user_id, current_state=None):

View File

@ -164,23 +164,23 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = ClientReaderServer( ss = ClientReaderServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
) )
ss.setup() ss.setup()
ss.start_listening(config.worker_listeners)
def start(): def start():
ss.config.read_certificate_from_disk()
ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
config
)
ss.start_listening(config.worker_listeners)
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View File

@ -185,23 +185,23 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = EventCreatorServer( ss = EventCreatorServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
) )
ss.setup() ss.setup()
ss.start_listening(config.worker_listeners)
def start(): def start():
ss.config.read_certificate_from_disk()
ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
config
)
ss.start_listening(config.worker_listeners)
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View File

@ -151,23 +151,23 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = FederationReaderServer( ss = FederationReaderServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
) )
ss.setup() ss.setup()
ss.start_listening(config.worker_listeners)
def start(): def start():
ss.config.read_certificate_from_disk()
ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
config
)
ss.start_listening(config.worker_listeners)
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View File

@ -183,24 +183,24 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config # Force the pushers to start since they will be disabled in the main config
config.send_federation = True config.send_federation = True
tls_server_context_factory = context_factory.ServerContextFactory(config) ss = FederationSenderServer(
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ps = FederationSenderServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
) )
ps.setup() ss.setup()
ps.start_listening(config.worker_listeners)
def start(): def start():
ps.get_datastore().start_profiling() ss.config.read_certificate_from_disk()
ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
config
)
ss.start_listening(config.worker_listeners)
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)
_base.start_worker_reactor("synapse-federation-sender", config) _base.start_worker_reactor("synapse-federation-sender", config)

View File

@ -241,23 +241,23 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = FrontendProxyServer( ss = FrontendProxyServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
) )
ss.setup() ss.setup()
ss.start_listening(config.worker_listeners)
def start(): def start():
ss.config.read_certificate_from_disk()
ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
config
)
ss.start_listening(config.worker_listeners)
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View File

@ -151,23 +151,23 @@ def start(config_options):
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ss = MediaRepositoryServer( ss = MediaRepositoryServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
) )
ss.setup() ss.setup()
ss.start_listening(config.worker_listeners)
def start(): def start():
ss.config.read_certificate_from_disk()
ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
config
)
ss.start_listening(config.worker_listeners)
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View File

@ -211,24 +211,24 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config # Force the pushers to start since they will be disabled in the main config
config.update_user_directory = True config.update_user_directory = True
tls_server_context_factory = context_factory.ServerContextFactory(config) ss = UserDirectoryServer(
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(config)
ps = UserDirectoryServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
tls_client_options_factory=tls_client_options_factory,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
) )
ps.setup() ss.setup()
ps.start_listening(config.worker_listeners)
def start(): def start():
ps.get_datastore().start_profiling() ss.config.read_certificate_from_disk()
ss.tls_server_context_factory = context_factory.ServerContextFactory(config)
ss.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
config
)
ss.start_listening(config.worker_listeners)
ss.get_datastore().start_profiling()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View File

@ -23,14 +23,14 @@ from signedjson.sign import sign_json
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.events.utils import prune_event from synapse.events.utils import prune_event, prune_event_dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_event_content_hash(event, hash_algorithm=hashlib.sha256): def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents""" """Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event, hash_algorithm) name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
logger.debug("Expecting hash: %s", encode_base64(expected_hash)) logger.debug("Expecting hash: %s", encode_base64(expected_hash))
# some malformed events lack a 'hashes'. Protect against it being missing # some malformed events lack a 'hashes'. Protect against it being missing
@ -59,35 +59,70 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
return message_hash_bytes == expected_hash return message_hash_bytes == expected_hash
def compute_content_hash(event, hash_algorithm): def compute_content_hash(event_dict, hash_algorithm):
event_json = event.get_pdu_json() """Compute the content hash of an event, which is the hash of the
event_json.pop("age_ts", None) unredacted event.
event_json.pop("unsigned", None)
event_json.pop("signatures", None)
event_json.pop("hashes", None)
event_json.pop("outlier", None)
event_json.pop("destinations", None)
event_json_bytes = encode_canonical_json(event_json) Args:
event_dict (dict): The unredacted event as a dict
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event
Returns:
tuple[str, bytes]: A tuple of the name of hash and the hash as raw
bytes.
"""
event_dict = dict(event_dict)
event_dict.pop("age_ts", None)
event_dict.pop("unsigned", None)
event_dict.pop("signatures", None)
event_dict.pop("hashes", None)
event_dict.pop("outlier", None)
event_dict.pop("destinations", None)
event_json_bytes = encode_canonical_json(event_dict)
hashed = hash_algorithm(event_json_bytes) hashed = hash_algorithm(event_json_bytes)
return (hashed.name, hashed.digest()) return (hashed.name, hashed.digest())
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256): def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
"""Computes the event reference hash. This is the hash of the redacted
event.
Args:
event (FrozenEvent)
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event
Returns:
tuple[str, bytes]: A tuple of the name of hash and the hash as raw
bytes.
"""
tmp_event = prune_event(event) tmp_event = prune_event(event)
event_json = tmp_event.get_pdu_json() event_dict = tmp_event.get_pdu_json()
event_json.pop("signatures", None) event_dict.pop("signatures", None)
event_json.pop("age_ts", None) event_dict.pop("age_ts", None)
event_json.pop("unsigned", None) event_dict.pop("unsigned", None)
event_json_bytes = encode_canonical_json(event_json) event_json_bytes = encode_canonical_json(event_dict)
hashed = hash_algorithm(event_json_bytes) hashed = hash_algorithm(event_json_bytes)
return (hashed.name, hashed.digest()) return (hashed.name, hashed.digest())
def compute_event_signature(event, signature_name, signing_key): def compute_event_signature(event_dict, signature_name, signing_key):
tmp_event = prune_event(event) """Compute the signature of the event for the given name and key.
redact_json = tmp_event.get_pdu_json()
Args:
event_dict (dict): The event as a dict
signature_name (str): The name of the entity signing the event
(typically the server's hostname).
signing_key (syutil.crypto.SigningKey): The key to sign with
Returns:
dict[str, dict[str, str]]: Returns a dictionary in the same format of
an event's signatures field.
"""
redact_json = prune_event_dict(event_dict)
redact_json.pop("age_ts", None) redact_json.pop("age_ts", None)
redact_json.pop("unsigned", None) redact_json.pop("unsigned", None)
logger.debug("Signing event: %s", encode_canonical_json(redact_json)) logger.debug("Signing event: %s", encode_canonical_json(redact_json))
@ -98,23 +133,27 @@ def compute_event_signature(event, signature_name, signing_key):
def add_hashes_and_signatures(event, signature_name, signing_key, def add_hashes_and_signatures(event, signature_name, signing_key,
hash_algorithm=hashlib.sha256): hash_algorithm=hashlib.sha256):
# if hasattr(event, "old_state_events"): """Add content hash and sign the event
# state_json_bytes = encode_canonical_json(
# [e.event_id for e in event.old_state_events.values()]
# )
# hashed = hash_algorithm(state_json_bytes)
# event.state_hash = {
# hashed.name: encode_base64(hashed.digest())
# }
name, digest = compute_content_hash(event, hash_algorithm=hash_algorithm) Args:
event_dict (EventBuilder): The event to add hashes to and sign
signature_name (str): The name of the entity signing the event
(typically the server's hostname).
signing_key (syutil.crypto.SigningKey): The key to sign with
hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
to hash the event
"""
name, digest = compute_content_hash(
event.get_pdu_json(), hash_algorithm=hash_algorithm,
)
if not hasattr(event, "hashes"): if not hasattr(event, "hashes"):
event.hashes = {} event.hashes = {}
event.hashes[name] = encode_base64(digest) event.hashes[name] = encode_base64(digest)
event.signatures = compute_event_signature( event.signatures = compute_event_signature(
event, event.get_pdu_json(),
signature_name=signature_name, signature_name=signature_name,
signing_key=signing_key, signing_key=signing_key,
) )

View File

@ -27,10 +27,11 @@ from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check(event, auth_events, do_sig_check=True, do_size_check=True): def check(room_version, event, auth_events, do_sig_check=True, do_size_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args: Args:
room_version (str): the version of the room
event: the event being checked. event: the event being checked.
auth_events (dict: event-key -> event): the existing room state. auth_events (dict: event-key -> event): the existing room state.

View File

@ -18,7 +18,11 @@ from distutils.util import strtobool
import six import six
from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventFormatVersions from synapse.api.constants import (
KNOWN_EVENT_FORMAT_VERSIONS,
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
)
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze from synapse.util.frozenutils import freeze
@ -256,3 +260,21 @@ def room_version_to_event_format(room_version):
raise RuntimeError("Unrecognized room version %s" % (room_version,)) raise RuntimeError("Unrecognized room version %s" % (room_version,))
return EventFormatVersions.V1 return EventFormatVersions.V1
def event_type_from_format_version(format_version):
"""Returns the python type to use to construct an Event object for the
given event format version.
Args:
format_version (int): The event format version
Returns:
type: A type that can be initialized as per the initializer of
`FrozenEvent`
"""
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
raise Exception(
"No event format %r" % (format_version,)
)
return FrozenEvent

View File

@ -15,12 +15,39 @@
import copy import copy
from synapse.api.constants import RoomVersions
from synapse.types import EventID from synapse.types import EventID
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from . import EventBase, FrozenEvent, _event_dict_property from . import EventBase, FrozenEvent, _event_dict_property
def get_event_builder(room_version, key_values={}, internal_metadata_dict={}):
"""Generate an event builder appropriate for the given room version
Args:
room_version (str): Version of the room that we're creating an
event builder for
key_values (dict): Fields used as the basis of the new event
internal_metadata_dict (dict): Used to create the `_EventInternalMetadata`
object.
Returns:
EventBuilder
"""
if room_version in {
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.VDH_TEST,
RoomVersions.STATE_V2_TEST,
}:
return EventBuilder(key_values, internal_metadata_dict)
else:
raise Exception(
"No event format defined for version %r" % (room_version,)
)
class EventBuilder(EventBase): class EventBuilder(EventBase):
def __init__(self, key_values={}, internal_metadata_dict={}): def __init__(self, key_values={}, internal_metadata_dict={}):
signatures = copy.deepcopy(key_values.pop("signatures", {})) signatures = copy.deepcopy(key_values.pop("signatures", {}))
@ -58,7 +85,29 @@ class EventBuilderFactory(object):
return e_id.to_string() return e_id.to_string()
def new(self, key_values={}): def new(self, room_version, key_values={}):
"""Generate an event builder appropriate for the given room version
Args:
room_version (str): Version of the room that we're creating an
event builder for
key_values (dict): Fields used as the basis of the new event
Returns:
EventBuilder
"""
# There's currently only the one event version defined
if room_version not in {
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.VDH_TEST,
RoomVersions.STATE_V2_TEST,
}:
raise Exception(
"No event format defined for version %r" % (room_version,)
)
key_values["event_id"] = self.create_event_id() key_values["event_id"] = self.create_event_id()
time_now = int(self.clock.time_msec()) time_now = int(self.clock.time_msec())

View File

@ -38,8 +38,31 @@ def prune_event(event):
This is used when we "redact" an event. We want to remove all fields that This is used when we "redact" an event. We want to remove all fields that
the user has specified, but we do want to keep necessary information like the user has specified, but we do want to keep necessary information like
type, state_key etc. type, state_key etc.
Args:
event (FrozenEvent)
Returns:
FrozenEvent
"""
pruned_event_dict = prune_event_dict(event.get_dict())
from . import event_type_from_format_version
return event_type_from_format_version(event.format_version)(
pruned_event_dict, event.internal_metadata.get_dict()
)
def prune_event_dict(event_dict):
"""Redacts the event_dict in the same way as `prune_event`, except it
operates on dicts rather than event objects
Args:
event_dict (dict)
Returns:
dict: A copy of the pruned event dict
""" """
event_type = event.type
allowed_keys = [ allowed_keys = [
"event_id", "event_id",
@ -59,13 +82,13 @@ def prune_event(event):
"membership", "membership",
] ]
event_dict = event.get_dict() event_type = event_dict["type"]
new_content = {} new_content = {}
def add_fields(*fields): def add_fields(*fields):
for field in fields: for field in fields:
if field in event.content: if field in event_dict["content"]:
new_content[field] = event_dict["content"][field] new_content[field] = event_dict["content"][field]
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
@ -98,17 +121,17 @@ def prune_event(event):
allowed_fields["content"] = new_content allowed_fields["content"] = new_content
allowed_fields["unsigned"] = {} unsigned = {}
allowed_fields["unsigned"] = unsigned
if "age_ts" in event.unsigned: event_unsigned = event_dict.get("unsigned", {})
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
if "replaces_state" in event.unsigned:
allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
return type(event)( if "age_ts" in event_unsigned:
allowed_fields, unsigned["age_ts"] = event_unsigned["age_ts"]
internal_metadata_dict=event.internal_metadata.get_dict() if "replaces_state" in event_unsigned:
) unsigned["replaces_state"] = event_unsigned["replaces_state"]
return allowed_fields
def _copy_field(src, dst, field): def _copy_field(src, dst, field):

View File

@ -23,7 +23,7 @@ from twisted.internet.defer import DeferredList
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent from synapse.events import event_type_from_format_version
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -302,11 +302,12 @@ def _is_invite_via_3pid(event):
) )
def event_from_pdu_json(pdu_json, outlier=False): def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
"""Construct a FrozenEvent from an event json received over federation """Construct a FrozenEvent from an event json received over federation
Args: Args:
pdu_json (object): pdu as received over federation pdu_json (object): pdu as received over federation
event_format_version (int): The event format version
outlier (bool): True to mark this event as an outlier outlier (bool): True to mark this event as an outlier
Returns: Returns:
@ -330,8 +331,8 @@ def event_from_pdu_json(pdu_json, outlier=False):
elif depth > MAX_DEPTH: elif depth > MAX_DEPTH:
raise SynapseError(400, "Depth too large", Codes.BAD_JSON) raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
event = FrozenEvent( event = event_type_from_format_version(event_format_version)(
pdu_json pdu_json,
) )
event.internal_metadata.outlier = outlier event.internal_metadata.outlier = outlier

View File

@ -170,13 +170,13 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def backfill(self, dest, context, limit, extremities): def backfill(self, dest, room_id, limit, extremities):
"""Requests some more historic PDUs for the given context from the """Requests some more historic PDUs for the given context from the
given destination server. given destination server.
Args: Args:
dest (str): The remote home server to ask. dest (str): The remote home server to ask.
context (str): The context to backfill. room_id (str): The room_id to backfill.
limit (int): The maximum number of PDUs to return. limit (int): The maximum number of PDUs to return.
extremities (list): List of PDU id and origins of the first pdus extremities (list): List of PDU id and origins of the first pdus
we have seen from the context we have seen from the context
@ -191,12 +191,15 @@ class FederationClient(FederationBase):
return return
transaction_data = yield self.transport_layer.backfill( transaction_data = yield self.transport_layer.backfill(
dest, context, extremities, limit) dest, room_id, extremities, limit)
logger.debug("backfill transaction_data=%s", repr(transaction_data)) logger.debug("backfill transaction_data=%s", repr(transaction_data))
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdus = [ pdus = [
event_from_pdu_json(p, outlier=False) event_from_pdu_json(p, format_ver, outlier=False)
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ]
@ -240,6 +243,8 @@ class FederationClient(FederationBase):
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
format_ver = room_version_to_event_format(room_version)
signed_pdu = None signed_pdu = None
for destination in destinations: for destination in destinations:
now = self._clock.time_msec() now = self._clock.time_msec()
@ -255,7 +260,7 @@ class FederationClient(FederationBase):
logger.debug("transaction_data %r", transaction_data) logger.debug("transaction_data %r", transaction_data)
pdu_list = [ pdu_list = [
event_from_pdu_json(p, outlier=outlier) event_from_pdu_json(p, format_ver, outlier=outlier)
for p in transaction_data["pdus"] for p in transaction_data["pdus"]
] ]
@ -349,12 +354,16 @@ class FederationClient(FederationBase):
destination, room_id, event_id=event_id, destination, room_id, event_id=event_id,
) )
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdus = [ pdus = [
event_from_pdu_json(p, outlier=True) for p in result["pdus"] event_from_pdu_json(p, format_ver, outlier=True)
for p in result["pdus"]
] ]
auth_chain = [ auth_chain = [
event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, format_ver, outlier=True)
for p in result.get("auth_chain", []) for p in result.get("auth_chain", [])
] ]
@ -362,8 +371,6 @@ class FederationClient(FederationBase):
ev.event_id for ev in itertools.chain(pdus, auth_chain) ev.event_id for ev in itertools.chain(pdus, auth_chain)
]) ])
room_version = yield self.store.get_room_version(room_id)
signed_pdus = yield self._check_sigs_and_hash_and_fetch( signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, destination,
[p for p in pdus if p.event_id not in seen_events], [p for p in pdus if p.event_id not in seen_events],
@ -462,13 +469,14 @@ class FederationClient(FederationBase):
destination, room_id, event_id, destination, room_id, event_id,
) )
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [ auth_chain = [
event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, format_ver, outlier=True)
for p in res["auth_chain"] for p in res["auth_chain"]
] ]
room_version = yield self.store.get_room_version(room_id)
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, destination, auth_chain,
outlier=True, room_version=room_version, outlier=True, room_version=room_version,
@ -605,7 +613,7 @@ class FederationClient(FederationBase):
pdu_dict.pop("origin_server_ts", None) pdu_dict.pop("origin_server_ts", None)
pdu_dict.pop("unsigned", None) pdu_dict.pop("unsigned", None)
builder = self.event_builder_factory.new(pdu_dict) builder = self.event_builder_factory.new(room_version, pdu_dict)
add_hashes_and_signatures( add_hashes_and_signatures(
builder, builder,
self.hs.hostname, self.hs.hostname,
@ -621,7 +629,7 @@ class FederationClient(FederationBase):
"make_" + membership, destinations, send_request, "make_" + membership, destinations, send_request,
) )
def send_join(self, destinations, pdu): def send_join(self, destinations, pdu, event_format_version):
"""Sends a join event to one of a list of homeservers. """Sends a join event to one of a list of homeservers.
Doing so will cause the remote server to add the event to the graph, Doing so will cause the remote server to add the event to the graph,
@ -631,6 +639,7 @@ class FederationClient(FederationBase):
destinations (str): Candidate homeservers which are probably destinations (str): Candidate homeservers which are probably
participating in the room. participating in the room.
pdu (BaseEvent): event to be sent pdu (BaseEvent): event to be sent
event_format_version (int): The event format version
Return: Return:
Deferred: resolves to a dict with members ``origin`` (a string Deferred: resolves to a dict with members ``origin`` (a string
@ -676,12 +685,12 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
state = [ state = [
event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, event_format_version, outlier=True)
for p in content.get("state", []) for p in content.get("state", [])
] ]
auth_chain = [ auth_chain = [
event_from_pdu_json(p, outlier=True) event_from_pdu_json(p, event_format_version, outlier=True)
for p in content.get("auth_chain", []) for p in content.get("auth_chain", [])
] ]
@ -759,7 +768,10 @@ class FederationClient(FederationBase):
logger.debug("Got response to send_invite: %s", pdu_dict) logger.debug("Got response to send_invite: %s", pdu_dict)
pdu = event_from_pdu_json(pdu_dict) room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(pdu_dict, format_ver)
# Check signatures are correct. # Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hash(pdu)
@ -837,13 +849,14 @@ class FederationClient(FederationBase):
content=send_content, content=send_content,
) )
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [ auth_chain = [
event_from_pdu_json(e) event_from_pdu_json(e, format_ver)
for e in content["auth_chain"] for e in content["auth_chain"]
] ]
room_version = yield self.store.get_room_version(room_id)
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True, room_version=room_version, destination, auth_chain, outlier=True, room_version=room_version,
) )
@ -887,13 +900,14 @@ class FederationClient(FederationBase):
timeout=timeout, timeout=timeout,
) )
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
events = [ events = [
event_from_pdu_json(e) event_from_pdu_json(e, format_ver)
for e in content.get("events", []) for e in content.get("events", [])
] ]
room_version = yield self.store.get_room_version(room_id)
signed_events = yield self._check_sigs_and_hash_and_fetch( signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False, room_version=room_version, destination, events, outlier=False, room_version=room_version,
) )

View File

@ -34,6 +34,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
@ -178,14 +179,13 @@ class FederationServer(FederationBase):
continue continue
try: try:
# In future we will actually use the room version to parse the room_version = yield self.store.get_room_version(room_id)
# PDU into an event. format_ver = room_version_to_event_format(room_version)
yield self.store.get_room_version(room_id)
except NotFoundError: except NotFoundError:
logger.info("Ignoring PDU for unknown room_id: %s", room_id) logger.info("Ignoring PDU for unknown room_id: %s", room_id)
continue continue
event = event_from_pdu_json(p) event = event_from_pdu_json(p, format_ver)
pdus_by_room.setdefault(room_id, []).append(event) pdus_by_room.setdefault(room_id, []).append(event)
pdu_results = {} pdu_results = {}
@ -370,7 +370,9 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_invite_request(self, origin, content, room_version): def on_invite_request(self, origin, content, room_version):
pdu = event_from_pdu_json(content) format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id) yield self.check_server_matches_acl(origin_host, pdu.room_id)
ret_pdu = yield self.handler.on_invite_request(origin, pdu) ret_pdu = yield self.handler.on_invite_request(origin, pdu)
@ -378,9 +380,12 @@ class FederationServer(FederationBase):
defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)}) defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks @defer.inlineCallbacks
def on_send_join_request(self, origin, content): def on_send_join_request(self, origin, content, room_id):
logger.debug("on_send_join_request: content: %s", content) logger.debug("on_send_join_request: content: %s", content)
pdu = event_from_pdu_json(content)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id) yield self.check_server_matches_acl(origin_host, pdu.room_id)
@ -410,9 +415,12 @@ class FederationServer(FederationBase):
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def on_send_leave_request(self, origin, content): def on_send_leave_request(self, origin, content, room_id):
logger.debug("on_send_leave_request: content: %s", content) logger.debug("on_send_leave_request: content: %s", content)
pdu = event_from_pdu_json(content)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdu = event_from_pdu_json(content, format_ver)
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, pdu.room_id) yield self.check_server_matches_acl(origin_host, pdu.room_id)
@ -458,13 +466,14 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin) origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id) yield self.check_server_matches_acl(origin_host, room_id)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [ auth_chain = [
event_from_pdu_json(e) event_from_pdu_json(e, format_ver)
for e in content["auth_chain"] for e in content["auth_chain"]
] ]
room_version = yield self.store.get_room_version(room_id)
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True, room_version=room_version, origin, auth_chain, outlier=True, room_version=room_version,
) )

View File

@ -469,7 +469,7 @@ class FederationSendLeaveServlet(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, origin, content, query, room_id, event_id): def on_PUT(self, origin, content, query, room_id, event_id):
content = yield self.handler.on_send_leave_request(origin, content) content = yield self.handler.on_send_leave_request(origin, content, room_id)
defer.returnValue((200, content)) defer.returnValue((200, content))
@ -487,7 +487,7 @@ class FederationSendJoinServlet(BaseFederationServlet):
def on_PUT(self, origin, content, query, context, event_id): def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually # TODO(paul): assert that context/event_id parsed from path actually
# match those given in content # match those given in content
content = yield self.handler.on_send_join_request(origin, content) content = yield self.handler.on_send_join_request(origin, content, context)
defer.returnValue((200, content)) defer.returnValue((200, content))

View File

@ -1061,7 +1061,7 @@ class FederationHandler(BaseHandler):
""" """
logger.debug("Joining %s to %s", joinee, room_id) logger.debug("Joining %s to %s", joinee, room_id)
origin, event = yield self._make_and_verify_event( origin, event, event_format_version = yield self._make_and_verify_event(
target_hosts, target_hosts,
room_id, room_id,
joinee, joinee,
@ -1091,7 +1091,9 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin) target_hosts.insert(0, origin)
except ValueError: except ValueError:
pass pass
ret = yield self.federation_client.send_join(target_hosts, event) ret = yield self.federation_client.send_join(
target_hosts, event, event_format_version,
)
origin = ret["origin"] origin = ret["origin"]
state = ret["state"] state = ret["state"]
@ -1164,13 +1166,18 @@ class FederationHandler(BaseHandler):
""" """
event_content = {"membership": Membership.JOIN} event_content = {"membership": Membership.JOIN}
builder = self.event_builder_factory.new({ room_version = yield self.store.get_room_version(room_id)
"type": EventTypes.Member,
"content": event_content, builder = self.event_builder_factory.new(
"room_id": room_id, room_version,
"sender": user_id, {
"state_key": user_id, "type": EventTypes.Member,
}) "content": event_content,
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
}
)
try: try:
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
@ -1182,7 +1189,9 @@ class FederationHandler(BaseHandler):
# The remote hasn't signed it yet, obviously. We'll do the full checks # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request` # when we get the event back in `on_send_join_request`
yield self.auth.check_from_context(event, context, do_sig_check=False) yield self.auth.check_from_context(
room_version, event, context, do_sig_check=False,
)
defer.returnValue(event) defer.returnValue(event)
@ -1304,7 +1313,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id): def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
origin, event = yield self._make_and_verify_event( origin, event, event_format_version = yield self._make_and_verify_event(
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
@ -1336,7 +1345,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _make_and_verify_event(self, target_hosts, room_id, user_id, membership, def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
content={}, params=None): content={}, params=None):
origin, pdu, _ = yield self.federation_client.make_membership_event( origin, event, format_ver = yield self.federation_client.make_membership_event(
target_hosts, target_hosts,
room_id, room_id,
user_id, user_id,
@ -1345,9 +1354,7 @@ class FederationHandler(BaseHandler):
params=params, params=params,
) )
logger.debug("Got response to make_%s: %s", membership, pdu) logger.debug("Got response to make_%s: %s", membership, event)
event = pdu
# We should assert some things. # We should assert some things.
# FIXME: Do this in a nicer way # FIXME: Do this in a nicer way
@ -1355,7 +1362,7 @@ class FederationHandler(BaseHandler):
assert(event.user_id == user_id) assert(event.user_id == user_id)
assert(event.state_key == user_id) assert(event.state_key == user_id)
assert(event.room_id == room_id) assert(event.room_id == room_id)
defer.returnValue((origin, event)) defer.returnValue((origin, event, format_ver))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -1364,13 +1371,17 @@ class FederationHandler(BaseHandler):
leave event for the room and return that. We do *not* persist or leave event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back. process it until the other server has signed it and sent it back.
""" """
builder = self.event_builder_factory.new({ room_version = yield self.store.get_room_version(room_id)
"type": EventTypes.Member, builder = self.event_builder_factory.new(
"content": {"membership": Membership.LEAVE}, room_version,
"room_id": room_id, {
"sender": user_id, "type": EventTypes.Member,
"state_key": user_id, "content": {"membership": Membership.LEAVE},
}) "room_id": room_id,
"sender": user_id,
"state_key": user_id,
}
)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder, builder=builder,
@ -1379,7 +1390,9 @@ class FederationHandler(BaseHandler):
try: try:
# The remote hasn't signed it yet, obviously. We'll do the full checks # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request` # when we get the event back in `on_send_leave_request`
yield self.auth.check_from_context(event, context, do_sig_check=False) yield self.auth.check_from_context(
room_version, event, context, do_sig_check=False,
)
except AuthError as e: except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e) logger.warn("Failed to create new leave %r because %s", event, e)
raise e raise e
@ -1674,7 +1687,7 @@ class FederationHandler(BaseHandler):
auth_for_e[(EventTypes.Create, "")] = create_event auth_for_e[(EventTypes.Create, "")] = create_event
try: try:
self.auth.check(e, auth_events=auth_for_e) self.auth.check(room_version, e, auth_events=auth_for_e)
except SynapseError as err: except SynapseError as err:
# we may get SynapseErrors here as well as AuthErrors. For # we may get SynapseErrors here as well as AuthErrors. For
# instance, there are a couple of (ancient) events in some # instance, there are a couple of (ancient) events in some
@ -1918,6 +1931,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
room_version = yield self.store.get_room_version(event.room_id)
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
# Do auth conflict res. # Do auth conflict res.
logger.info("Different auth: %s", different_auth) logger.info("Different auth: %s", different_auth)
@ -1942,8 +1957,6 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d (d.type, d.state_key): d for d in different_events if d
}) })
room_version = yield self.store.get_room_version(event.room_id)
new_state = yield self.state_handler.resolve_events( new_state = yield self.state_handler.resolve_events(
room_version, room_version,
[list(local_view.values()), list(remote_view.values())], [list(local_view.values()), list(remote_view.values())],
@ -2043,7 +2056,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
self.auth.check(event, auth_events=auth_events) self.auth.check(room_version, event, auth_events=auth_events)
except AuthError as e: except AuthError as e:
logger.warn("Failed auth resolution for %r because %s", event, e) logger.warn("Failed auth resolution for %r because %s", event, e)
raise e raise e
@ -2266,18 +2279,20 @@ class FederationHandler(BaseHandler):
} }
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict) room_version = yield self.store.get_room_version(room_id)
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder builder=builder
) )
event, context = yield self.add_display_name_to_third_party_invite( event, context = yield self.add_display_name_to_third_party_invite(
event_dict, event, context room_version, event_dict, event, context
) )
try: try:
yield self.auth.check_from_context(event, context) yield self.auth.check_from_context(room_version, event, context)
except AuthError as e: except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e) logger.warn("Denying new third party invite %r because %s", event, e)
raise e raise e
@ -2304,18 +2319,22 @@ class FederationHandler(BaseHandler):
Returns: Returns:
Deferred: resolves (to None) Deferred: resolves (to None)
""" """
builder = self.event_builder_factory.new(event_dict) room_version = yield self.store.get_room_version(room_id)
# NB: event_dict has a particular specced format we might need to fudge
# if we change event formats too much.
builder = self.event_builder_factory.new(room_version, event_dict)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder, builder=builder,
) )
event, context = yield self.add_display_name_to_third_party_invite( event, context = yield self.add_display_name_to_third_party_invite(
event_dict, event, context room_version, event_dict, event, context
) )
try: try:
self.auth.check_from_context(event, context) self.auth.check_from_context(room_version, event, context)
except AuthError as e: except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e) logger.warn("Denying third party invite %r because %s", event, e)
raise e raise e
@ -2331,7 +2350,8 @@ class FederationHandler(BaseHandler):
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_display_name_to_third_party_invite(self, event_dict, event, context): def add_display_name_to_third_party_invite(self, room_version, event_dict,
event, context):
key = ( key = (
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"] event.content["third_party_invite"]["signed"]["token"]
@ -2355,7 +2375,7 @@ class FederationHandler(BaseHandler):
# auth checks. If we need the invite and don't have it then the # auth checks. If we need the invite and don't have it then the
# auth check code will explode appropriately. # auth check code will explode appropriately.
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
event, context = yield self.event_creation_handler.create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder, builder=builder,

View File

@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -278,7 +278,15 @@ class EventCreationHandler(object):
""" """
yield self.auth.check_auth_blocking(requester.user.to_string()) yield self.auth.check_auth_blocking(requester.user.to_string())
builder = self.event_builder_factory.new(event_dict) if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version = event_dict["content"]["room_version"]
else:
try:
room_version = yield self.store.get_room_version(event_dict["room_id"])
except NotFoundError:
raise AuthError(403, "Unknown room")
builder = self.event_builder_factory.new(room_version, event_dict)
self.validator.validate_new(builder) self.validator.validate_new(builder)
@ -603,8 +611,13 @@ class EventCreationHandler(object):
extra_users (list(UserID)): Any extra users to notify about event extra_users (list(UserID)): Any extra users to notify about event
""" """
if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
room_version = event.content.get("room_version", RoomVersions.V1)
else:
room_version = yield self.store.get_room_version(event.room_id)
try: try:
yield self.auth.check_from_context(event, context) yield self.auth.check_from_context(room_version, event, context)
except AuthError as err: except AuthError as err:
logger.warn("Denying new event %r because %s", event, err) logger.warn("Denying new event %r because %s", event, err)
raise err raise err

View File

@ -123,9 +123,12 @@ class RoomCreationHandler(BaseHandler):
token_id=requester.access_token_id, token_id=requester.access_token_id,
) )
) )
yield self.auth.check_from_context(tombstone_event, tombstone_context) old_room_version = yield self.store.get_room_version(old_room_id)
yield self.auth.check_from_context(
old_room_version, tombstone_event, tombstone_context,
)
yield self.clone_exiting_room( yield self.clone_existing_room(
requester, requester,
old_room_id=old_room_id, old_room_id=old_room_id,
new_room_id=new_room_id, new_room_id=new_room_id,
@ -230,7 +233,7 @@ class RoomCreationHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def clone_exiting_room( def clone_existing_room(
self, requester, old_room_id, new_room_id, new_room_version, self, requester, old_room_id, new_room_id, new_room_version,
tombstone_event_id, tombstone_event_id,
): ):
@ -262,6 +265,7 @@ class RoomCreationHandler(BaseHandler):
initial_state = dict() initial_state = dict()
# Replicate relevant room events
types_to_copy = ( types_to_copy = (
(EventTypes.JoinRules, ""), (EventTypes.JoinRules, ""),
(EventTypes.Name, ""), (EventTypes.Name, ""),

View File

@ -63,7 +63,7 @@ class RoomMemberHandler(object):
self.directory_handler = hs.get_handlers().directory_handler self.directory_handler = hs.get_handlers().directory_handler
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.event_creation_hander = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member") self.member_linearizer = Linearizer(name="member")
self.member_limiter = Linearizer(max_count=10, name="member_as_limiter") self.member_limiter = Linearizer(max_count=10, name="member_as_limiter")
@ -162,6 +162,8 @@ class RoomMemberHandler(object):
ratelimit=True, ratelimit=True,
content=None, content=None,
): ):
user_id = target.to_string()
if content is None: if content is None:
content = {} content = {}
@ -169,14 +171,14 @@ class RoomMemberHandler(object):
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
event, context = yield self.event_creation_hander.create_event( event, context = yield self.event_creation_handler.create_event(
requester, requester,
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
"content": content, "content": content,
"room_id": room_id, "room_id": room_id,
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
"state_key": target.to_string(), "state_key": user_id,
# For backwards compatibility: # For backwards compatibility:
"membership": membership, "membership": membership,
@ -187,14 +189,14 @@ class RoomMemberHandler(object):
) )
# Check if this event matches the previous membership event for the user. # Check if this event matches the previous membership event for the user.
duplicate = yield self.event_creation_hander.deduplicate_state_event( duplicate = yield self.event_creation_handler.deduplicate_state_event(
event, context, event, context,
) )
if duplicate is not None: if duplicate is not None:
# Discard the new event since this membership change is a no-op. # Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate) defer.returnValue(duplicate)
yield self.event_creation_hander.handle_new_client_event( yield self.event_creation_handler.handle_new_client_event(
requester, requester,
event, event,
context, context,
@ -205,12 +207,12 @@ class RoomMemberHandler(object):
prev_state_ids = yield context.get_prev_state_ids(self.store) prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_member_event_id = prev_state_ids.get( prev_member_event_id = prev_state_ids.get(
(EventTypes.Member, target.to_string()), (EventTypes.Member, user_id),
None None
) )
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the # Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile # room. Don't bother if the user is just changing their profile
# info. # info.
newly_joined = True newly_joined = True
@ -219,6 +221,18 @@ class RoomMemberHandler(object):
newly_joined = prev_member_event.membership != Membership.JOIN newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined: if newly_joined:
yield self._user_joined_room(target, room_id) yield self._user_joined_room(target, room_id)
# Copy over direct message status and room tags if this is a join
# on an upgraded room
# Check if this is an upgraded room
predecessor = yield self.store.get_room_predecessor(room_id)
if predecessor:
# It is an upgraded room. Copy over old tags
self.copy_room_tags_and_direct_to_room(
predecessor["room_id"], room_id, user_id,
)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id) prev_member_event = yield self.store.get_event(prev_member_event_id)
@ -227,6 +241,55 @@ class RoomMemberHandler(object):
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks
def copy_room_tags_and_direct_to_room(
self,
old_room_id,
new_room_id,
user_id,
):
"""Copies the tags and direct room state from one room to another.
Args:
old_room_id (str)
new_room_id (str)
user_id (str)
Returns:
Deferred[None]
"""
# Retrieve user account data for predecessor room
user_account_data, _ = yield self.store.get_account_data_for_user(
user_id,
)
# Copy direct message state if applicable
direct_rooms = user_account_data.get("m.direct", {})
# Check which key this room is under
if isinstance(direct_rooms, dict):
for key, room_id_list in direct_rooms.items():
if old_room_id in room_id_list and new_room_id not in room_id_list:
# Add new room_id to this key
direct_rooms[key].append(new_room_id)
# Save back to user's m.direct account data
yield self.store.add_account_data_for_user(
user_id, "m.direct", direct_rooms,
)
break
# Copy room tags if applicable
room_tags = yield self.store.get_tags_for_room(
user_id, old_room_id,
)
# Copy each room tag to the new room
for tag, tag_content in room_tags.items():
yield self.store.add_tag_to_room(
user_id, new_room_id, tag, tag_content
)
@defer.inlineCallbacks @defer.inlineCallbacks
def update_membership( def update_membership(
self, self,
@ -513,7 +576,7 @@ class RoomMemberHandler(object):
else: else:
requester = synapse.types.create_requester(target_user) requester = synapse.types.create_requester(target_user)
prev_event = yield self.event_creation_hander.deduplicate_state_event( prev_event = yield self.event_creation_handler.deduplicate_state_event(
event, context, event, context,
) )
if prev_event is not None: if prev_event is not None:
@ -533,7 +596,7 @@ class RoomMemberHandler(object):
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
yield self.event_creation_hander.handle_new_client_event( yield self.event_creation_handler.handle_new_client_event(
requester, requester,
event, event,
context, context,
@ -547,7 +610,7 @@ class RoomMemberHandler(object):
) )
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the # Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile # room. Don't bother if the user is just changing their profile
# info. # info.
newly_joined = True newly_joined = True
@ -775,7 +838,7 @@ class RoomMemberHandler(object):
) )
) )
yield self.event_creation_hander.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.ThirdPartyInvite, "type": EventTypes.ThirdPartyInvite,

View File

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
import logging import logging
import attr
from netaddr import IPAddress
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
@ -22,7 +24,6 @@ from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent from twisted.web.iweb import IAgent
from synapse.http.endpoint import parse_server_name
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
@ -86,29 +87,20 @@ class MatrixFederationAgent(object):
response from being received (including problems that prevent the request response from being received (including problems that prevent the request
from being sent). from being sent).
""" """
parsed_uri = URI.fromBytes(uri, defaultPort=-1)
res = yield self._route_matrix_uri(parsed_uri)
parsed_uri = URI.fromBytes(uri) # set up the TLS connection params
server_name_bytes = parsed_uri.netloc #
host, port = parse_server_name(server_name_bytes.decode("ascii"))
# XXX disabling TLS is really only supported here for the benefit of the # XXX disabling TLS is really only supported here for the benefit of the
# unit tests. We should make the UTs cope with TLS rather than having to make # unit tests. We should make the UTs cope with TLS rather than having to make
# the code support the unit tests. # the code support the unit tests.
if self._tls_client_options_factory is None: if self._tls_client_options_factory is None:
tls_options = None tls_options = None
else: else:
tls_options = self._tls_client_options_factory.get_options(host) tls_options = self._tls_client_options_factory.get_options(
res.tls_server_name.decode("ascii")
if port is not None: )
target = (host, port)
else:
service_name = b"_matrix._tcp.%s" % (server_name_bytes, )
server_list = yield self._srv_resolver.resolve_service(service_name)
if not server_list:
target = (host, 8448)
logger.debug("No SRV record for %s, using %s", host, target)
else:
target = pick_server_from_list(server_list)
# make sure that the Host header is set correctly # make sure that the Host header is set correctly
if headers is None: if headers is None:
@ -117,13 +109,13 @@ class MatrixFederationAgent(object):
headers = headers.copy() headers = headers.copy()
if not headers.hasHeader(b'host'): if not headers.hasHeader(b'host'):
headers.addRawHeader(b'host', server_name_bytes) headers.addRawHeader(b'host', res.host_header)
class EndpointFactory(object): class EndpointFactory(object):
@staticmethod @staticmethod
def endpointForURI(_uri): def endpointForURI(_uri):
logger.info("Connecting to %s:%s", target[0], target[1]) logger.info("Connecting to %s:%s", res.target_host, res.target_port)
ep = HostnameEndpoint(self._reactor, host=target[0], port=target[1]) ep = HostnameEndpoint(self._reactor, res.target_host, res.target_port)
if tls_options is not None: if tls_options is not None:
ep = wrapClientTLS(tls_options, ep) ep = wrapClientTLS(tls_options, ep)
return ep return ep
@ -133,3 +125,111 @@ class MatrixFederationAgent(object):
agent.request(method, uri, headers, bodyProducer) agent.request(method, uri, headers, bodyProducer)
) )
defer.returnValue(res) defer.returnValue(res)
@defer.inlineCallbacks
def _route_matrix_uri(self, parsed_uri):
"""Helper for `request`: determine the routing for a Matrix URI
Args:
parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
if there is no explicit port given.
Returns:
Deferred[_RoutingResult]
"""
# check for an IP literal
try:
ip_address = IPAddress(parsed_uri.host.decode("ascii"))
except Exception:
# not an IP address
ip_address = None
if ip_address:
port = parsed_uri.port
if port == -1:
port = 8448
defer.returnValue(_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=parsed_uri.host,
target_port=port,
))
if parsed_uri.port != -1:
# there is an explicit port
defer.returnValue(_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=parsed_uri.host,
target_port=parsed_uri.port,
))
# try a SRV lookup
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
server_list = yield self._srv_resolver.resolve_service(service_name)
if not server_list:
target_host = parsed_uri.host
port = 8448
logger.debug(
"No SRV record for %s, using %s:%i",
parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
)
else:
target_host, port = pick_server_from_list(server_list)
logger.debug(
"Picked %s:%i from SRV records for %s",
target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
)
defer.returnValue(_RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=target_host,
target_port=port,
))
@attr.s
class _RoutingResult(object):
"""The result returned by `_route_matrix_uri`.
Contains the parameters needed to direct a federation connection to a particular
server.
Where a SRV record points to several servers, this object contains a single server
chosen from the list.
"""
host_header = attr.ib()
"""
The value we should assign to the Host header (host:port from the matrix
URI, or .well-known).
:type: bytes
"""
tls_server_name = attr.ib()
"""
The server name we should set in the SNI (typically host, without port, from the
matrix URI or .well-known)
:type: bytes
"""
target_host = attr.ib()
"""
The hostname (or IP literal) we should route the TCP connection to (the target of the
SRV record, or the hostname from the URL/.well-known)
:type: bytes
"""
target_port = attr.ib()
"""
The port we should route the TCP connection to (the target of the SRV record, or
the port from the URL/.well-known, or 8448)
:type: int
"""

View File

@ -17,7 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.events import FrozenEvent from synapse.events import event_type_from_format_version
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
@ -70,6 +70,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_payloads.append({ event_payloads.append({
"event": event.get_pdu_json(), "event": event.get_pdu_json(),
"event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(), "internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason, "rejected_reason": event.rejected_reason,
"context": serialized_context, "context": serialized_context,
@ -94,9 +95,12 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_and_contexts = [] event_and_contexts = []
for event_payload in event_payloads: for event_payload in event_payloads:
event_dict = event_payload["event"] event_dict = event_payload["event"]
format_ver = content["event_format_version"]
internal_metadata = event_payload["internal_metadata"] internal_metadata = event_payload["internal_metadata"]
rejected_reason = event_payload["rejected_reason"] rejected_reason = event_payload["rejected_reason"]
event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
EventType = event_type_from_format_version(format_ver)
event = EventType(event_dict, internal_metadata, rejected_reason)
context = yield EventContext.deserialize( context = yield EventContext.deserialize(
self.store, event_payload["context"], self.store, event_payload["context"],

View File

@ -17,7 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.events import FrozenEvent from synapse.events import event_type_from_format_version
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
@ -74,6 +74,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
payload = { payload = {
"event": event.get_pdu_json(), "event": event.get_pdu_json(),
"event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(), "internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason, "rejected_reason": event.rejected_reason,
"context": serialized_context, "context": serialized_context,
@ -90,9 +91,12 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
event_dict = content["event"] event_dict = content["event"]
format_ver = content["event_format_version"]
internal_metadata = content["internal_metadata"] internal_metadata = content["internal_metadata"]
rejected_reason = content["rejected_reason"] rejected_reason = content["rejected_reason"]
event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
EventType = event_type_from_format_version(format_ver)
event = EventType(event_dict, internal_metadata, rejected_reason)
requester = Requester.deserialize(self.store, content["requester"]) requester = Requester.deserialize(self.store, content["requester"])
context = yield EventContext.deserialize(self.store, content["context"]) context = yield EventContext.deserialize(self.store, content["context"])

View File

@ -89,7 +89,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs) super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_hander = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler() self.message_handler = hs.get_message_handler()
@ -172,7 +172,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content, content=content,
) )
else: else:
event = yield self.event_creation_hander.create_and_send_nonmember_event( event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
event_dict, event_dict,
txn_id=txn_id, txn_id=txn_id,
@ -189,7 +189,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs) super(RoomSendEventRestServlet, self).__init__(hs)
self.event_creation_hander = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
@ -211,7 +211,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
if b'ts' in request.args and requester.app_service: if b'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
event = yield self.event_creation_hander.create_and_send_nonmember_event( event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
event_dict, event_dict,
txn_id=txn_id, txn_id=txn_id,

View File

@ -611,7 +611,7 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
RoomVersions.VDH_TEST, RoomVersions.STATE_V2_TEST, RoomVersions.V2, RoomVersions.VDH_TEST, RoomVersions.STATE_V2_TEST, RoomVersions.V2,
): ):
return v2.resolve_events_with_store( return v2.resolve_events_with_store(
state_sets, event_map, state_res_store, room_version, state_sets, event_map, state_res_store,
) )
else: else:
# This should only happen if we added a version but forgot to add it to # This should only happen if we added a version but forgot to add it to

View File

@ -21,7 +21,7 @@ from six import iteritems, iterkeys, itervalues
from twisted.internet import defer from twisted.internet import defer
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes, RoomVersions
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -274,7 +274,11 @@ def _resolve_auth_events(events, auth_events):
auth_events[(prev_event.type, prev_event.state_key)] = prev_event auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try: try:
# The signatures have already been checked at this point # The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) event_auth.check(
RoomVersions.V1, event, auth_events,
do_sig_check=False,
do_size_check=False,
)
prev_event = event prev_event = event
except AuthError: except AuthError:
return prev_event return prev_event
@ -286,7 +290,11 @@ def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events): for event in _ordered_events(events):
try: try:
# The signatures have already been checked at this point # The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False) event_auth.check(
RoomVersions.V1, event, auth_events,
do_sig_check=False,
do_size_check=False,
)
return event return event
except AuthError: except AuthError:
pass pass

View File

@ -29,10 +29,12 @@ logger = logging.getLogger(__name__)
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_events_with_store(state_sets, event_map, state_res_store): def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
"""Resolves the state using the v2 state resolution algorithm """Resolves the state using the v2 state resolution algorithm
Args: Args:
room_version (str): The room version
state_sets(list): List of dicts of (type, state_key) -> event_id, state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve. which are the different state groups to resolve.
@ -104,7 +106,7 @@ def resolve_events_with_store(state_sets, event_map, state_res_store):
# Now sequentially auth each one # Now sequentially auth each one
resolved_state = yield _iterative_auth_checks( resolved_state = yield _iterative_auth_checks(
sorted_power_events, unconflicted_state, event_map, room_version, sorted_power_events, unconflicted_state, event_map,
state_res_store, state_res_store,
) )
@ -129,7 +131,7 @@ def resolve_events_with_store(state_sets, event_map, state_res_store):
logger.debug("resolving remaining events") logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks( resolved_state = yield _iterative_auth_checks(
leftover_events, resolved_state, event_map, room_version, leftover_events, resolved_state, event_map,
state_res_store, state_res_store,
) )
@ -350,11 +352,13 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
@defer.inlineCallbacks @defer.inlineCallbacks
def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store): def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
state_res_store):
"""Sequentially apply auth checks to each event in given list, updating the """Sequentially apply auth checks to each event in given list, updating the
state as it goes along. state as it goes along.
Args: Args:
room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to event_ids (list[str]): Ordered list of events to apply auth checks to
base_state (dict[tuple[str, str], str]): The set of state to start with base_state (dict[tuple[str, str], str]): The set of state to start with
event_map (dict[str,FrozenEvent]) event_map (dict[str,FrozenEvent])
@ -385,7 +389,7 @@ def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store):
try: try:
event_auth.check( event_auth.check(
event, auth_events, room_version, event, auth_events,
do_sig_check=False, do_sig_check=False,
do_size_check=False do_size_check=False
) )

View File

@ -27,7 +27,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.stringutils import exception_to_unicode from synapse.util.stringutils import exception_to_unicode
@ -196,6 +196,12 @@ class SQLBaseStore(object):
# A set of tables that are not safe to use native upserts in. # A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = {"user_ips"} self._unsafe_to_upsert_tables = {"user_ips"}
# We add the user_directory_search table to the blacklist on SQLite
# because the existing search table does not have an index, making it
# unsafe to use native upserts.
if isinstance(self.database_engine, Sqlite3Engine):
self._unsafe_to_upsert_tables.add("user_directory_search")
if self.database_engine.can_native_upsert: if self.database_engine.can_native_upsert:
# Check ASAP (and then later, every 1s) to see if we have finished # Check ASAP (and then later, every 1s) to see if we have finished
# background updates of tables that aren't safe to update. # background updates of tables that aren't safe to update.
@ -230,7 +236,7 @@ class SQLBaseStore(object):
self._unsafe_to_upsert_tables.discard("user_ips") self._unsafe_to_upsert_tables.discard("user_ips")
# If there's any tables left to check, reschedule to run. # If there's any tables left to check, reschedule to run.
if self._unsafe_to_upsert_tables: if updates:
self._clock.call_later( self._clock.call_later(
15.0, 15.0,
run_as_background_process, run_as_background_process,

View File

@ -240,7 +240,7 @@ class BackgroundUpdateStore(SQLBaseStore):
* An integer count of the number of items to update in this batch. * An integer count of the number of items to update in this batch.
The handler should return a deferred integer count of items updated. The handler should return a deferred integer count of items updated.
The hander is responsible for updating the progress of the update. The handler is responsible for updating the progress of the update.
Args: Args:
update_name(str): The name of the update that this code handles. update_name(str): The name of the update that this code handles.

View File

@ -33,14 +33,10 @@ class Sqlite3Engine(object):
@property @property
def can_native_upsert(self): def can_native_upsert(self):
""" """
Do we support native UPSERTs? Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
more work we haven't done yet to tell what was inserted vs updated.
""" """
# SQLite3 3.24+ supports them, but empirically the unit tests don't work return self.module.sqlite_version_info >= (3, 24, 0)
# when its enabled.
# FIXME: Figure out what is wrong so we can re-enable native upserts
# return self.module.sqlite_version_info >= (3, 24, 0)
return False
def check_database(self, txn): def check_database(self, txn):
pass pass

View File

@ -23,7 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import EventFormatVersions from synapse.api.constants import EventFormatVersions
from synapse.api.errors import NotFoundError from synapse.api.errors import NotFoundError
from synapse.events import FrozenEvent from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
# these are only included to make the type annotations work # these are only included to make the type annotations work
from synapse.events.snapshot import EventContext # noqa: F401 from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
@ -412,11 +412,7 @@ class EventsWorkerStore(SQLBaseStore):
# of a event format version, so it must be a V1 event. # of a event format version, so it must be a V1 event.
format_version = EventFormatVersions.V1 format_version = EventFormatVersions.V1
# TODO: When we implement new event formats we'll need to use a original_ev = event_type_from_format_version(format_version)(
# different event python type
assert format_version == EventFormatVersions.V1
original_ev = FrozenEvent(
event_dict=d, event_dict=d,
internal_metadata_dict=internal_metadata, internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason, rejected_reason=rejected_reason,

View File

@ -197,15 +197,21 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if is_support: if is_support:
return return
is_insert = yield self.runInteraction( yield self.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, "upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
user_id user_id
) )
if is_insert: user_in_mau = self.user_last_seen_monthly_active.cache.get(
self.user_last_seen_monthly_active.invalidate((user_id,)) (user_id,),
None,
update_metrics=False
)
if user_in_mau is None:
self.get_monthly_active_count.invalidate(()) self.get_monthly_active_count.invalidate(())
self.user_last_seen_monthly_active.invalidate((user_id,))
def upsert_monthly_active_user_txn(self, txn, user_id): def upsert_monthly_active_user_txn(self, txn, user_id):
"""Updates or inserts monthly active user member """Updates or inserts monthly active user member

View File

@ -166,11 +166,7 @@ class MatrixFederationAgentTests(TestCase):
""" """
Test the behaviour when the server name contains an explicit IP (with no port) Test the behaviour when the server name contains an explicit IP (with no port)
""" """
# there will be a getaddrinfo on the IP
# the SRV lookup will return an empty list (XXX: why do we even do an SRV lookup?)
self.mock_resolver.resolve_service.side_effect = lambda _: []
# then there will be a getaddrinfo on the IP
self.reactor.lookups["1.2.3.4"] = "1.2.3.4" self.reactor.lookups["1.2.3.4"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar") test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar")
@ -178,10 +174,6 @@ class MatrixFederationAgentTests(TestCase):
# Nothing happened yet # Nothing happened yet
self.assertNoResult(test_d) self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.1.2.3.4",
)
# Make sure treq is trying to connect # Make sure treq is trying to connect
clients = self.reactor.tcpClients clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1) self.assertEqual(len(clients), 1)
@ -209,6 +201,88 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,)) self.reactor.pump((0.1,))
self.successResultOf(test_d) self.successResultOf(test_d)
def test_get_ipv6_address(self):
"""
Test the behaviour when the server name contains an explicit IPv6 address
(with no port)
"""
# there will be a getaddrinfo on the IP
self.reactor.lookups["::1"] = "::1"
test_d = self._make_get_request(b"matrix://[::1]/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '::1')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=None,
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'[::1]'],
)
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_ipv6_address_with_port(self):
"""
Test the behaviour when the server name contains an explicit IPv6 address
(with explicit port)
"""
# there will be a getaddrinfo on the IP
self.reactor.lookups["::1"] = "::1"
test_d = self._make_get_request(b"matrix://[::1]:80/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '::1')
self.assertEqual(port, 80)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=None,
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'[::1]:80'],
)
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_get_hostname_no_srv(self): def test_get_hostname_no_srv(self):
""" """
Test the behaviour when the server name has no port, and no SRV record Test the behaviour when the server name has no port, and no SRV record
@ -258,7 +332,7 @@ class MatrixFederationAgentTests(TestCase):
Test the behaviour when there is a single SRV record Test the behaviour when there is a single SRV record
""" """
self.mock_resolver.resolve_service.side_effect = lambda _: [ self.mock_resolver.resolve_service.side_effect = lambda _: [
Server(host="srvtarget", port=8443) Server(host=b"srvtarget", port=8443)
] ]
self.reactor.lookups["srvtarget"] = "1.2.3.4" self.reactor.lookups["srvtarget"] = "1.2.3.4"
@ -298,6 +372,95 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,)) self.reactor.pump((0.1,))
self.successResultOf(test_d) self.successResultOf(test_d)
def test_idna_servername(self):
"""test the behaviour when the server name has idna chars in"""
self.mock_resolver.resolve_service.side_effect = lambda _: []
# the resolver is always called with the IDNA hostname as a native string.
self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4"
# this is idna for bücher.com
test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.xn--bcher-kva.com",
)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8448)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=b'xn--bcher-kva.com',
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'xn--bcher-kva.com'],
)
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def test_idna_srv_target(self):
"""test the behaviour when the target of a SRV record has idna chars"""
self.mock_resolver.resolve_service.side_effect = lambda _: [
Server(host=b"xn--trget-3qa.com", port=8443) # târget.com
]
self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4"
test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar")
# Nothing happened yet
self.assertNoResult(test_d)
self.mock_resolver.resolve_service.assert_called_once_with(
b"_matrix._tcp.xn--bcher-kva.com",
)
# Make sure treq is trying to connect
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 8443)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
expected_sni=b'xn--bcher-kva.com',
)
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b'GET')
self.assertEqual(request.path, b'/foo/bar')
self.assertEqual(
request.requestHeaders.getRawHeaders(b'host'),
[b'xn--bcher-kva.com'],
)
# finish the request
request.finish()
self.reactor.pump((0.1,))
self.successResultOf(test_d)
def _check_logcontext(context): def _check_logcontext(context):
current = LoggingContext.current_context() current = LoggingContext.current_context()

View File

@ -8,11 +8,10 @@ import attr
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import address, threads, udp from twisted.internet import address, threads, udp
from twisted.internet._resolver import HostResolution from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.address import IPv4Address from twisted.internet.defer import Deferred, fail, succeed
from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver from twisted.internet.interfaces import IReactorPluggableNameResolver, IResolverSimple
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactorClock from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.http import unquote from twisted.web.http import unquote
@ -227,30 +226,16 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def __init__(self): def __init__(self):
self._udp = [] self._udp = []
self.lookups = {} lookups = self.lookups = {}
class Resolver(object): @implementer(IResolverSimple)
def resolveHostName( class FakeResolver(object):
_self, def getHostByName(self, name, timeout=None):
resolutionReceiver, if name not in lookups:
hostName, return fail(DNSLookupError("OH NO: unknown %s" % (name, )))
portNumber=0, return succeed(lookups[name])
addressTypes=None,
transportSemantics='TCP',
):
resolution = HostResolution(hostName) self.nameResolver = SimpleResolverComplexifier(FakeResolver())
resolutionReceiver.resolutionBegan(resolution)
if hostName not in self.lookups:
raise DNSLookupError("OH NO")
resolutionReceiver.addressResolved(
IPv4Address('TCP', self.lookups[hostName], portNumber)
)
resolutionReceiver.resolutionComplete()
return resolution
self.nameResolver = Resolver()
super(ThreadedMemoryReactorClock, self).__init__() super(ThreadedMemoryReactorClock, self).__init__()
def listenUDP(self, port, protocol, interface='', maxPacketSize=8196): def listenUDP(self, port, protocol, interface='', maxPacketSize=8196):

View File

@ -19,7 +19,7 @@ from six.moves import zip
import attr import attr
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership, RoomVersions
from synapse.event_auth import auth_types_for_event from synapse.event_auth import auth_types_for_event
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
@ -539,6 +539,7 @@ class StateTestCase(unittest.TestCase):
state_before = dict(state_at_event[prev_events[0]]) state_before = dict(state_at_event[prev_events[0]])
else: else:
state_d = resolve_events_with_store( state_d = resolve_events_with_store(
RoomVersions.V2,
[state_at_event[n] for n in prev_events], [state_at_event[n] for n in prev_events],
event_map=event_map, event_map=event_map,
state_res_store=TestStateResolutionStore(event_map), state_res_store=TestStateResolutionStore(event_map),
@ -685,6 +686,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
# Test that we correctly handle passing `None` as the event_map # Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store( state_d = resolve_events_with_store(
RoomVersions.V2,
[self.state_at_bob, self.state_at_charlie], [self.state_at_bob, self.state_at_charlie],
event_map=None, event_map=None,
state_res_store=TestStateResolutionStore(self.event_map), state_res_store=TestStateResolutionStore(self.event_map),

View File

@ -49,14 +49,17 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.db_pool.runWithConnection = runWithConnection self.db_pool.runWithConnection = runWithConnection
config = Mock() config = Mock()
config._enable_native_upserts = False config._disable_native_upserts = True
config.event_cache_size = 1 config.event_cache_size = 1
config.database_config = {"name": "sqlite3"} config.database_config = {"name": "sqlite3"}
engine = create_engine(config.database_config)
fake_engine = Mock(wraps=engine)
fake_engine.can_native_upsert = False
hs = TestHomeServer( hs = TestHomeServer(
"test", "test",
db_pool=self.db_pool, db_pool=self.db_pool,
config=config, config=config,
database_engine=create_engine(config.database_config), database_engine=fake_engine,
) )
self.datastore = SQLBaseStore(None, hs) self.datastore = SQLBaseStore(None, hs)

View File

@ -18,12 +18,12 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from tests.unittest import HomeserverTestCase from tests import unittest
FORTY_DAYS = 40 * 24 * 60 * 60 FORTY_DAYS = 40 * 24 * 60 * 60
class MonthlyActiveUsersTestCase(HomeserverTestCase): class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver() hs = self.setup_test_homeserver()

View File

@ -18,7 +18,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
from tests import unittest from tests import unittest
@ -52,6 +52,7 @@ class RedactionTestCase(unittest.TestCase):
content = {"membership": membership} content = {"membership": membership}
content.update(extra_content) content.update(extra_content)
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(
RoomVersions.V1,
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
"sender": user.to_string(), "sender": user.to_string(),
@ -74,6 +75,7 @@ class RedactionTestCase(unittest.TestCase):
self.depth += 1 self.depth += 1
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(
RoomVersions.V1,
{ {
"type": EventTypes.Message, "type": EventTypes.Message,
"sender": user.to_string(), "sender": user.to_string(),
@ -94,6 +96,7 @@ class RedactionTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_redaction(self, room, event_id, user, reason): def inject_redaction(self, room, event_id, user, reason):
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(
RoomVersions.V1,
{ {
"type": EventTypes.Redaction, "type": EventTypes.Redaction,
"sender": user.to_string(), "sender": user.to_string(),

View File

@ -18,7 +18,7 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
from tests import unittest from tests import unittest
@ -50,6 +50,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_room_member(self, room, user, membership, replaces_state=None): def inject_room_member(self, room, user, membership, replaces_state=None):
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(
RoomVersions.V1,
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
"sender": user.to_string(), "sender": user.to_string(),

View File

@ -17,7 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership, RoomVersions
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
@ -52,6 +52,7 @@ class StateStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content): def inject_state_event(self, room, sender, typ, state_key, content):
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(
RoomVersions.V1,
{ {
"type": typ, "type": typ,
"sender": sender.to_string(), "sender": sender.to_string(),

View File

@ -16,6 +16,7 @@
import unittest import unittest
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import RoomVersions
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
@ -35,12 +36,16 @@ class EventAuthTestCase(unittest.TestCase):
} }
# creator should be able to send state # creator should be able to send state
event_auth.check(_random_state_event(creator), auth_events, do_sig_check=False) event_auth.check(
RoomVersions.V1, _random_state_event(creator), auth_events,
do_sig_check=False,
)
# joiner should not be able to send state # joiner should not be able to send state
self.assertRaises( self.assertRaises(
AuthError, AuthError,
event_auth.check, event_auth.check,
RoomVersions.V1,
_random_state_event(joiner), _random_state_event(joiner),
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,
@ -69,13 +74,17 @@ class EventAuthTestCase(unittest.TestCase):
self.assertRaises( self.assertRaises(
AuthError, AuthError,
event_auth.check, event_auth.check,
RoomVersions.V1,
_random_state_event(pleb), _random_state_event(pleb),
auth_events, auth_events,
do_sig_check=False, do_sig_check=False,
), ),
# king should be able to send state # king should be able to send state
event_auth.check(_random_state_event(king), auth_events, do_sig_check=False) event_auth.check(
RoomVersions.V1, _random_state_event(king), auth_events,
do_sig_check=False,
)
# helpers for making events # helpers for making events

View File

@ -17,6 +17,7 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import succeed from twisted.internet.defer import succeed
from synapse.api.constants import RoomVersions
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.visibility import filter_events_for_server from synapse.visibility import filter_events_for_server
@ -124,6 +125,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
def inject_visibility(self, user_id, visibility): def inject_visibility(self, user_id, visibility):
content = {"history_visibility": visibility} content = {"history_visibility": visibility}
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(
RoomVersions.V1,
{ {
"type": "m.room.history_visibility", "type": "m.room.history_visibility",
"sender": user_id, "sender": user_id,
@ -144,6 +146,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
content = {"membership": membership} content = {"membership": membership}
content.update(extra_content) content.update(extra_content)
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(
RoomVersions.V1,
{ {
"type": "m.room.member", "type": "m.room.member",
"sender": user_id, "sender": user_id,
@ -165,6 +168,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
if content is None: if content is None:
content = {"body": "testytest"} content = {"body": "testytest"}
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(
RoomVersions.V1,
{ {
"type": "m.room.message", "type": "m.room.message",
"sender": user_id, "sender": user_id,

View File

@ -26,7 +26,7 @@ from six.moves.urllib import parse as urlparse
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes, RoomVersions
from synapse.api.errors import CodeMessageException, cs_error from synapse.api.errors import CodeMessageException, cs_error
from synapse.config.server import ServerConfig from synapse.config.server import ServerConfig
from synapse.federation.transport import server from synapse.federation.transport import server
@ -624,6 +624,7 @@ def create_room(hs, room_id, creator_id):
event_creation_handler = hs.get_event_creation_handler() event_creation_handler = hs.get_event_creation_handler()
builder = event_builder_factory.new( builder = event_builder_factory.new(
RoomVersions.V1,
{ {
"type": EventTypes.Create, "type": EventTypes.Create,
"state_key": "", "state_key": "",