Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

anoa/log_11772
David Robertson 2021-09-22 13:35:31 +01:00
commit a8340692ab
No known key found for this signature in database
GPG Key ID: 903ECE108A39DEDD
164 changed files with 1903 additions and 1288 deletions

View File

@ -61,6 +61,5 @@ jobs:
uses: peaceiris/actions-gh-pages@068dc23d9710f1ba62e86896f84735d869951305 # v3.8.0 uses: peaceiris/actions-gh-pages@068dc23d9710f1ba62e86896f84735d869951305 # v3.8.0
with: with:
github_token: ${{ secrets.GITHUB_TOKEN }} github_token: ${{ secrets.GITHUB_TOKEN }}
keep_files: true
publish_dir: ./book publish_dir: ./book
destination_dir: ./${{ steps.vars.outputs.branch-version }} destination_dir: ./${{ steps.vars.outputs.branch-version }}

View File

@ -192,6 +192,7 @@ jobs:
volumes: volumes:
- ${{ github.workspace }}:/src - ${{ github.workspace }}:/src
env: env:
SYTEST_BRANCH: ${{ github.head_ref }}
POSTGRES: ${{ matrix.postgres && 1}} POSTGRES: ${{ matrix.postgres && 1}}
MULTI_POSTGRES: ${{ (matrix.postgres == 'multi-postgres') && 1}} MULTI_POSTGRES: ${{ (matrix.postgres == 'multi-postgres') && 1}}
WORKERS: ${{ matrix.workers && 1 }} WORKERS: ${{ matrix.workers && 1 }}

View File

@ -1,7 +1,23 @@
Synapse 1.43.0rc1 (2021-09-14) Synapse 1.43.0 (2021-09-21)
===========================
This release drops support for the deprecated, unstable API for [MSC2858 (Multiple SSO Identity Providers)](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2858-Multiple-SSO-Identity-Providers.md#unstable-prefix), as well as the undocumented `experimental.msc2858_enabled` config option. Client authors should update their clients to use the stable API, available since Synapse 1.30.
The documentation has been updated with configuration for routing `/spaces`, `/hierarchy` and `/summary` to workers. See [the upgrade notes](https://github.com/matrix-org/synapse/blob/release-v1.43/docs/upgrade.md#upgrading-to-v1430) for more details.
No significant changes since 1.43.0rc2.
Synapse 1.43.0rc2 (2021-09-17)
============================== ==============================
This release drops support for the deprecated, unstable API for [MSC2858](https://github.com/matrix-org/matrix-doc/blob/master/proposals/2858-Multiple-SSO-Identity-Providers.md#unstable-prefix), as well as the undocumented `experimental.msc2858_enabled` config option. Client authors should update their clients to use the stable API, available since Synapse 1.30. Bugfixes
--------
- Added opentracing logging to help debug [\#9424](https://github.com/matrix-org/synapse/issues/9424). ([\#10828](https://github.com/matrix-org/synapse/issues/10828))
Synapse 1.43.0rc1 (2021-09-14)
==============================
Features Features
-------- --------

1
changelog.d/10659.misc Normal file
View File

@ -0,0 +1 @@
Fix GitHub Actions config so we can run sytest on synapse from parallel branches.

View File

@ -0,0 +1 @@
Only allow the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send?chunk_id=xxx` endpoint to connect to an already existing insertion event.

1
changelog.d/10777.misc Normal file
View File

@ -0,0 +1 @@
Split out [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) meta events to their own fields in the `/batch_send` response.

1
changelog.d/10785.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to REST servlets.

1
changelog.d/10796.misc Normal file
View File

@ -0,0 +1 @@
Simplify the internal logic which maintains the user directory database tables.

1
changelog.d/10807.bugfix Normal file
View File

@ -0,0 +1 @@
Allow sending a membership event to unban a user. Contributed by @aaronraimist.

1
changelog.d/10810.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a case where logging contexts would go missing when federation requests time out.

1
changelog.d/10812.misc Normal file
View File

@ -0,0 +1 @@
Use direct references to config flags.

View File

@ -0,0 +1 @@
Improve oEmbed previews by processing the author name, photo, and video information.

1
changelog.d/10815.misc Normal file
View File

@ -0,0 +1 @@
Specify the type of token in generic "Invalid token" error messages.

1
changelog.d/10816.misc Normal file
View File

@ -0,0 +1 @@
Make `StateFilter` frozen so it is hashable.

1
changelog.d/10817.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to REST servlets.

1
changelog.d/10823.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to the state database.

View File

@ -1 +0,0 @@
Added opentrace logging to help debug #9424.

1
changelog.d/10829.misc Normal file
View File

@ -0,0 +1 @@
Track cache eviction rates more finely in Prometheus' monitoring.

1
changelog.d/10831.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to handlers.

1
changelog.d/10834.misc Normal file
View File

@ -0,0 +1 @@
Factor out PNG image data to a constant to be used in several tests.

1
changelog.d/10835.misc Normal file
View File

@ -0,0 +1 @@
Add a test to ensure state events sent by modules get persisted correctly.

1
changelog.d/10838.misc Normal file
View File

@ -0,0 +1 @@
Rename [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) fields and event types from `chunk` to `batch` to match the `/batch_send` endpoint.

1
changelog.d/10839.misc Normal file
View File

@ -0,0 +1 @@
Rename [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` query parameter from `?prev_event` to more obvious usage with `?prev_event_id`.

1
changelog.d/10843.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug causing the `remove_stale_pushers` background job to repeatedly fail and log errors. This bug affected Synapse servers that had been upgraded from version 1.28 or older and are using SQLite.

1
changelog.d/10845.doc Normal file
View File

@ -0,0 +1 @@
Fix some crashes in the Module API example code, by adding JSON encoding/decoding.

1
changelog.d/10856.misc Normal file
View File

@ -0,0 +1 @@
Add missing type hints to handlers.

1
changelog.d/10859.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a bug in Unicode support of the room search admin API. It is now possible to search for rooms with non-ASCII characters.

1
changelog.d/10867.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.http.site`.

1
changelog.d/10869.doc Normal file
View File

@ -0,0 +1 @@
Properly remove deleted files from GitHub pages when generating the documentation.

1
changelog.d/10879.misc Normal file
View File

@ -0,0 +1 @@
Include outlier status when we log V2 or V3 events.

12
debian/changelog vendored
View File

@ -1,3 +1,15 @@
matrix-synapse-py3 (1.43.0) stable; urgency=medium
* New synapse release 1.43.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 21 Sep 2021 11:49:05 +0100
matrix-synapse-py3 (1.43.0~rc2) stable; urgency=medium
* New synapse release 1.43.0~rc2.
-- Synapse Packaging team <packages@matrix.org> Fri, 17 Sep 2021 10:43:21 +0100
matrix-synapse-py3 (1.43.0~rc1) stable; urgency=medium matrix-synapse-py3 (1.43.0~rc1) stable; urgency=medium
* New synapse release 1.43.0~rc1. * New synapse release 1.43.0~rc1.

View File

@ -25,16 +25,14 @@ When Synapse is asked to preview a URL it does the following:
3. Kicks off a background process to generate a preview: 3. Kicks off a background process to generate a preview:
1. Checks the database cache by URL and timestamp and returns the result if it 1. Checks the database cache by URL and timestamp and returns the result if it
has not expired and was successful (a 2xx return code). has not expired and was successful (a 2xx return code).
2. Checks if the URL matches an oEmbed pattern. If it does, fetch the oEmbed 2. Checks if the URL matches an [oEmbed](https://oembed.com/) pattern. If it
response. If this is an image, replace the URL to fetch and continue. If does, update the URL to download.
if it is HTML content, use the HTML as the document and continue. 3. Downloads the URL and stores it into a file via the media storage provider
3. If it doesn't match an oEmbed pattern, downloads the URL and stores it and saves the local media metadata.
into a file via the media storage provider and saves the local media 4. If the media is an image:
metadata.
5. If the media is an image:
1. Generates thumbnails. 1. Generates thumbnails.
2. Generates an Open Graph response based on image properties. 2. Generates an Open Graph response based on image properties.
6. If the media is HTML: 5. If the media is HTML:
1. Decodes the HTML via the stored file. 1. Decodes the HTML via the stored file.
2. Generates an Open Graph response from the HTML. 2. Generates an Open Graph response from the HTML.
3. If an image exists in the Open Graph response: 3. If an image exists in the Open Graph response:
@ -42,6 +40,13 @@ When Synapse is asked to preview a URL it does the following:
provider and saves the local media metadata. provider and saves the local media metadata.
2. Generates thumbnails. 2. Generates thumbnails.
3. Updates the Open Graph response based on image properties. 3. Updates the Open Graph response based on image properties.
6. If the media is JSON and an oEmbed URL was found:
1. Convert the oEmbed response to an Open Graph response.
2. If a thumbnail or image is in the oEmbed response:
1. Downloads the URL and stores it into a file via the media storage
provider and saves the local media metadata.
2. Generates thumbnails.
3. Updates the Open Graph response based on image properties.
7. Stores the result in the database cache. 7. Stores the result in the database cache.
4. Returns the result. 4. Returns the result.

View File

@ -136,9 +136,9 @@ class IsUserEvilResource(Resource):
self.evil_users = config.get("evil_users") or [] self.evil_users = config.get("evil_users") or []
def render_GET(self, request: Request): def render_GET(self, request: Request):
user = request.args.get(b"user")[0] user = request.args.get(b"user")[0].decode()
request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Type", b"application/json")
return json.dumps({"evil": user in self.evil_users}) return json.dumps({"evil": user in self.evil_users}).encode()
class ListSpamChecker: class ListSpamChecker:

View File

@ -2362,11 +2362,15 @@ user_directory:
#enabled: false #enabled: false
# Defines whether to search all users visible to your HS when searching # Defines whether to search all users visible to your HS when searching
# the user directory, rather than limiting to users visible in public # the user directory. If false, search results will only contain users
# rooms. Defaults to false. # visible in public rooms and users sharing a room with the requester.
# Defaults to false.
# #
# If you set it true, you'll have to rebuild the user_directory search # NB. If you set this to true, and the last time the user_directory search
# indexes, see: # indexes were (re)built was before Synapse 1.44, you'll have to
# rebuild the indexes in order to search through all known users.
# These indexes are built the first time Synapse starts; admins can
# manually trigger a rebuild following the instructions at
# https://matrix-org.github.io/synapse/latest/user_directory.html # https://matrix-org.github.io/synapse/latest/user_directory.html
# #
# Uncomment to return search results containing all known users, even if that # Uncomment to return search results containing all known users, even if that

View File

@ -60,6 +60,7 @@ files =
synapse/storage/databases/main/session.py, synapse/storage/databases/main/session.py,
synapse/storage/databases/main/stream.py, synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py, synapse/storage/databases/main/ui_auth.py,
synapse/storage/databases/state,
synapse/storage/database.py, synapse/storage/database.py,
synapse/storage/engines, synapse/storage/engines,
synapse/storage/keys.py, synapse/storage/keys.py,
@ -86,10 +87,14 @@ files =
tests/handlers/test_sync.py, tests/handlers/test_sync.py,
tests/rest/client/test_login.py, tests/rest/client/test_login.py,
tests/rest/client/test_auth.py, tests/rest/client/test_auth.py,
tests/storage/test_state.py,
tests/util/test_itertools.py, tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py
[mypy-synapse.rest.client.*] [mypy-synapse.handlers.*]
disallow_untyped_defs = True
[mypy-synapse.rest.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.util.batching_queue] [mypy-synapse.util.batching_queue]

View File

@ -47,7 +47,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.43.0rc1" __version__ = "1.43.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View File

@ -70,8 +70,8 @@ class Auth:
self._auth_blocking = AuthBlocking(self.hs) self._auth_blocking = AuthBlocking(self.hs)
self._track_appservice_user_ips = hs.config.track_appservice_user_ips self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key self._macaroon_secret_key = hs.config.key.macaroon_secret_key
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
async def check_user_in_room( async def check_user_in_room(

View File

@ -30,13 +30,15 @@ class AuthBlocking:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self._hs_disabled = hs.config.hs_disabled self._hs_disabled = hs.config.server.hs_disabled
self._hs_disabled_message = hs.config.hs_disabled_message self._hs_disabled_message = hs.config.server.hs_disabled_message
self._admin_contact = hs.config.admin_contact self._admin_contact = hs.config.server.admin_contact
self._max_mau_value = hs.config.max_mau_value self._max_mau_value = hs.config.server.max_mau_value
self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids self._mau_limits_reserved_threepids = (
hs.config.server.mau_limits_reserved_threepids
)
self._server_name = hs.hostname self._server_name = hs.hostname
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips

View File

@ -121,7 +121,7 @@ class EventTypes:
SpaceParent = "m.space.parent" SpaceParent = "m.space.parent"
MSC2716_INSERTION = "org.matrix.msc2716.insertion" MSC2716_INSERTION = "org.matrix.msc2716.insertion"
MSC2716_CHUNK = "org.matrix.msc2716.chunk" MSC2716_BATCH = "org.matrix.msc2716.batch"
MSC2716_MARKER = "org.matrix.msc2716.marker" MSC2716_MARKER = "org.matrix.msc2716.marker"
@ -209,11 +209,11 @@ class EventContentFields:
# Used on normal messages to indicate they were historically imported after the fact # Used on normal messages to indicate they were historically imported after the fact
MSC2716_HISTORICAL = "org.matrix.msc2716.historical" MSC2716_HISTORICAL = "org.matrix.msc2716.historical"
# For "insertion" events to indicate what the next chunk ID should be in # For "insertion" events to indicate what the next batch ID should be in
# order to connect to it # order to connect to it
MSC2716_NEXT_CHUNK_ID = "org.matrix.msc2716.next_chunk_id" MSC2716_NEXT_BATCH_ID = "org.matrix.msc2716.next_batch_id"
# Used on "chunk" events to indicate which insertion event it connects to # Used on "batch" events to indicate which insertion event it connects to
MSC2716_CHUNK_ID = "org.matrix.msc2716.chunk_id" MSC2716_BATCH_ID = "org.matrix.msc2716.batch_id"
# For "marker" events # For "marker" events
MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion" MSC2716_MARKER_INSERTION = "org.matrix.msc2716.marker.insertion"

View File

@ -244,24 +244,8 @@ class RoomVersions:
msc2716_historical=False, msc2716_historical=False,
msc2716_redactions=False, msc2716_redactions=False,
) )
MSC2716 = RoomVersion( MSC2716v3 = RoomVersion(
"org.matrix.msc2716", "org.matrix.msc2716v3",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=False,
msc3375_redaction_rules=False,
msc2403_knocking=True,
msc2716_historical=True,
msc2716_redactions=False,
)
MSC2716v2 = RoomVersion(
"org.matrix.msc2716v2",
RoomDisposition.UNSTABLE, RoomDisposition.UNSTABLE,
EventFormatVersions.V3, EventFormatVersions.V3,
StateResolutionVersions.V2, StateResolutionVersions.V2,
@ -289,9 +273,9 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V6, RoomVersions.V6,
RoomVersions.MSC2176, RoomVersions.MSC2176,
RoomVersions.V7, RoomVersions.V7,
RoomVersions.MSC2716,
RoomVersions.V8, RoomVersions.V8,
RoomVersions.V9, RoomVersions.V9,
RoomVersions.MSC2716v3,
) )
} }

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, List from typing import Any, List, Tuple, Type
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
@ -25,7 +25,7 @@ class PasswordAuthProviderConfig(Config):
section = "authproviders" section = "authproviders"
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.password_providers: List[Any] = [] self.password_providers: List[Tuple[Type, Any]] = []
providers = [] providers = []
# We want to be backwards compatible with the old `ldap_config` # We want to be backwards compatible with the old `ldap_config`

View File

@ -45,11 +45,15 @@ class UserDirectoryConfig(Config):
#enabled: false #enabled: false
# Defines whether to search all users visible to your HS when searching # Defines whether to search all users visible to your HS when searching
# the user directory, rather than limiting to users visible in public # the user directory. If false, search results will only contain users
# rooms. Defaults to false. # visible in public rooms and users sharing a room with the requester.
# Defaults to false.
# #
# If you set it true, you'll have to rebuild the user_directory search # NB. If you set this to true, and the last time the user_directory search
# indexes, see: # indexes were (re)built was before Synapse 1.44, you'll have to
# rebuild the indexes in order to search through all known users.
# These indexes are built the first time Synapse starts; admins can
# manually trigger a rebuild following the instructions at
# https://matrix-org.github.io/synapse/latest/user_directory.html # https://matrix-org.github.io/synapse/latest/user_directory.html
# #
# Uncomment to return search results containing all known users, even if that # Uncomment to return search results containing all known users, even if that

View File

@ -102,7 +102,7 @@ class FederationPolicyForHTTPS:
self._config = config self._config = config
# Check if we're using a custom list of a CA certificates # Check if we're using a custom list of a CA certificates
trust_root = config.federation_ca_trust_root trust_root = config.tls.federation_ca_trust_root
if trust_root is None: if trust_root is None:
# Use CA root certs provided by OpenSSL # Use CA root certs provided by OpenSSL
trust_root = platformTrust() trust_root = platformTrust()
@ -113,7 +113,7 @@ class FederationPolicyForHTTPS:
# moving to TLS 1.2 by default, we want to respect the config option if # moving to TLS 1.2 by default, we want to respect the config option if
# it is set to 1.0 (which the alternate option, raiseMinimumTo, will not # it is set to 1.0 (which the alternate option, raiseMinimumTo, will not
# let us do). # let us do).
minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version] minTLS = _TLS_VERSION_MAP[config.tls.federation_client_minimum_tls_version]
_verify_ssl = CertificateOptions( _verify_ssl = CertificateOptions(
trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS
@ -125,10 +125,10 @@ class FederationPolicyForHTTPS:
self._no_verify_ssl_context = _no_verify_ssl.getContext() self._no_verify_ssl_context = _no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(_context_info_cb) self._no_verify_ssl_context.set_info_callback(_context_info_cb)
self._should_verify = self._config.federation_verify_certificates self._should_verify = self._config.tls.federation_verify_certificates
self._federation_certificate_verification_whitelist = ( self._federation_certificate_verification_whitelist = (
self._config.federation_certificate_verification_whitelist self._config.tls.federation_certificate_verification_whitelist
) )
def get_options(self, host: bytes): def get_options(self, host: bytes):

View File

@ -572,7 +572,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
super().__init__(hs) super().__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_federation_http_client() self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers self.key_servers = self.config.key.key_servers
async def _fetch_keys( async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest] self, keys_to_fetch: List[_FetchKeyRequest]

View File

@ -213,7 +213,7 @@ def check(
if ( if (
event.type == EventTypes.MSC2716_INSERTION event.type == EventTypes.MSC2716_INSERTION
or event.type == EventTypes.MSC2716_CHUNK or event.type == EventTypes.MSC2716_BATCH
or event.type == EventTypes.MSC2716_MARKER or event.type == EventTypes.MSC2716_MARKER
): ):
check_historical(room_version_obj, event, auth_events) check_historical(room_version_obj, event, auth_events)
@ -552,14 +552,14 @@ def check_historical(
auth_events: StateMap[EventBase], auth_events: StateMap[EventBase],
) -> None: ) -> None:
"""Check whether the event sender is allowed to send historical related """Check whether the event sender is allowed to send historical related
events like "insertion", "chunk", and "marker". events like "insertion", "batch", and "marker".
Returns: Returns:
None None
Raises: Raises:
AuthError if the event sender is not allowed to send historical related events AuthError if the event sender is not allowed to send historical related events
("insertion", "chunk", and "marker"). ("insertion", "batch", and "marker").
""" """
# Ignore the auth checks in room versions that do not support historical # Ignore the auth checks in room versions that do not support historical
# events # events
@ -573,7 +573,7 @@ def check_historical(
if user_level < historical_level: if user_level < historical_level:
raise AuthError( raise AuthError(
403, 403,
'You don\'t have permission to send send historical related events ("insertion", "chunk", and "marker")', 'You don\'t have permission to send send historical related events ("insertion", "batch", and "marker")',
) )

View File

@ -344,6 +344,18 @@ class EventBase(metaclass=abc.ABCMeta):
# this will be a no-op if the event dict is already frozen. # this will be a no-op if the event dict is already frozen.
self._dict = freeze(self._dict) self._dict = freeze(self._dict)
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<%s event_id=%r, type=%r, state_key=%r, outlier=%s>" % (
self.__class__.__name__,
self.event_id,
self.get("type", None),
self.get("state_key", None),
self.internal_metadata.is_outlier(),
)
class FrozenEvent(EventBase): class FrozenEvent(EventBase):
format_version = EventFormatVersions.V1 # All events of this type are V1 format_version = EventFormatVersions.V1 # All events of this type are V1
@ -392,17 +404,6 @@ class FrozenEvent(EventBase):
def event_id(self) -> str: def event_id(self) -> str:
return self._event_id return self._event_id
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<FrozenEvent event_id=%r, type=%r, state_key=%r, outlier=%s>" % (
self.get("event_id", None),
self.get("type", None),
self.get("state_key", None),
self.internal_metadata.is_outlier(),
)
class FrozenEventV2(EventBase): class FrozenEventV2(EventBase):
format_version = EventFormatVersions.V2 # All events of this type are V2 format_version = EventFormatVersions.V2 # All events of this type are V2
@ -478,17 +479,6 @@ class FrozenEventV2(EventBase):
""" """
return self.auth_events return self.auth_events
def __str__(self):
return self.__repr__()
def __repr__(self):
return "<%s event_id=%r, type=%r, state_key=%r>" % (
self.__class__.__name__,
self.event_id,
self.get("type", None),
self.get("state_key", None),
)
class FrozenEventV3(FrozenEventV2): class FrozenEventV3(FrozenEventV2):
"""FrozenEventV3, which differs from FrozenEventV2 only in the event_id format""" """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""

View File

@ -141,9 +141,9 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules: elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
add_fields("redacts") add_fields("redacts")
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION: elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_INSERTION:
add_fields(EventContentFields.MSC2716_NEXT_CHUNK_ID) add_fields(EventContentFields.MSC2716_NEXT_BATCH_ID)
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_CHUNK: elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_BATCH:
add_fields(EventContentFields.MSC2716_CHUNK_ID) add_fields(EventContentFields.MSC2716_BATCH_ID)
elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER: elif room_version.msc2716_redactions and event_type == EventTypes.MSC2716_MARKER:
add_fields(EventContentFields.MSC2716_MARKER_INSERTION) add_fields(EventContentFields.MSC2716_MARKER_INSERTION)

View File

@ -1237,7 +1237,7 @@ class FederationHandlerRegistry:
self._edu_type_to_instance[edu_type] = instance_names self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict) -> None: async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
if not self.config.use_presence and edu_type == EduTypes.Presence: if not self.config.server.use_presence and edu_type == EduTypes.Presence:
return return
# Check if we have a handler on this instance # Check if we have a handler on this instance

View File

@ -594,7 +594,7 @@ class FederationSender(AbstractFederationSender):
destinations (list[str]) destinations (list[str])
""" """
if not states or not self.hs.config.use_presence: if not states or not self.hs.config.server.use_presence:
# No-op if presence is disabled. # No-op if presence is disabled.
return return

View File

@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.types import Requester
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -63,16 +64,21 @@ class BaseHandler:
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
async def ratelimit(self, requester, update=True, is_admin_redaction=False): async def ratelimit(
self,
requester: Requester,
update: bool = True,
is_admin_redaction: bool = False,
) -> None:
"""Ratelimits requests. """Ratelimits requests.
Args: Args:
requester (Requester) requester
update (bool): Whether to record that a request is being processed. update: Whether to record that a request is being processed.
Set to False when doing multiple checks for one request (e.g. Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to to check up front if we would reject the request), and set to
True for the last call for a given request. True for the last call for a given request.
is_admin_redaction (bool): Whether this is a room admin/moderator is_admin_redaction: Whether this is a room admin/moderator
redacting an event. If so then we may apply different redacting an event. If so then we may apply different
ratelimits depending on config. ratelimits depending on config.

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random import random
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, Collection, List, Optional, Tuple
from synapse.replication.http.account_data import ( from synapse.replication.http.account_data import (
ReplicationAddTagRestServlet, ReplicationAddTagRestServlet,
@ -21,6 +21,7 @@ from synapse.replication.http.account_data import (
ReplicationRoomAccountDataRestServlet, ReplicationRoomAccountDataRestServlet,
ReplicationUserAccountDataRestServlet, ReplicationUserAccountDataRestServlet,
) )
from synapse.streams import EventSource
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING: if TYPE_CHECKING:
@ -163,7 +164,7 @@ class AccountDataHandler:
return response["max_stream_id"] return response["max_stream_id"]
class AccountDataEventSource: class AccountDataEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -171,7 +172,13 @@ class AccountDataEventSource:
return self.store.get_max_account_data_stream_id() return self.store.get_max_account_data_stream_id()
async def get_new_events( async def get_new_events(
self, user: UserID, from_key: int, **kwargs self,
user: UserID,
from_key: int,
limit: Optional[int],
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
user_id = user.to_string() user_id = user.to_string()
last_stream_id = from_key last_stream_id = from_key

View File

@ -99,7 +99,7 @@ class AccountValidityHandler:
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None, on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None, on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None, on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
): ) -> None:
"""Register callbacks from module for each hook.""" """Register callbacks from module for each hook."""
if is_user_expired is not None: if is_user_expired is not None:
self._is_user_expired_callbacks.append(is_user_expired) self._is_user_expired_callbacks.append(is_user_expired)
@ -165,7 +165,7 @@ class AccountValidityHandler:
return False return False
async def on_user_registration(self, user_id: str): async def on_user_registration(self, user_id: str) -> None:
"""Tell third-party modules about a user's registration. """Tell third-party modules about a user's registration.
Args: Args:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Union from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
from prometheus_client import Counter from prometheus_client import Counter
@ -58,7 +58,7 @@ class ApplicationServicesHandler:
self.current_max = 0 self.current_max = 0
self.is_processing = False self.is_processing = False
def notify_interested_services(self, max_token: RoomStreamToken): def notify_interested_services(self, max_token: RoomStreamToken) -> None:
"""Notifies (pushes) all application services interested in this event. """Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any Pushing is done asynchronously, so this method won't block for any
@ -82,7 +82,7 @@ class ApplicationServicesHandler:
self._notify_interested_services(max_token) self._notify_interested_services(max_token)
@wrap_as_background_process("notify_interested_services") @wrap_as_background_process("notify_interested_services")
async def _notify_interested_services(self, max_token: RoomStreamToken): async def _notify_interested_services(self, max_token: RoomStreamToken) -> None:
with Measure(self.clock, "notify_interested_services"): with Measure(self.clock, "notify_interested_services"):
self.is_processing = True self.is_processing = True
try: try:
@ -100,7 +100,7 @@ class ApplicationServicesHandler:
for event in events: for event in events:
events_by_room.setdefault(event.room_id, []).append(event) events_by_room.setdefault(event.room_id, []).append(event)
async def handle_event(event): async def handle_event(event: EventBase) -> None:
# Gather interested services # Gather interested services
services = await self._get_services_for_event(event) services = await self._get_services_for_event(event)
if len(services) == 0: if len(services) == 0:
@ -116,9 +116,9 @@ class ApplicationServicesHandler:
if not self.started_scheduler: if not self.started_scheduler:
async def start_scheduler(): async def start_scheduler() -> None:
try: try:
return await self.scheduler.start() await self.scheduler.start()
except Exception: except Exception:
logger.error("Application Services Failure") logger.error("Application Services Failure")
@ -137,7 +137,7 @@ class ApplicationServicesHandler:
"appservice_sender" "appservice_sender"
).observe((now - ts) / 1000) ).observe((now - ts) / 1000)
async def handle_room_events(events): async def handle_room_events(events: Iterable[EventBase]) -> None:
for event in events: for event in events:
await handle_event(event) await handle_event(event)
@ -184,7 +184,7 @@ class ApplicationServicesHandler:
stream_key: str, stream_key: str,
new_token: Optional[int], new_token: Optional[int],
users: Optional[Collection[Union[str, UserID]]] = None, users: Optional[Collection[Union[str, UserID]]] = None,
): ) -> None:
"""This is called by the notifier in the background """This is called by the notifier in the background
when a ephemeral event handled by the homeserver. when a ephemeral event handled by the homeserver.
@ -226,7 +226,7 @@ class ApplicationServicesHandler:
stream_key: str, stream_key: str,
new_token: Optional[int], new_token: Optional[int],
users: Collection[Union[str, UserID]], users: Collection[Union[str, UserID]],
): ) -> None:
logger.debug("Checking interested services for %s" % (stream_key)) logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"): with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services: for service in services:
@ -254,7 +254,7 @@ class ApplicationServicesHandler:
async def _handle_typing( async def _handle_typing(
self, service: ApplicationService, new_token: int self, service: ApplicationService, new_token: int
) -> List[JsonDict]: ) -> List[JsonDict]:
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources.typing
# Get the typing events from just before current # Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as( typing, _ = await typing_source.get_new_events_as(
service=service, service=service,
@ -269,7 +269,7 @@ class ApplicationServicesHandler:
from_key = await self.store.get_type_stream_id_for_appservice( from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt" service, "read_receipt"
) )
receipts_source = self.event_sources.sources["receipt"] receipts_source = self.event_sources.sources.receipt
receipts, _ = await receipts_source.get_new_events_as( receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key service=service, from_key=from_key
) )
@ -279,7 +279,7 @@ class ApplicationServicesHandler:
self, service: ApplicationService, users: Collection[Union[str, UserID]] self, service: ApplicationService, users: Collection[Union[str, UserID]]
) -> List[JsonDict]: ) -> List[JsonDict]:
events: List[JsonDict] = [] events: List[JsonDict] = []
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice( from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence" service, "presence"
) )

View File

@ -29,6 +29,7 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Tuple, Tuple,
Type,
Union, Union,
cast, cast,
) )
@ -439,7 +440,7 @@ class AuthHandler(BaseHandler):
return ui_auth_types return ui_auth_types
def get_enabled_auth_types(self): def get_enabled_auth_types(self) -> Iterable[str]:
"""Return the enabled user-interactive authentication types """Return the enabled user-interactive authentication types
Returns the UI-Auth types which are supported by the homeserver's current Returns the UI-Auth types which are supported by the homeserver's current
@ -702,7 +703,7 @@ class AuthHandler(BaseHandler):
except StoreError: except StoreError:
raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
async def _expire_old_sessions(self): async def _expire_old_sessions(self) -> None:
""" """
Invalidate any user interactive authentication sessions that have expired. Invalidate any user interactive authentication sessions that have expired.
""" """
@ -1347,12 +1348,12 @@ class AuthHandler(BaseHandler):
try: try:
res = self.macaroon_gen.verify_short_term_login_token(login_token) res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception: except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
await self.auth.check_auth_blocking(res.user_id) await self.auth.check_auth_blocking(res.user_id)
return res return res
async def delete_access_token(self, access_token: str): async def delete_access_token(self, access_token: str) -> None:
"""Invalidate a single access token """Invalidate a single access token
Args: Args:
@ -1381,7 +1382,7 @@ class AuthHandler(BaseHandler):
user_id: str, user_id: str,
except_token_id: Optional[int] = None, except_token_id: Optional[int] = None,
device_id: Optional[str] = None, device_id: Optional[str] = None,
): ) -> None:
"""Invalidate access tokens belonging to a user """Invalidate access tokens belonging to a user
Args: Args:
@ -1409,7 +1410,7 @@ class AuthHandler(BaseHandler):
async def add_threepid( async def add_threepid(
self, user_id: str, medium: str, address: str, validated_at: int self, user_id: str, medium: str, address: str, validated_at: int
): ) -> None:
# check if medium has a valid value # check if medium has a valid value
if medium not in ["email", "msisdn"]: if medium not in ["email", "msisdn"]:
raise SynapseError( raise SynapseError(
@ -1480,7 +1481,7 @@ class AuthHandler(BaseHandler):
Hashed password. Hashed password.
""" """
def _do_hash(): def _do_hash() -> str:
# Normalise the Unicode in the password # Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password) pw = unicodedata.normalize("NFKC", password)
@ -1504,7 +1505,7 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash. Whether self.hash(password) == stored_hash.
""" """
def _do_validate_hash(checked_hash: bytes): def _do_validate_hash(checked_hash: bytes) -> bool:
# Normalise the Unicode in the password # Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password) pw = unicodedata.normalize("NFKC", password)
@ -1581,7 +1582,7 @@ class AuthHandler(BaseHandler):
client_redirect_url: str, client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None, extra_attributes: Optional[JsonDict] = None,
new_user: bool = False, new_user: bool = False,
): ) -> None:
"""Having figured out a mxid for this user, complete the HTTP request """Having figured out a mxid for this user, complete the HTTP request
Args: Args:
@ -1627,7 +1628,7 @@ class AuthHandler(BaseHandler):
extra_attributes: Optional[JsonDict] = None, extra_attributes: Optional[JsonDict] = None,
new_user: bool = False, new_user: bool = False,
user_profile_data: Optional[ProfileInfo] = None, user_profile_data: Optional[ProfileInfo] = None,
): ) -> None:
""" """
The synchronous portion of complete_sso_login. The synchronous portion of complete_sso_login.
@ -1726,7 +1727,7 @@ class AuthHandler(BaseHandler):
del self._extra_attributes[user_id] del self._extra_attributes[user_id]
@staticmethod @staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any): def add_query_param_to_url(url: str, param_name: str, param: Any) -> str:
url_parts = list(urllib.parse.urlparse(url)) url_parts = list(urllib.parse.urlparse(url))
query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True) query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
query.append((param_name, param)) query.append((param_name, param))
@ -1734,9 +1735,9 @@ class AuthHandler(BaseHandler):
return urllib.parse.urlunparse(url_parts) return urllib.parse.urlunparse(url_parts)
@attr.s(slots=True) @attr.s(slots=True, auto_attribs=True)
class MacaroonGenerator: class MacaroonGenerator:
hs = attr.ib() hs: "HomeServer"
def generate_guest_access_token(self, user_id: str) -> str: def generate_guest_access_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
@ -1816,7 +1817,9 @@ class PasswordProvider:
""" """
@classmethod @classmethod
def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider": def load(
cls, module: Type, config: JsonDict, module_api: ModuleApi
) -> "PasswordProvider":
try: try:
pp = module(config=config, account_handler=module_api) pp = module(config=config, account_handler=module_api)
except Exception as e: except Exception as e:
@ -1824,7 +1827,7 @@ class PasswordProvider:
raise raise
return cls(pp, module_api) return cls(pp, module_api)
def __init__(self, pp, module_api: ModuleApi): def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
self._pp = pp self._pp = pp
self._module_api = module_api self._module_api = module_api
@ -1838,7 +1841,7 @@ class PasswordProvider:
if g: if g:
self._supported_login_types.update(g()) self._supported_login_types.update(g())
def __str__(self): def __str__(self) -> str:
return str(self._pp) return str(self._pp)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
@ -1876,19 +1879,19 @@ class PasswordProvider:
""" """
# first grandfather in a call to check_password # first grandfather in a call to check_password
if login_type == LoginType.PASSWORD: if login_type == LoginType.PASSWORD:
g = getattr(self._pp, "check_password", None) check_password = getattr(self._pp, "check_password", None)
if g: if check_password:
qualified_user_id = self._module_api.get_qualified_user_id(username) qualified_user_id = self._module_api.get_qualified_user_id(username)
is_valid = await self._pp.check_password( is_valid = await check_password(
qualified_user_id, login_dict["password"] qualified_user_id, login_dict["password"]
) )
if is_valid: if is_valid:
return qualified_user_id, None return qualified_user_id, None
g = getattr(self._pp, "check_auth", None) check_auth = getattr(self._pp, "check_auth", None)
if not g: if not check_auth:
return None return None
result = await g(username, login_type, login_dict) result = await check_auth(username, login_type, login_dict)
# Check if the return value is a str or a tuple # Check if the return value is a str or a tuple
if isinstance(result, str): if isinstance(result, str):

View File

@ -34,20 +34,20 @@ logger = logging.getLogger(__name__)
class CasError(Exception): class CasError(Exception):
"""Used to catch errors when validating the CAS ticket.""" """Used to catch errors when validating the CAS ticket."""
def __init__(self, error, error_description=None): def __init__(self, error: str, error_description: Optional[str] = None):
self.error = error self.error = error
self.error_description = error_description self.error_description = error_description
def __str__(self): def __str__(self) -> str:
if self.error_description: if self.error_description:
return f"{self.error}: {self.error_description}" return f"{self.error}: {self.error_description}"
return self.error return self.error
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class CasResponse: class CasResponse:
username = attr.ib(type=str) username: str
attributes = attr.ib(type=Dict[str, List[Optional[str]]]) attributes: Dict[str, List[Optional[str]]]
class CasHandler: class CasHandler:
@ -133,11 +133,9 @@ class CasHandler:
body = pde.response body = pde.response
except HttpResponseException as e: except HttpResponseException as e:
description = ( description = (
(
'Authorization server responded with a "{status}" error ' 'Authorization server responded with a "{status}" error '
"while exchanging the authorization code." "while exchanging the authorization code."
).format(status=e.code), ).format(status=e.code)
)
raise CasError("server_error", description) from e raise CasError("server_error", description) from e
return self._parse_cas_response(body) return self._parse_cas_response(body)

View File

@ -257,11 +257,8 @@ class DeactivateAccountHandler(BaseHandler):
""" """
# Add the user to the directory, if necessary. # Add the user to the directory, if necessary.
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(user.localpart) profile = await self.store.get_profileinfo(user.localpart)
await self.user_directory_handler.handle_local_profile_change( await self.user_directory_handler.handle_local_profile_change(user_id, profile)
user_id, profile
)
# Ensure the user is not marked as erased. # Ensure the user is not marked as erased.
await self.store.mark_user_not_erased(user_id) await self.store.mark_user_not_erased(user_id)

View File

@ -267,7 +267,7 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)
def _check_device_name_length(self, name: Optional[str]): def _check_device_name_length(self, name: Optional[str]) -> None:
""" """
Checks whether a device name is longer than the maximum allowed length. Checks whether a device name is longer than the maximum allowed length.

View File

@ -202,7 +202,7 @@ class E2eKeysHandler:
# Now fetch any devices that we don't have in our cache # Now fetch any devices that we don't have in our cache
@trace @trace
async def do_remote_query(destination): async def do_remote_query(destination: str) -> None:
"""This is called when we are querying the device list of a user on """This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for cache. If we share a room with this user and we're not querying for
@ -447,7 +447,7 @@ class E2eKeysHandler:
} }
@trace @trace
async def claim_client_keys(destination): async def claim_client_keys(destination: str) -> None:
set_tag("destination", destination) set_tag("destination", destination)
device_keys = remote_queries[destination] device_keys = remote_queries[destination]
try: try:

View File

@ -25,6 +25,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.builder import EventBuilder from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.types import StateMap, get_domain_from_id from synapse.types import StateMap, get_domain_from_id
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -45,7 +46,11 @@ class EventAuthHandler:
self._server_name = hs.hostname self._server_name = hs.hostname
async def check_from_context( async def check_from_context(
self, room_version: str, event, context, do_sig_check=True self,
room_version: str,
event: EventBase,
context: EventContext,
do_sig_check: bool = True,
) -> None: ) -> None:
auth_event_ids = event.auth_event_ids() auth_event_ids = event.auth_event_ids()
auth_events_by_id = await self._store.get_events(auth_event_ids) auth_events_by_id = await self._store.get_events(auth_event_ids)

View File

@ -1221,136 +1221,6 @@ class FederationHandler(BaseHandler):
return missing_events return missing_events
async def construct_auth_difference(
self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
) -> Dict:
"""Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
Params:
local_auth
remote_auth
Returns:
dict
"""
logger.debug("construct_auth_difference Start!")
# TODO: Make sure we are OK with local_auth or remote_auth having more
# auth events in them than strictly necessary.
def sort_fun(ev):
return ev.depth, ev.event_id
logger.debug("construct_auth_difference after sort_fun!")
# We find the differences by starting at the "bottom" of each list
# and iterating up on both lists. The lists are ordered by depth and
# then event_id, we iterate up both lists until we find the event ids
# don't match. Then we look at depth/event_id to see which side is
# missing that event, and iterate only up that list. Repeat.
remote_list = list(remote_auth)
remote_list.sort(key=sort_fun)
local_list = list(local_auth)
local_list.sort(key=sort_fun)
local_iter = iter(local_list)
remote_iter = iter(remote_list)
logger.debug("construct_auth_difference before get_next!")
def get_next(it, opt=None):
try:
return next(it)
except Exception:
return opt
current_local = get_next(local_iter)
current_remote = get_next(remote_iter)
logger.debug("construct_auth_difference before while")
missing_remotes = []
missing_locals = []
while current_local or current_remote:
if current_remote is None:
missing_locals.append(current_local)
current_local = get_next(local_iter)
continue
if current_local is None:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
if current_local.event_id == current_remote.event_id:
current_local = get_next(local_iter)
current_remote = get_next(remote_iter)
continue
if current_local.depth < current_remote.depth:
missing_locals.append(current_local)
current_local = get_next(local_iter)
continue
if current_local.depth > current_remote.depth:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
# They have the same depth, so we fall back to the event_id order
if current_local.event_id < current_remote.event_id:
missing_locals.append(current_local)
current_local = get_next(local_iter)
if current_local.event_id > current_remote.event_id:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
logger.debug("construct_auth_difference after while")
# missing locals should be sent to the server
# We should find why we are missing remotes, as they will have been
# rejected.
# Remove events from missing_remotes if they are referencing a missing
# remote. We only care about the "root" rejected ones.
missing_remote_ids = [e.event_id for e in missing_remotes]
base_remote_rejected = list(missing_remotes)
for e in missing_remotes:
for e_id in e.auth_event_ids():
if e_id in missing_remote_ids:
try:
base_remote_rejected.remove(e)
except ValueError:
pass
reason_map = {}
for e in base_remote_rejected:
reason = await self.store.get_rejection_reason(e.event_id)
if reason is None:
# TODO: e is not in the current state, so we should
# construct some proof of that.
continue
reason_map[e.event_id] = reason
logger.debug("construct_auth_difference returning")
return {
"auth_chain": local_auth,
"rejects": {
e.event_id: {"reason": reason_map[e.event_id], "proof": None}
for e in base_remote_rejected
},
"missing": [e.event_id for e in missing_locals],
}
@log_function @log_function
async def exchange_third_party_invite( async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict

View File

@ -1016,7 +1016,7 @@ class FederationEventHandler:
except Exception: except Exception:
logger.exception("Failed to resync device for %s", sender) logger.exception("Failed to resync device for %s", sender)
async def _handle_marker_event(self, origin: str, marker_event: EventBase): async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
"""Handles backfilling the insertion event when we receive a marker """Handles backfilling the insertion event when we receive a marker
event that points to one. event that points to one.
@ -1109,7 +1109,7 @@ class FederationEventHandler:
event_map: Dict[str, EventBase] = {} event_map: Dict[str, EventBase] = {}
async def get_event(event_id: str): async def get_event(event_id: str) -> None:
with nested_logging_context(event_id): with nested_logging_context(event_id):
try: try:
event = await self._federation_client.get_pdu( event = await self._federation_client.get_pdu(
@ -1218,7 +1218,7 @@ class FederationEventHandler:
if not event_infos: if not event_infos:
return return
async def prep(ev_info: _NewEventInfo): async def prep(ev_info: _NewEventInfo) -> EventContext:
event = ev_info.event event = ev_info.event
with nested_logging_context(suffix=event.event_id): with nested_logging_context(suffix=event.event_id):
res = await self._state_handler.compute_event_context(event) res = await self._state_handler.compute_event_context(event)
@ -1692,7 +1692,7 @@ class FederationEventHandler:
async def _run_push_actions_and_persist_event( async def _run_push_actions_and_persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False self, event: EventBase, context: EventContext, backfilled: bool = False
): ) -> None:
"""Run the push actions for a received event, and persist it. """Run the push actions for a received event, and persist it.
Args: Args:

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Set from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Set
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import GroupID, JsonDict, get_domain_from_id from synapse.types import GroupID, JsonDict, get_domain_from_id
@ -25,12 +25,14 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _create_rerouter(func_name): def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]:
"""Returns an async function that looks at the group id and calls the function """Returns an async function that looks at the group id and calls the function
on federation or the local group server if the group is local on federation or the local group server if the group is local
""" """
async def f(self, group_id, *args, **kwargs): async def f(
self: "GroupsLocalWorkerHandler", group_id: str, *args: Any, **kwargs: Any
) -> JsonDict:
if not GroupID.is_valid(group_id): if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) raise SynapseError(400, "%s is not a legal group ID" % (group_id,))

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
@ -125,7 +125,7 @@ class InitialSyncHandler(BaseHandler):
now_token = self.hs.get_event_sources().get_current_token() now_token = self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"] presence_stream = self.hs.get_event_sources().sources.presence
presence, _ = await presence_stream.get_new_events( presence, _ = await presence_stream.get_new_events(
user, from_key=None, include_offline=False user, from_key=None, include_offline=False
) )
@ -150,7 +150,7 @@ class InitialSyncHandler(BaseHandler):
if limit is None: if limit is None:
limit = 10 limit = 10
async def handle_room(event: RoomsForUser): async def handle_room(event: RoomsForUser) -> None:
d: JsonDict = { d: JsonDict = {
"room_id": event.room_id, "room_id": event.room_id,
"membership": event.membership, "membership": event.membership,
@ -411,9 +411,9 @@ class InitialSyncHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler() presence_handler = self.hs.get_presence_handler()
async def get_presence(): async def get_presence() -> List[JsonDict]:
# If presence is disabled, return an empty list # If presence is disabled, return an empty list
if not self.hs.config.use_presence: if not self.hs.config.server.use_presence:
return [] return []
states = await presence_handler.get_states( states = await presence_handler.get_states(
@ -428,7 +428,7 @@ class InitialSyncHandler(BaseHandler):
for s in states for s in states
] ]
async def get_receipts(): async def get_receipts() -> List[JsonDict]:
receipts = await self.store.get_linearized_receipts_for_room( receipts = await self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key room_id, to_key=now_token.receipt_key
) )

View File

@ -46,6 +46,7 @@ from synapse.events import EventBase
from synapse.events.builder import EventBuilder from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@ -298,7 +299,7 @@ class MessageHandler:
for user_id, profile in users_with_profile.items() for user_id, profile in users_with_profile.items()
} }
def maybe_schedule_expiry(self, event: EventBase): def maybe_schedule_expiry(self, event: EventBase) -> None:
"""Schedule the expiry of an event if there's not already one scheduled, """Schedule the expiry of an event if there's not already one scheduled,
or if the one running is for an event that will expire after the provided or if the one running is for an event that will expire after the provided
timestamp. timestamp.
@ -318,7 +319,7 @@ class MessageHandler:
# a task scheduled for a timestamp that's sooner than the provided one. # a task scheduled for a timestamp that's sooner than the provided one.
self._schedule_expiry_for_event(event.event_id, expiry_ts) self._schedule_expiry_for_event(event.event_id, expiry_ts)
async def _schedule_next_expiry(self): async def _schedule_next_expiry(self) -> None:
"""Retrieve the ID and the expiry timestamp of the next event to be expired, """Retrieve the ID and the expiry timestamp of the next event to be expired,
and schedule an expiry task for it. and schedule an expiry task for it.
@ -331,7 +332,7 @@ class MessageHandler:
event_id, expiry_ts = res event_id, expiry_ts = res
self._schedule_expiry_for_event(event_id, expiry_ts) self._schedule_expiry_for_event(event_id, expiry_ts)
def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int): def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int) -> None:
"""Schedule an expiry task for the provided event if there's not already one """Schedule an expiry task for the provided event if there's not already one
scheduled at a timestamp that's sooner than the provided one. scheduled at a timestamp that's sooner than the provided one.
@ -367,7 +368,7 @@ class MessageHandler:
event_id, event_id,
) )
async def _expire_event(self, event_id: str): async def _expire_event(self, event_id: str) -> None:
"""Retrieve and expire an event that needs to be expired from the database. """Retrieve and expire an event that needs to be expired from the database.
If the event doesn't exist in the database, log it and delete the expiry date If the event doesn't exist in the database, log it and delete the expiry date
@ -1229,7 +1230,10 @@ class EventCreationHandler:
self._external_cache_joined_hosts_updates[state_entry.state_group] = None self._external_cache_joined_hosts_updates[state_entry.state_group] = None
async def _validate_canonical_alias( async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str self,
directory_handler: DirectoryHandler,
room_alias_str: str,
expected_room_id: str,
) -> None: ) -> None:
""" """
Ensure that the given room alias points to the expected room ID. Ensure that the given room alias points to the expected room ID.
@ -1421,7 +1425,7 @@ class EventCreationHandler:
# structural protocol level). # structural protocol level).
is_msc2716_event = ( is_msc2716_event = (
original_event.type == EventTypes.MSC2716_INSERTION original_event.type == EventTypes.MSC2716_INSERTION
or original_event.type == EventTypes.MSC2716_CHUNK or original_event.type == EventTypes.MSC2716_BATCH
or original_event.type == EventTypes.MSC2716_MARKER or original_event.type == EventTypes.MSC2716_MARKER
) )
if not room_version_obj.msc2716_historical and is_msc2716_event: if not room_version_obj.msc2716_historical and is_msc2716_event:
@ -1477,7 +1481,7 @@ class EventCreationHandler:
# If there's an expiry timestamp on the event, schedule its expiry. # If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event) self._message_handler.maybe_schedule_expiry(event)
def _notify(): def _notify() -> None:
try: try:
self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users event, event_pos, max_stream_token, extra_users=extra_users
@ -1523,7 +1527,7 @@ class EventCreationHandler:
except Exception: except Exception:
logger.exception("Error bumping presence active time") logger.exception("Error bumping presence active time")
async def _send_dummy_events_to_fill_extremities(self): async def _send_dummy_events_to_fill_extremities(self) -> None:
"""Background task to send dummy events into rooms that have a large """Background task to send dummy events into rooms that have a large
number of extremities number of extremities
""" """
@ -1600,7 +1604,7 @@ class EventCreationHandler:
) )
return False return False
def _expire_rooms_to_exclude_from_dummy_event_insertion(self): def _expire_rooms_to_exclude_from_dummy_event_insertion(self) -> None:
expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
to_expire = set() to_expire = set()
for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items(): for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items():

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import logging import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode, urlparse from urllib.parse import urlencode, urlparse
import attr import attr
@ -249,11 +249,11 @@ class OidcHandler:
class OidcError(Exception): class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint""" """Used to catch errors when calling the token_endpoint"""
def __init__(self, error, error_description=None): def __init__(self, error: str, error_description: Optional[str] = None):
self.error = error self.error = error
self.error_description = error_description self.error_description = error_description
def __str__(self): def __str__(self) -> str:
if self.error_description: if self.error_description:
return f"{self.error}: {self.error_description}" return f"{self.error}: {self.error_description}"
return self.error return self.error
@ -1057,13 +1057,13 @@ class JwtClientSecret:
self._cached_secret = b"" self._cached_secret = b""
self._cached_secret_replacement_time = 0 self._cached_secret_replacement_time = 0
def __str__(self): def __str__(self) -> str:
# if client_auth_method is client_secret_basic, then ClientAuth.prepare calls # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
# encode_client_secret_basic, which calls "{}".format(secret), which ends up # encode_client_secret_basic, which calls "{}".format(secret), which ends up
# here. # here.
return self._get_secret().decode("ascii") return self._get_secret().decode("ascii")
def __bytes__(self): def __bytes__(self) -> bytes:
# if client_auth_method is client_secret_post, then ClientAuth.prepare calls # if client_auth_method is client_secret_post, then ClientAuth.prepare calls
# encode_client_secret_post, which ends up here. # encode_client_secret_post, which ends up here.
return self._get_secret() return self._get_secret()
@ -1197,21 +1197,21 @@ class OidcSessionTokenGenerator:
) )
@attr.s(frozen=True, slots=True) @attr.s(frozen=True, slots=True, auto_attribs=True)
class OidcSessionData: class OidcSessionData:
"""The attributes which are stored in a OIDC session cookie""" """The attributes which are stored in a OIDC session cookie"""
# the Identity Provider being used # the Identity Provider being used
idp_id = attr.ib(type=str) idp_id: str
# The `nonce` parameter passed to the OIDC provider. # The `nonce` parameter passed to the OIDC provider.
nonce = attr.ib(type=str) nonce: str
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth) # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str) client_redirect_url: str
# The session ID of the ongoing UI Auth ("" if this is a login) # The session ID of the ongoing UI Auth ("" if this is a login)
ui_auth_session_id = attr.ib(type=str) ui_auth_session_id: str
class UserAttributeDict(TypedDict): class UserAttributeDict(TypedDict):
@ -1290,20 +1290,20 @@ class OidcMappingProvider(Generic[C]):
# Used to clear out "None" values in templates # Used to clear out "None" values in templates
def jinja_finalize(thing): def jinja_finalize(thing: Any) -> Any:
return thing if thing is not None else "" return thing if thing is not None else ""
env = Environment(finalize=jinja_finalize) env = Environment(finalize=jinja_finalize)
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class JinjaOidcMappingConfig: class JinjaOidcMappingConfig:
subject_claim = attr.ib(type=str) subject_claim: str
localpart_template = attr.ib(type=Optional[Template]) localpart_template: Optional[Template]
display_name_template = attr.ib(type=Optional[Template]) display_name_template: Optional[Template]
email_template = attr.ib(type=Optional[Template]) email_template: Optional[Template]
extra_attributes = attr.ib(type=Dict[str, Template]) extra_attributes: Dict[str, Template]
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):

View File

@ -15,6 +15,8 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Optional, Set from typing import TYPE_CHECKING, Any, Dict, Optional, Set
import attr
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -24,7 +26,7 @@ from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import Requester from synapse.types import JsonDict, Requester
from synapse.util.async_helpers import ReadWriteLock from synapse.util.async_helpers import ReadWriteLock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -36,15 +38,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True, auto_attribs=True)
class PurgeStatus: class PurgeStatus:
"""Object tracking the status of a purge request """Object tracking the status of a purge request
This class contains information on the progress of a purge request, for This class contains information on the progress of a purge request, for
return by get_purge_status. return by get_purge_status.
Attributes:
status (int): Tracks whether this request has completed. One of
STATUS_{ACTIVE,COMPLETE,FAILED}
""" """
STATUS_ACTIVE = 0 STATUS_ACTIVE = 0
@ -57,10 +56,10 @@ class PurgeStatus:
STATUS_FAILED: "failed", STATUS_FAILED: "failed",
} }
def __init__(self): # Tracks whether this request has completed. One of STATUS_{ACTIVE,COMPLETE,FAILED}.
self.status = PurgeStatus.STATUS_ACTIVE status: int = STATUS_ACTIVE
def asdict(self): def asdict(self) -> JsonDict:
return {"status": PurgeStatus.STATUS_TEXT[self.status]} return {"status": PurgeStatus.STATUS_TEXT[self.status]}
@ -107,7 +106,7 @@ class PaginationHandler:
async def purge_history_for_rooms_in_range( async def purge_history_for_rooms_in_range(
self, min_ms: Optional[int], max_ms: Optional[int] self, min_ms: Optional[int], max_ms: Optional[int]
): ) -> None:
"""Purge outdated events from rooms within the given retention range. """Purge outdated events from rooms within the given retention range.
If a default retention policy is defined in the server's configuration and its If a default retention policy is defined in the server's configuration and its
@ -291,7 +290,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.discard(room_id) self._purges_in_progress_by_room.discard(room_id)
# remove the purge from the list 24 hours after it completes # remove the purge from the list 24 hours after it completes
def clear_purge(): def clear_purge() -> None:
del self._purges_by_id[purge_id] del self._purges_by_id[purge_id]
self.hs.get_reactor().callLater(24 * 3600, clear_purge) self.hs.get_reactor().callLater(24 * 3600, clear_purge)

View File

@ -26,18 +26,22 @@ import contextlib
import logging import logging
from bisect import bisect from bisect import bisect
from contextlib import contextmanager from contextlib import contextmanager
from types import TracebackType
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable,
Callable, Callable,
Collection, Collection,
Dict, Dict,
FrozenSet, FrozenSet,
Generator,
Iterable, Iterable,
List, List,
Optional, Optional,
Set, Set,
Tuple, Tuple,
Type,
Union, Union,
) )
@ -61,6 +65,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import ClearUserSyncsCommand from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.streams import EventSource
from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.descriptors import _CacheContext, cached
@ -240,7 +245,7 @@ class BasePresenceHandler(abc.ABC):
""" """
@abc.abstractmethod @abc.abstractmethod
async def bump_presence_active_time(self, user: UserID): async def bump_presence_active_time(self, user: UserID) -> None:
"""We've seen the user do something that indicates they're interacting """We've seen the user do something that indicates they're interacting
with the app. with the app.
""" """
@ -274,7 +279,7 @@ class BasePresenceHandler(abc.ABC):
async def process_replication_rows( async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list self, stream_name: str, instance_name: str, token: int, rows: list
): ) -> None:
"""Process streams received over replication.""" """Process streams received over replication."""
await self._federation_queue.process_replication_rows( await self._federation_queue.process_replication_rows(
stream_name, instance_name, token, rows stream_name, instance_name, token, rows
@ -286,7 +291,7 @@ class BasePresenceHandler(abc.ABC):
async def maybe_send_presence_to_interested_destinations( async def maybe_send_presence_to_interested_destinations(
self, states: List[UserPresenceState] self, states: List[UserPresenceState]
): ) -> None:
"""If this instance is a federation sender, send the states to all """If this instance is a federation sender, send the states to all
destinations that are interested. Filters out any states for remote destinations that are interested. Filters out any states for remote
users. users.
@ -309,7 +314,7 @@ class BasePresenceHandler(abc.ABC):
for destination, host_states in hosts_to_states.items(): for destination, host_states in hosts_to_states.items():
self._federation.send_presence_to_destinations(host_states, [destination]) self._federation.send_presence_to_destinations(host_states, [destination])
async def send_full_presence_to_users(self, user_ids: Collection[str]): async def send_full_presence_to_users(self, user_ids: Collection[str]) -> None:
""" """
Adds to the list of users who should receive a full snapshot of presence Adds to the list of users who should receive a full snapshot of presence
upon their next sync. Note that this only works for local users. upon their next sync. Note that this only works for local users.
@ -363,7 +368,12 @@ class BasePresenceHandler(abc.ABC):
class _NullContextManager(ContextManager[None]): class _NullContextManager(ContextManager[None]):
"""A context manager which does nothing.""" """A context manager which does nothing."""
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
pass pass
@ -374,7 +384,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self._presence_writer_instance = hs.config.worker.writers.presence[0] self._presence_writer_instance = hs.config.worker.writers.presence[0]
self._presence_enabled = hs.config.use_presence self._presence_enabled = hs.config.server.use_presence
# Route presence EDUs to the right worker # Route presence EDUs to the right worker
hs.get_federation_registry().register_instances_for_edu( hs.get_federation_registry().register_instances_for_edu(
@ -468,7 +478,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
if self._user_to_num_current_syncs[user_id] == 1: if self._user_to_num_current_syncs[user_id] == 1:
self.mark_as_coming_online(user_id) self.mark_as_coming_online(user_id)
def _end(): def _end() -> None:
# We check that the user_id is in user_to_num_current_syncs because # We check that the user_id is in user_to_num_current_syncs because
# user_to_num_current_syncs may have been cleared if we are # user_to_num_current_syncs may have been cleared if we are
# shutting down. # shutting down.
@ -480,7 +490,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.mark_as_going_offline(user_id) self.mark_as_going_offline(user_id)
@contextlib.contextmanager @contextlib.contextmanager
def _user_syncing(): def _user_syncing() -> Generator[None, None, None]:
try: try:
yield yield
finally: finally:
@ -503,7 +513,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
async def process_replication_rows( async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list self, stream_name: str, instance_name: str, token: int, rows: list
): ) -> None:
await super().process_replication_rows(stream_name, instance_name, token, rows) await super().process_replication_rows(stream_name, instance_name, token, rows)
if stream_name != PresenceStream.NAME: if stream_name != PresenceStream.NAME:
@ -584,7 +594,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
user_id = target_user.to_string() user_id = target_user.to_string()
# If presence is disabled, no-op # If presence is disabled, no-op
if not self.hs.config.use_presence: if not self.hs.config.server.use_presence:
return return
# Proxy request to instance that writes presence # Proxy request to instance that writes presence
@ -601,7 +611,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
with the app. with the app.
""" """
# If presence is disabled, no-op # If presence is disabled, no-op
if not self.hs.config.use_presence: if not self.hs.config.server.use_presence:
return return
# Proxy request to instance that writes presence # Proxy request to instance that writes presence
@ -618,7 +628,7 @@ class PresenceHandler(BasePresenceHandler):
self.server_name = hs.hostname self.server_name = hs.hostname
self.wheel_timer: WheelTimer[str] = WheelTimer() self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._presence_enabled = hs.config.use_presence self._presence_enabled = hs.config.server.use_presence
federation_registry = hs.get_federation_registry() federation_registry = hs.get_federation_registry()
@ -689,7 +699,7 @@ class PresenceHandler(BasePresenceHandler):
# Start a LoopingCall in 30s that fires every 5s. # Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to # The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline. # reconnect before we treat them as offline.
def run_timeout_handler(): def run_timeout_handler() -> Awaitable[None]:
return run_as_background_process( return run_as_background_process(
"handle_presence_timeouts", self._handle_timeouts "handle_presence_timeouts", self._handle_timeouts
) )
@ -698,7 +708,7 @@ class PresenceHandler(BasePresenceHandler):
30, self.clock.looping_call, run_timeout_handler, 5000 30, self.clock.looping_call, run_timeout_handler, 5000
) )
def run_persister(): def run_persister() -> Awaitable[None]:
return run_as_background_process( return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes "persist_presence_changes", self._persist_unpersisted_changes
) )
@ -916,7 +926,7 @@ class PresenceHandler(BasePresenceHandler):
with the app. with the app.
""" """
# If presence is disabled, no-op # If presence is disabled, no-op
if not self.hs.config.use_presence: if not self.hs.config.server.use_presence:
return return
user_id = user.to_string() user_id = user.to_string()
@ -942,14 +952,14 @@ class PresenceHandler(BasePresenceHandler):
when users disconnect/reconnect. when users disconnect/reconnect.
Args: Args:
user_id (str) user_id
affect_presence (bool): If false this function will be a no-op. affect_presence: If false this function will be a no-op.
Useful for streams that are not associated with an actual Useful for streams that are not associated with an actual
client that is being used by a user. client that is being used by a user.
""" """
# Override if it should affect the user's presence, if presence is # Override if it should affect the user's presence, if presence is
# disabled. # disabled.
if not self.hs.config.use_presence: if not self.hs.config.server.use_presence:
affect_presence = False affect_presence = False
if affect_presence: if affect_presence:
@ -978,7 +988,7 @@ class PresenceHandler(BasePresenceHandler):
] ]
) )
async def _end(): async def _end() -> None:
try: try:
self.user_to_num_current_syncs[user_id] -= 1 self.user_to_num_current_syncs[user_id] -= 1
@ -994,7 +1004,7 @@ class PresenceHandler(BasePresenceHandler):
logger.exception("Error updating presence after sync") logger.exception("Error updating presence after sync")
@contextmanager @contextmanager
def _user_syncing(): def _user_syncing() -> Generator[None, None, None]:
try: try:
yield yield
finally: finally:
@ -1264,7 +1274,7 @@ class PresenceHandler(BasePresenceHandler):
if self._event_processing: if self._event_processing:
return return
async def _process_presence(): async def _process_presence() -> None:
assert not self._event_processing assert not self._event_processing
self._event_processing = True self._event_processing = True
@ -1491,7 +1501,7 @@ def format_user_presence_state(
return content return content
class PresenceEventSource: class PresenceEventSource(EventSource[int, UserPresenceState]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
# We can't call get_presence_handler here because there's a cycle: # We can't call get_presence_handler here because there's a cycle:
# #
@ -1510,10 +1520,11 @@ class PresenceEventSource:
self, self,
user: UserID, user: UserID,
from_key: Optional[int], from_key: Optional[int],
limit: Optional[int] = None,
room_ids: Optional[List[str]] = None, room_ids: Optional[List[str]] = None,
include_offline: bool = True, is_guest: bool = False,
explicit_room_id: Optional[str] = None, explicit_room_id: Optional[str] = None,
**kwargs, include_offline: bool = True,
) -> Tuple[List[UserPresenceState], int]: ) -> Tuple[List[UserPresenceState], int]:
# The process for getting presence events are: # The process for getting presence events are:
# 1. Get the rooms the user is in. # 1. Get the rooms the user is in.
@ -2074,7 +2085,7 @@ class PresenceFederationQueue:
if self._queue_presence_updates: if self._queue_presence_updates:
self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS) self._clock.looping_call(self._clear_queue, self._CLEAR_ITEMS_EVERY_MS)
def _clear_queue(self): def _clear_queue(self) -> None:
"""Clear out older entries from the queue.""" """Clear out older entries from the queue."""
clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS clear_before = self._clock.time_msec() - self._KEEP_ITEMS_IN_QUEUE_FOR_MS
@ -2205,7 +2216,7 @@ class PresenceFederationQueue:
async def process_replication_rows( async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list self, stream_name: str, instance_name: str, token: int, rows: list
): ) -> None:
if stream_name != PresenceFederationStream.NAME: if stream_name != PresenceFederationStream.NAME:
return return

View File

@ -214,7 +214,6 @@ class ProfileHandler(BaseHandler):
target_user.localpart, displayname_to_set target_user.localpart, displayname_to_set
) )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change( await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile target_user.to_string(), profile
@ -254,7 +253,7 @@ class ProfileHandler(BaseHandler):
requester: Requester, requester: Requester,
new_avatar_url: str, new_avatar_url: str,
by_admin: bool = False, by_admin: bool = False,
): ) -> None:
"""Set a new avatar URL for a user. """Set a new avatar URL for a user.
Args: Args:
@ -300,7 +299,6 @@ class ProfileHandler(BaseHandler):
target_user.localpart, avatar_url_to_set target_user.localpart, avatar_url_to_set
) )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
await self.user_directory_handler.handle_local_profile_change( await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile target_user.to_string(), profile
@ -425,7 +423,7 @@ class ProfileHandler(BaseHandler):
raise raise
@wrap_as_background_process("Update remote profile") @wrap_as_background_process("Update remote profile")
async def _update_remote_profile_cache(self): async def _update_remote_profile_cache(self) -> None:
"""Called periodically to check profiles of remote users we haven't """Called periodically to check profiles of remote users we haven't
checked in a while. checked in a while.
""" """

View File

@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse.api.constants import ReadReceiptEventFields from synapse.api.constants import ReadReceiptEventFields
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.streams import EventSource
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
if TYPE_CHECKING: if TYPE_CHECKING:
@ -162,7 +163,7 @@ class ReceiptsHandler(BaseHandler):
await self.federation_sender.send_read_receipt(receipt) await self.federation_sender.send_read_receipt(receipt)
class ReceiptEventSource: class ReceiptEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.config = hs.config self.config = hs.config
@ -216,7 +217,13 @@ class ReceiptEventSource:
return visible_events return visible_events
async def get_new_events( async def get_new_events(
self, from_key: int, room_ids: List[str], user: UserID, **kwargs self,
user: UserID,
from_key: int,
limit: Optional[int],
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
from_key = int(from_key) from_key = int(from_key)
to_key = self.get_current_key() to_key = self.get_current_key()

View File

@ -125,7 +125,7 @@ class RegistrationHandler(BaseHandler):
localpart: str, localpart: str,
guest_access_token: Optional[str] = None, guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None, assigned_user_id: Optional[str] = None,
): ) -> None:
if types.contains_invalid_mxid_characters(localpart): if types.contains_invalid_mxid_characters(localpart):
raise SynapseError( raise SynapseError(
400, 400,
@ -295,7 +295,6 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned, shadow_banned=shadow_banned,
) )
if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(localpart) profile = await self.store.get_profileinfo(localpart)
await self.user_directory_handler.handle_local_profile_change( await self.user_directory_handler.handle_local_profile_change(
user_id, profile user_id, profile

View File

@ -1,6 +1,4 @@
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2016-2021 The Matrix.org Foundation C.I.C.
# Copyright 2018-2019 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -22,7 +20,16 @@ import math
import random import random
import string import string
from collections import OrderedDict from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Collection,
Dict,
List,
Optional,
Tuple,
)
from synapse.api.constants import ( from synapse.api.constants import (
EventContentFields, EventContentFields,
@ -49,6 +56,7 @@ from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents from synapse.events.utils import copy_power_levels_contents
from synapse.rest.admin._base import assert_user_is_admin from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
MutableStateMap, MutableStateMap,
@ -186,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
async def _upgrade_room( async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion self, requester: Requester, old_room_id: str, new_version: RoomVersion
): ) -> str:
""" """
Args: Args:
requester: the user requesting the upgrade requester: the user requesting the upgrade
@ -512,7 +520,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id: str, old_room_id: str,
new_room_id: str, new_room_id: str,
old_room_state: StateMap[str], old_room_state: StateMap[str],
): ) -> None:
# check to see if we have a canonical alias. # check to see if we have a canonical alias.
canonical_alias_event = None canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, "")) canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
@ -902,7 +910,7 @@ class RoomCreationHandler(BaseHandler):
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""} event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
def create(etype: str, content: JsonDict, **kwargs) -> JsonDict: def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
e = {"type": etype, "content": content} e = {"type": etype, "content": content}
e.update(event_keys) e.update(event_keys)
@ -910,7 +918,7 @@ class RoomCreationHandler(BaseHandler):
return e return e
async def send(etype: str, content: JsonDict, **kwargs) -> int: async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
event = create(etype, content, **kwargs) event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype) logger.debug("Sending %s in new room", etype)
# Allow these events to be sent even if the user is shadow-banned to # Allow these events to be sent even if the user is shadow-banned to
@ -1033,7 +1041,7 @@ class RoomCreationHandler(BaseHandler):
creator_id: str, creator_id: str,
is_public: bool, is_public: bool,
room_version: RoomVersion, room_version: RoomVersion,
): ) -> str:
# autogen room IDs and try to create it. We may clash, so just # autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually. # try a few times till one goes through, giving up eventually.
attempts = 0 attempts = 0
@ -1097,7 +1105,7 @@ class RoomContextHandler:
users = await self.store.get_users_in_room(room_id) users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users is_peeking = user.to_string() not in users
async def filter_evts(events): async def filter_evts(events: List[EventBase]) -> List[EventBase]:
if use_admin_priviledge: if use_admin_priviledge:
return events return events
return await filter_events_for_client( return await filter_events_for_client(
@ -1175,7 +1183,7 @@ class RoomContextHandler:
return results return results
class RoomEventSource: class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -1183,8 +1191,8 @@ class RoomEventSource:
self, self,
user: UserID, user: UserID,
from_key: RoomStreamToken, from_key: RoomStreamToken,
limit: int, limit: Optional[int],
room_ids: List[str], room_ids: Collection[str],
is_guest: bool, is_guest: bool,
explicit_room_id: Optional[str] = None, explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], RoomStreamToken]: ) -> Tuple[List[EventBase], RoomStreamToken]:

View File

@ -14,7 +14,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Any, Optional, Tuple
import msgpack import msgpack
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
@ -33,7 +33,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler from ._base import BaseHandler
@ -169,7 +169,7 @@ class RoomListHandler(BaseHandler):
ignore_non_federatable=from_federation, ignore_non_federatable=from_federation,
) )
def build_room_entry(room): def build_room_entry(room: JsonDict) -> JsonDict:
entry = { entry = {
"room_id": room["room_id"], "room_id": room["room_id"],
"name": room["name"], "name": room["name"],
@ -249,10 +249,10 @@ class RoomListHandler(BaseHandler):
self, self,
room_id: str, room_id: str,
num_joined_users: int, num_joined_users: int,
cache_context, cache_context: _CacheContext,
with_alias: bool = True, with_alias: bool = True,
allow_private: bool = False, allow_private: bool = False,
) -> Optional[dict]: ) -> Optional[JsonDict]:
"""Returns the entry for a room """Returns the entry for a room
Args: Args:
@ -507,7 +507,7 @@ class RoomListNextBatch(
) )
) )
def copy_and_replace(self, **kwds) -> "RoomListNextBatch": def copy_and_replace(self, **kwds: Any) -> "RoomListNextBatch":
return self._replace(**kwds) return self._replace(**kwds)

View File

@ -226,7 +226,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id: Optional[str], room_id: Optional[str],
n_invites: int, n_invites: int,
update: bool = True, update: bool = True,
): ) -> None:
"""Ratelimit more than one invite sent by the given requester in the given room. """Ratelimit more than one invite sent by the given requester in the given room.
Args: Args:
@ -250,7 +250,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
requester: Optional[Requester], requester: Optional[Requester],
room_id: Optional[str], room_id: Optional[str],
invitee_user_id: str, invitee_user_id: str,
): ) -> None:
"""Ratelimit invites by room and by target user. """Ratelimit invites by room and by target user.
If room ID is missing then we just rate limit by target user. If room ID is missing then we just rate limit by target user.
@ -387,7 +387,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
return result_event.event_id, result_event.internal_metadata.stream_ordering return result_event.event_id, result_event.internal_metadata.stream_ordering
async def copy_room_tags_and_direct_to_room( async def copy_room_tags_and_direct_to_room(
self, old_room_id, new_room_id, user_id self, old_room_id: str, new_room_id: str, user_id: str
) -> None: ) -> None:
"""Copies the tags and direct room state from one room to another. """Copies the tags and direct room state from one room to another.
@ -688,7 +688,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
" (membership=%s)" % old_membership, " (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE, errcode=Codes.BAD_STATE,
) )
if old_membership == "ban" and action != "unban": if old_membership == "ban" and action not in ["ban", "unban", "leave"]:
raise SynapseError( raise SynapseError(
403, 403,
"Cannot %s user who was banned" % (action,), "Cannot %s user who was banned" % (action,),
@ -1050,7 +1050,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
ratelimit: bool = True, ratelimit: bool = True,
): ) -> None:
""" """
Change the membership status of a user in a room. Change the membership status of a user in a room.

View File

@ -541,7 +541,7 @@ class RoomSummaryHandler:
origin: str, origin: str,
requested_room_id: str, requested_room_id: str,
suggested_only: bool, suggested_only: bool,
): ) -> JsonDict:
""" """
Implementation of the room hierarchy Federation API. Implementation of the room hierarchy Federation API.

View File

@ -40,15 +40,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@attr.s(slots=True) @attr.s(slots=True, auto_attribs=True)
class Saml2SessionData: class Saml2SessionData:
"""Data we track about SAML2 sessions""" """Data we track about SAML2 sessions"""
# time the session was created, in milliseconds # time the session was created, in milliseconds
creation_time = attr.ib() creation_time: int
# The user interactive authentication session ID associated with this SAML # The user interactive authentication session ID associated with this SAML
# session (or None if this SAML session is for an initial login). # session (or None if this SAML session is for an initial login).
ui_auth_session_id = attr.ib(type=Optional[str], default=None) ui_auth_session_id: Optional[str] = None
class SamlHandler(BaseHandler): class SamlHandler(BaseHandler):
@ -359,7 +359,7 @@ class SamlHandler(BaseHandler):
return remote_user_id return remote_user_id
def expire_sessions(self): def expire_sessions(self) -> None:
expire_before = self.clock.time_msec() - self._saml2_session_lifetime expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set() to_expire = set()
for reqid, data in self._outstanding_requests_dict.items(): for reqid, data in self._outstanding_requests_dict.items():
@ -391,10 +391,10 @@ MXID_MAPPER_MAP: Dict[str, Callable[[str], str]] = {
} }
@attr.s @attr.s(auto_attribs=True)
class SamlConfig: class SamlConfig:
mxid_source_attribute = attr.ib() mxid_source_attribute: str
mxid_mapper = attr.ib() mxid_mapper: Callable[[str], str]
class DefaultSamlMappingProvider: class DefaultSamlMappingProvider:

View File

@ -17,7 +17,7 @@ import logging
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from io import BytesIO from io import BytesIO
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Any, Optional
from pkg_resources import parse_version from pkg_resources import parse_version
@ -79,7 +79,7 @@ async def _sendmail(
msg = BytesIO(msg_bytes) msg = BytesIO(msg_bytes)
d: "Deferred[object]" = Deferred() d: "Deferred[object]" = Deferred()
def build_sender_factory(**kwargs) -> ESMTPSenderFactory: def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory:
return ESMTPSenderFactory( return ESMTPSenderFactory(
username, username,
password, password,

View File

@ -205,7 +205,7 @@ class SsoHandler:
self._consent_at_registration = hs.config.consent.user_consent_at_registration self._consent_at_registration = hs.config.consent.user_consent_at_registration
def register_identity_provider(self, p: SsoIdentityProvider): def register_identity_provider(self, p: SsoIdentityProvider) -> None:
p_id = p.idp_id p_id = p.idp_id
assert p_id not in self._identity_providers assert p_id not in self._identity_providers
self._identity_providers[p_id] = p self._identity_providers[p_id] = p
@ -856,7 +856,7 @@ class SsoHandler:
async def handle_terms_accepted( async def handle_terms_accepted(
self, request: Request, session_id: str, terms_version: str self, request: Request, session_id: str, terms_version: str
): ) -> None:
"""Handle a request to the new-user 'consent' endpoint """Handle a request to the new-user 'consent' endpoint
Will serve an HTTP response to the request. Will serve an HTTP response to the request.
@ -959,7 +959,7 @@ class SsoHandler:
new_user=True, new_user=True,
) )
def _expire_old_sessions(self): def _expire_old_sessions(self) -> None:
to_expire = [] to_expire = []
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())

View File

@ -68,7 +68,7 @@ class StatsHandler:
self._is_processing = True self._is_processing = True
async def process(): async def process() -> None:
try: try:
await self._unsafe_process() await self._unsafe_process()
finally: finally:

View File

@ -364,7 +364,9 @@ class SyncHandler:
) )
else: else:
async def current_sync_callback(before_token, after_token) -> SyncResult: async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken
) -> SyncResult:
return await self.current_sync_for_user(sync_config, since_token) return await self.current_sync_for_user(sync_config, since_token)
result = await self.notifier.wait_for_events( result = await self.notifier.wait_for_events(
@ -441,7 +443,7 @@ class SyncHandler:
room_ids = sync_result_builder.joined_room_ids room_ids = sync_result_builder.joined_room_ids
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources.typing
typing, typing_key = await typing_source.get_new_events( typing, typing_key = await typing_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=typing_key, from_key=typing_key,
@ -463,7 +465,7 @@ class SyncHandler:
receipt_key = since_token.receipt_key if since_token else 0 receipt_key = since_token.receipt_key if since_token else 0
receipt_source = self.event_sources.sources["receipt"] receipt_source = self.event_sources.sources.receipt
receipts, receipt_key = await receipt_source.get_new_events( receipts, receipt_key = await receipt_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=receipt_key, from_key=receipt_key,
@ -1090,7 +1092,7 @@ class SyncHandler:
block_all_presence_data = ( block_all_presence_data = (
since_token is None and sync_config.filter_collection.blocks_all_presence() since_token is None and sync_config.filter_collection.blocks_all_presence()
) )
if self.hs_config.use_presence and not block_all_presence_data: if self.hs_config.server.use_presence and not block_all_presence_data:
logger.debug("Fetching presence data") logger.debug("Fetching presence data")
await self._generate_sync_entry_for_presence( await self._generate_sync_entry_for_presence(
sync_result_builder, sync_result_builder,
@ -1413,7 +1415,7 @@ class SyncHandler:
sync_config = sync_result_builder.sync_config sync_config = sync_result_builder.sync_config
user = sync_result_builder.sync_config.user user = sync_result_builder.sync_config.user
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources.presence
since_token = sync_result_builder.since_token since_token = sync_result_builder.since_token
presence_key = None presence_key = None
@ -1532,9 +1534,9 @@ class SyncHandler:
newly_joined_rooms = room_changes.newly_joined_rooms newly_joined_rooms = room_changes.newly_joined_rooms
newly_left_rooms = room_changes.newly_left_rooms newly_left_rooms = room_changes.newly_left_rooms
async def handle_room_entries(room_entry: "RoomSyncResultBuilder"): async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None:
logger.debug("Generating room entry for %s", room_entry.room_id) logger.debug("Generating room entry for %s", room_entry.room_id)
res = await self._generate_room_entry( await self._generate_room_entry(
sync_result_builder, sync_result_builder,
ignored_users, ignored_users,
room_entry, room_entry,
@ -1544,7 +1546,6 @@ class SyncHandler:
always_include=sync_result_builder.full_state, always_include=sync_result_builder.full_state,
) )
logger.debug("Generated room entry for %s", room_entry.room_id) logger.debug("Generated room entry for %s", room_entry.room_id)
return res
await concurrently_execute(handle_room_entries, room_entries, 10) await concurrently_execute(handle_room_entries, room_entries, 10)
@ -1925,7 +1926,7 @@ class SyncHandler:
tags: Optional[Dict[str, Dict[str, Any]]], tags: Optional[Dict[str, Dict[str, Any]]],
account_data: Dict[str, JsonDict], account_data: Dict[str, JsonDict],
always_include: bool = False, always_include: bool = False,
): ) -> None:
"""Populates the `joined` and `archived` section of `sync_result_builder` """Populates the `joined` and `archived` section of `sync_result_builder`
based on the `room_builder`. based on the `room_builder`.

View File

@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process, wrap_as_background_process,
) )
from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams import TypingStream
from synapse.streams import EventSource
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -439,7 +440,7 @@ class TypingWriterHandler(FollowerTypingHandler):
raise Exception("Typing writer instance got typing info over replication") raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource: class TypingNotificationEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -485,7 +486,13 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial) return (events, handler._latest_room_serial)
async def get_new_events( async def get_new_events(
self, from_key: int, room_ids: Iterable[str], **kwargs self,
user: UserID,
from_key: int,
limit: Optional[int],
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"): with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key) from_key = int(from_key)

View File

@ -70,7 +70,7 @@ class DummyAuthChecker(UserInteractiveAuthChecker):
class TermsAuthChecker(UserInteractiveAuthChecker): class TermsAuthChecker(UserInteractiveAuthChecker):
AUTH_TYPE = LoginType.TERMS AUTH_TYPE = LoginType.TERMS
def is_enabled(self): def is_enabled(self) -> bool:
return True return True
async def check_auth(self, authdict: dict, clientip: str) -> Any: async def check_auth(self, authdict: dict, clientip: str) -> Any:

View File

@ -114,7 +114,7 @@ class UserDirectoryHandler(StateDeltasHandler):
if self._is_processing: if self._is_processing:
return return
async def process(): async def process() -> None:
try: try:
await self._unsafe_process() await self._unsafe_process()
finally: finally:

View File

@ -321,8 +321,11 @@ class SimpleHttpClient:
self.user_agent = hs.version_string self.user_agent = hs.version_string
self.clock = hs.get_clock() self.clock = hs.get_clock()
if hs.config.user_agent_suffix: if hs.config.server.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) self.user_agent = "%s %s" % (
self.user_agent,
hs.config.server.user_agent_suffix,
)
# We use this for our body producers to ensure that they use the correct # We use this for our body producers to ensure that they use the correct
# reactor. # reactor.

View File

@ -66,7 +66,7 @@ from synapse.http.client import (
) )
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_decoder from synapse.util import json_decoder
@ -553,20 +553,29 @@ class MatrixFederationHttpClient:
with Measure(self.clock, "outbound_request"): with Measure(self.clock, "outbound_request"):
# we don't want all the fancy cookie and redirect handling # we don't want all the fancy cookie and redirect handling
# that treq.request gives: just use the raw Agent. # that treq.request gives: just use the raw Agent.
request_deferred = self.agent.request(
# To preserve the logging context, the timeout is treated
# in a similar way to `defer.gatherResults`:
# * Each logging context-preserving fork is wrapped in
# `run_in_background`. In this case there is only one,
# since the timeout fork is not logging-context aware.
# * The `Deferred` that joins the forks back together is
# wrapped in `make_deferred_yieldable` to restore the
# logging context regardless of the path taken.
request_deferred = run_in_background(
self.agent.request,
method_bytes, method_bytes,
url_bytes, url_bytes,
headers=Headers(headers_dict), headers=Headers(headers_dict),
bodyProducer=producer, bodyProducer=producer,
) )
request_deferred = timeout_deferred( request_deferred = timeout_deferred(
request_deferred, request_deferred,
timeout=_sec_timeout, timeout=_sec_timeout,
reactor=self.reactor, reactor=self.reactor,
) )
response = await request_deferred response = await make_deferred_yieldable(request_deferred)
except DNSLookupError as e: except DNSLookupError as e:
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e: except Exception as e:

View File

@ -21,7 +21,7 @@ from zope.interface import implementer
from twisted.internet.interfaces import IAddress, IReactorTime from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.resource import IResource from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig from synapse.config.server import ListenerConfig
@ -61,7 +61,7 @@ class SynapseRequest(Request):
logcontext: the log context for this request logcontext: the log context for this request
""" """
def __init__(self, channel, *args, max_request_body_size=1024, **kw): def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw):
Request.__init__(self, channel, *args, **kw) Request.__init__(self, channel, *args, **kw)
self._max_request_body_size = max_request_body_size self._max_request_body_size = max_request_body_size
self.site: SynapseSite = channel.site self.site: SynapseSite = channel.site
@ -83,13 +83,13 @@ class SynapseRequest(Request):
self._is_processing = False self._is_processing = False
# the time when the asynchronous request handler completed its processing # the time when the asynchronous request handler completed its processing
self._processing_finished_time = None self._processing_finished_time: Optional[float] = None
# what time we finished sending the response to the client (or the connection # what time we finished sending the response to the client (or the connection
# dropped) # dropped)
self.finish_time = None self.finish_time: Optional[float] = None
def __repr__(self): def __repr__(self) -> str:
# We overwrite this so that we don't log ``access_token`` # We overwrite this so that we don't log ``access_token``
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % ( return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__, self.__class__.__name__,
@ -100,7 +100,7 @@ class SynapseRequest(Request):
self.site.site_tag, self.site.site_tag,
) )
def handleContentChunk(self, data): def handleContentChunk(self, data: bytes) -> None:
# we should have a `content` by now. # we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()" assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size: if self.content.tell() + len(data) > self._max_request_body_size:
@ -139,7 +139,7 @@ class SynapseRequest(Request):
# If there's no authenticated entity, it was the requester. # If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester self.logcontext.request.authenticated_entity = authenticated_entity or requester
def get_request_id(self): def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq) return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self) -> str: def get_redacted_uri(self) -> str:
@ -205,7 +205,7 @@ class SynapseRequest(Request):
return None, None return None, None
def render(self, resrc): def render(self, resrc: Resource) -> None:
# this is called once a Resource has been found to serve the request; in our # this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource. # case the Resource in question will normally be a JsonResource.
@ -282,7 +282,7 @@ class SynapseRequest(Request):
if self.finish_time is not None: if self.finish_time is not None:
self._finished_processing() self._finished_processing()
def finish(self): def finish(self) -> None:
"""Called when all response data has been written to this Request. """Called when all response data has been written to this Request.
Overrides twisted.web.server.Request.finish to record the finish time and do Overrides twisted.web.server.Request.finish to record the finish time and do
@ -295,7 +295,7 @@ class SynapseRequest(Request):
with PreserveLoggingContext(self.logcontext): with PreserveLoggingContext(self.logcontext):
self._finished_processing() self._finished_processing()
def connectionLost(self, reason): def connectionLost(self, reason: Union[Failure, Exception]) -> None:
"""Called when the client connection is closed before the response is written. """Called when the client connection is closed before the response is written.
Overrides twisted.web.server.Request.connectionLost to record the finish time and Overrides twisted.web.server.Request.connectionLost to record the finish time and
@ -327,7 +327,7 @@ class SynapseRequest(Request):
if not self._is_processing: if not self._is_processing:
self._finished_processing() self._finished_processing()
def _started_processing(self, servlet_name): def _started_processing(self, servlet_name: str) -> None:
"""Record the fact that we are processing this request. """Record the fact that we are processing this request.
This will log the request's arrival. Once the request completes, This will log the request's arrival. Once the request completes,
@ -354,9 +354,11 @@ class SynapseRequest(Request):
self.get_redacted_uri(), self.get_redacted_uri(),
) )
def _finished_processing(self): def _finished_processing(self) -> None:
"""Log the completion of this request and update the metrics""" """Log the completion of this request and update the metrics"""
assert self.logcontext is not None assert self.logcontext is not None
assert self.finish_time is not None
usage = self.logcontext.get_resource_usage() usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None: if self._processing_finished_time is None:
@ -437,7 +439,7 @@ class XForwardedForRequest(SynapseRequest):
_forwarded_for: "Optional[_XForwardedForAddress]" = None _forwarded_for: "Optional[_XForwardedForAddress]" = None
_forwarded_https: bool = False _forwarded_https: bool = False
def requestReceived(self, command, path, version): def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
# this method is called by the Channel once the full request has been # this method is called by the Channel once the full request has been
# received, to dispatch the request to a resource. # received, to dispatch the request to a resource.
# We can use it to set the IP address and protocol according to the # We can use it to set the IP address and protocol according to the
@ -445,7 +447,7 @@ class XForwardedForRequest(SynapseRequest):
self._process_forwarded_headers() self._process_forwarded_headers()
return super().requestReceived(command, path, version) return super().requestReceived(command, path, version)
def _process_forwarded_headers(self): def _process_forwarded_headers(self) -> None:
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for") headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers: if not headers:
return return
@ -470,7 +472,7 @@ class XForwardedForRequest(SynapseRequest):
) )
self._forwarded_https = True self._forwarded_https = True
def isSecure(self): def isSecure(self) -> bool:
if self._forwarded_https: if self._forwarded_https:
return True return True
return super().isSecure() return super().isSecure()
@ -545,14 +547,16 @@ class SynapseSite(Site):
proxied = config.http_options.x_forwarded proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest request_class = XForwardedForRequest if proxied else SynapseRequest
def request_factory(channel, queued) -> Request: def request_factory(channel, queued: bool) -> Request:
return request_class( return request_class(
channel, max_request_body_size=max_request_body_size, queued=queued channel,
max_request_body_size=max_request_body_size,
queued=queued,
) )
self.requestFactory = request_factory # type: ignore self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii") self.server_version_string = server_version_string.encode("ascii")
def log(self, request): def log(self, request: SynapseRequest) -> None:
pass pass

View File

@ -91,7 +91,7 @@ class ModuleApi:
self._auth = hs.get_auth() self._auth = hs.get_auth()
self._auth_handler = auth_handler self._auth_handler = auth_handler
self._server_name = hs.hostname self._server_name = hs.hostname
self._presence_stream = hs.get_event_sources().sources["presence"] self._presence_stream = hs.get_event_sources().sources.presence
self._state = hs.get_state_handler() self._state = hs.get_state_handler()
self._clock: Clock = hs.get_clock() self._clock: Clock = hs.get_clock()
self._send_email_handler = hs.get_send_email_handler() self._send_email_handler = hs.get_send_email_handler()

View File

@ -584,7 +584,7 @@ class Notifier:
events: List[EventBase] = [] events: List[EventBase] = []
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.items(): for name, source in self.event_sources.sources.get_sources():
keyname = "%s_key" % name keyname = "%s_key" % name
before_id = getattr(before_token, keyname) before_id = getattr(before_token, keyname)
after_id = getattr(after_token, keyname) after_id = getattr(after_token, keyname)

View File

@ -370,7 +370,7 @@ class HttpPusher(Pusher):
if event.type == "m.room.member" and event.is_state(): if event.type == "m.room.member" and event.is_state():
d["notification"]["membership"] = event.content["membership"] d["notification"]["membership"] = event.content["membership"]
d["notification"]["user_is_target"] = event.state_key == self.user_id d["notification"]["user_is_target"] = event.state_key == self.user_id
if self.hs.config.push_include_content and event.content: if self.hs.config.push.push_include_content and event.content:
d["notification"]["content"] = event.content d["notification"]["content"] = event.content
# We no longer send aliases separately, instead, we send the human # We no longer send aliases separately, instead, we send the human

View File

@ -110,7 +110,7 @@ class Mailer:
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
self.storage = hs.get_storage() self.storage = hs.get_storage()
self.app_name = app_name self.app_name = app_name
self.email_subjects: EmailSubjectConfig = hs.config.email_subjects self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects
logger.info("Created Mailer for app_name %s" % app_name) logger.info("Created Mailer for app_name %s" % app_name)
@ -796,8 +796,8 @@ class Mailer:
Returns: Returns:
A link to open a room in the web client. A link to open a room in the web client.
""" """
if self.hs.config.email_riot_base_url: if self.hs.config.email.email_riot_base_url:
base_url = "%s/#/room" % (self.hs.config.email_riot_base_url) base_url = "%s/#/room" % (self.hs.config.email.email_riot_base_url)
elif self.app_name == "Vector": elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS # need /beta for Universal Links to work on iOS
base_url = "https://vector.im/beta/#/room" base_url = "https://vector.im/beta/#/room"
@ -815,9 +815,9 @@ class Mailer:
Returns: Returns:
A link to open the notification in the web client. A link to open the notification in the web client.
""" """
if self.hs.config.email_riot_base_url: if self.hs.config.email.email_riot_base_url:
return "%s/#/room/%s/%s" % ( return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url, self.hs.config.email.email_riot_base_url,
notif["room_id"], notif["room_id"],
notif["event_id"], notif["event_id"],
) )

View File

@ -35,12 +35,12 @@ class PusherFactory:
"http": HttpPusher "http": HttpPusher
} }
logger.info("email enable notifs: %r", hs.config.email_enable_notifs) logger.info("email enable notifs: %r", hs.config.email.email_enable_notifs)
if hs.config.email_enable_notifs: if hs.config.email.email_enable_notifs:
self.mailers: Dict[str, Mailer] = {} self.mailers: Dict[str, Mailer] = {}
self._notif_template_html = hs.config.email_notif_template_html self._notif_template_html = hs.config.email.email_notif_template_html
self._notif_template_text = hs.config.email_notif_template_text self._notif_template_text = hs.config.email.email_notif_template_text
self.pusher_types["email"] = self._create_email_pusher self.pusher_types["email"] = self._create_email_pusher

View File

@ -62,7 +62,7 @@ class PusherPool:
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
# We shard the handling of push notifications by user ID. # We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config self._pusher_shard_config = hs.config.worker.pusher_shard_config
self._instance_name = hs.get_instance_name() self._instance_name = hs.get_instance_name()
self._should_start_pushers = ( self._should_start_pushers = (
self._instance_name in self._pusher_shard_config.instances self._instance_name in self._pusher_shard_config.instances

View File

@ -12,7 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.http.server import JsonResource from typing import TYPE_CHECKING
from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import ( from synapse.rest.client import (
account, account,
@ -57,6 +59,9 @@ from synapse.rest.client import (
voip, voip,
) )
if TYPE_CHECKING:
from synapse.server import HomeServer
class ClientRestResource(JsonResource): class ClientRestResource(JsonResource):
"""Matrix Client API REST resource. """Matrix Client API REST resource.
@ -68,12 +73,12 @@ class ClientRestResource(JsonResource):
* etc * etc
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
JsonResource.__init__(self, hs, canonical_json=False) JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs) self.register_servlets(self, hs)
@staticmethod @staticmethod
def register_servlets(client_resource, hs): def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
versions.register_servlets(hs, client_resource) versions.register_servlets(hs, client_resource)
# Deprecated in r0 # Deprecated in r0

View File

@ -47,7 +47,7 @@ class DeviceRestServlet(RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET( async def on_GET(
self, request: SynapseRequest, user_id, device_id: str self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)

View File

@ -125,7 +125,7 @@ class ListRoomRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
search_term = parse_string(request, "search_term") search_term = parse_string(request, "search_term", encoding="utf-8")
if search_term == "": if search_term == "":
raise SynapseError( raise SynapseError(
400, 400,

View File

@ -57,7 +57,7 @@ class SendServerNoticeServlet(RestServlet):
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
self.txns = HttpTransactionCache(hs) self.txns = HttpTransactionCache(hs)
def register(self, json_resource: HttpServer): def register(self, json_resource: HttpServer) -> None:
PATTERN = "/send_server_notice" PATTERN = "/send_server_notice"
json_resource.register_paths( json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__ "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__

View File

@ -419,7 +419,7 @@ class UserRegisterServlet(RestServlet):
self.nonces: Dict[str, int] = {} self.nonces: Dict[str, int] = {}
self.hs = hs self.hs = hs
def _clear_old_nonces(self): def _clear_old_nonces(self) -> None:
""" """
Clear out old nonces that are older than NONCE_TIMEOUT. Clear out old nonces that are older than NONCE_TIMEOUT.
""" """

View File

@ -14,6 +14,7 @@
import logging import logging
import re import re
from http import HTTPStatus
from typing import TYPE_CHECKING, Awaitable, List, Tuple from typing import TYPE_CHECKING, Awaitable, List, Tuple
from twisted.web.server import Request from twisted.web.server import Request
@ -42,25 +43,25 @@ logger = logging.getLogger(__name__)
class RoomBatchSendEventRestServlet(RestServlet): class RoomBatchSendEventRestServlet(RestServlet):
""" """
API endpoint which can insert a chunk of events historically back in time API endpoint which can insert a batch of events historically back in time
next to the given `prev_event`. next to the given `prev_event`.
`chunk_id` comes from `next_chunk_id `in the response of the batch send `batch_id` comes from `next_batch_id `in the response of the batch send
endpoint and is derived from the "insertion" events added to each chunk. endpoint and is derived from the "insertion" events added to each batch.
It's not required for the first batch send. It's not required for the first batch send.
`state_events_at_start` is used to define the historical state events `state_events_at_start` is used to define the historical state events
needed to auth the events like join events. These events will float needed to auth the events like join events. These events will float
outside of the normal DAG as outlier's and won't be visible in the chat outside of the normal DAG as outlier's and won't be visible in the chat
history which also allows us to insert multiple chunks without having a bunch history which also allows us to insert multiple batches without having a bunch
of `@mxid joined the room` noise between each chunk. of `@mxid joined the room` noise between each batch.
`events` is chronological chunk/list of events you want to insert. `events` is chronological list of events you want to insert.
There is a reverse-chronological constraint on chunks so once you insert There is a reverse-chronological constraint on batches so once you insert
some messages, you can only insert older ones after that. some messages, you can only insert older ones after that.
tldr; Insert chunks from your most recent history -> oldest history. tldr; Insert batches from your most recent history -> oldest history.
POST /_matrix/client/unstable/org.matrix.msc2716/rooms/<roomID>/batch_send?prev_event=<eventID>&chunk_id=<chunkID> POST /_matrix/client/unstable/org.matrix.msc2716/rooms/<roomID>/batch_send?prev_event_id=<eventID>&batch_id=<batchID>
{ {
"events": [ ... ], "events": [ ... ],
"state_events_at_start": [ ... ] "state_events_at_start": [ ... ]
@ -128,7 +129,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
self, sender: str, room_id: str, origin_server_ts: int self, sender: str, room_id: str, origin_server_ts: int
) -> JsonDict: ) -> JsonDict:
"""Creates an event dict for an "insertion" event with the proper fields """Creates an event dict for an "insertion" event with the proper fields
and a random chunk ID. and a random batch ID.
Args: Args:
sender: The event author MXID sender: The event author MXID
@ -139,13 +140,13 @@ class RoomBatchSendEventRestServlet(RestServlet):
The new event dictionary to insert. The new event dictionary to insert.
""" """
next_chunk_id = random_string(8) next_batch_id = random_string(8)
insertion_event = { insertion_event = {
"type": EventTypes.MSC2716_INSERTION, "type": EventTypes.MSC2716_INSERTION,
"sender": sender, "sender": sender,
"room_id": room_id, "room_id": room_id,
"content": { "content": {
EventContentFields.MSC2716_NEXT_CHUNK_ID: next_chunk_id, EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id,
EventContentFields.MSC2716_HISTORICAL: True, EventContentFields.MSC2716_HISTORICAL: True,
}, },
"origin_server_ts": origin_server_ts, "origin_server_ts": origin_server_ts,
@ -179,7 +180,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
if not requester.app_service: if not requester.app_service:
raise AuthError( raise AuthError(
403, HTTPStatus.FORBIDDEN,
"Only application services can use the /batchsend endpoint", "Only application services can use the /batchsend endpoint",
) )
@ -187,24 +188,26 @@ class RoomBatchSendEventRestServlet(RestServlet):
assert_params_in_dict(body, ["state_events_at_start", "events"]) assert_params_in_dict(body, ["state_events_at_start", "events"])
assert request.args is not None assert request.args is not None
prev_events_from_query = parse_strings_from_args(request.args, "prev_event") prev_event_ids_from_query = parse_strings_from_args(
chunk_id_from_query = parse_string(request, "chunk_id") request.args, "prev_event_id"
)
batch_id_from_query = parse_string(request, "batch_id")
if prev_events_from_query is None: if prev_event_ids_from_query is None:
raise SynapseError( raise SynapseError(
400, HTTPStatus.BAD_REQUEST,
"prev_event query parameter is required when inserting historical messages back in time", "prev_event query parameter is required when inserting historical messages back in time",
errcode=Codes.MISSING_PARAM, errcode=Codes.MISSING_PARAM,
) )
# For the event we are inserting next to (`prev_events_from_query`), # For the event we are inserting next to (`prev_event_ids_from_query`),
# find the most recent auth events (derived from state events) that # find the most recent auth events (derived from state events) that
# allowed that message to be sent. We will use that as a base # allowed that message to be sent. We will use that as a base
# to auth our historical messages against. # to auth our historical messages against.
( (
most_recent_prev_event_id, most_recent_prev_event_id,
_, _,
) = await self.store.get_max_depth_of(prev_events_from_query) ) = await self.store.get_max_depth_of(prev_event_ids_from_query)
# mapping from (type, state_key) -> state_event_id # mapping from (type, state_key) -> state_event_id
prev_state_map = await self.state_store.get_state_ids_for_event( prev_state_map = await self.state_store.get_state_ids_for_event(
most_recent_prev_event_id most_recent_prev_event_id
@ -213,7 +216,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
prev_state_ids = list(prev_state_map.values()) prev_state_ids = list(prev_state_map.values())
auth_event_ids = prev_state_ids auth_event_ids = prev_state_ids
state_events_at_start = [] state_event_ids_at_start = []
for state_event in body["state_events_at_start"]: for state_event in body["state_events_at_start"]:
assert_params_in_dict( assert_params_in_dict(
state_event, ["type", "origin_server_ts", "content", "sender"] state_event, ["type", "origin_server_ts", "content", "sender"]
@ -279,27 +282,38 @@ class RoomBatchSendEventRestServlet(RestServlet):
) )
event_id = event.event_id event_id = event.event_id
state_events_at_start.append(event_id) state_event_ids_at_start.append(event_id)
auth_event_ids.append(event_id) auth_event_ids.append(event_id)
events_to_create = body["events"] events_to_create = body["events"]
inherited_depth = await self._inherit_depth_from_prev_ids( inherited_depth = await self._inherit_depth_from_prev_ids(
prev_events_from_query prev_event_ids_from_query
) )
# Figure out which chunk to connect to. If they passed in # Figure out which batch to connect to. If they passed in
# chunk_id_from_query let's use it. The chunk ID passed in comes # batch_id_from_query let's use it. The batch ID passed in comes
# from the chunk_id in the "insertion" event from the previous chunk. # from the batch_id in the "insertion" event from the previous batch.
last_event_in_chunk = events_to_create[-1] last_event_in_batch = events_to_create[-1]
chunk_id_to_connect_to = chunk_id_from_query batch_id_to_connect_to = batch_id_from_query
base_insertion_event = None base_insertion_event = None
if chunk_id_from_query: if batch_id_from_query:
# All but the first base insertion event should point at a fake # All but the first base insertion event should point at a fake
# event, which causes the HS to ask for the state at the start of # event, which causes the HS to ask for the state at the start of
# the chunk later. # the batch later.
prev_event_ids = [fake_prev_event_id] prev_event_ids = [fake_prev_event_id]
# TODO: Verify the chunk_id_from_query corresponds to an insertion event
# Verify the batch_id_from_query corresponds to an actual insertion event
# and have the batch connected.
corresponding_insertion_event_id = (
await self.store.get_insertion_event_by_batch_id(batch_id_from_query)
)
if corresponding_insertion_event_id is None:
raise SynapseError(
400,
"No insertion event corresponds to the given ?batch_id",
errcode=Codes.INVALID_PARAM,
)
pass pass
# Otherwise, create an insertion event to act as a starting point. # Otherwise, create an insertion event to act as a starting point.
# #
@ -309,12 +323,12 @@ class RoomBatchSendEventRestServlet(RestServlet):
# an insertion event), in which case we just create a new insertion event # an insertion event), in which case we just create a new insertion event
# that can then get pointed to by a "marker" event later. # that can then get pointed to by a "marker" event later.
else: else:
prev_event_ids = prev_events_from_query prev_event_ids = prev_event_ids_from_query
base_insertion_event_dict = self._create_insertion_event_dict( base_insertion_event_dict = self._create_insertion_event_dict(
sender=requester.user.to_string(), sender=requester.user.to_string(),
room_id=room_id, room_id=room_id,
origin_server_ts=last_event_in_chunk["origin_server_ts"], origin_server_ts=last_event_in_batch["origin_server_ts"],
) )
base_insertion_event_dict["prev_events"] = prev_event_ids.copy() base_insertion_event_dict["prev_events"] = prev_event_ids.copy()
@ -333,38 +347,38 @@ class RoomBatchSendEventRestServlet(RestServlet):
depth=inherited_depth, depth=inherited_depth,
) )
chunk_id_to_connect_to = base_insertion_event["content"][ batch_id_to_connect_to = base_insertion_event["content"][
EventContentFields.MSC2716_NEXT_CHUNK_ID EventContentFields.MSC2716_NEXT_BATCH_ID
] ]
# Connect this current chunk to the insertion event from the previous chunk # Connect this current batch to the insertion event from the previous batch
chunk_event = { batch_event = {
"type": EventTypes.MSC2716_CHUNK, "type": EventTypes.MSC2716_BATCH,
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
"room_id": room_id, "room_id": room_id,
"content": { "content": {
EventContentFields.MSC2716_CHUNK_ID: chunk_id_to_connect_to, EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to,
EventContentFields.MSC2716_HISTORICAL: True, EventContentFields.MSC2716_HISTORICAL: True,
}, },
# Since the chunk event is put at the end of the chunk, # Since the batch event is put at the end of the batch,
# where the newest-in-time event is, copy the origin_server_ts from # where the newest-in-time event is, copy the origin_server_ts from
# the last event we're inserting # the last event we're inserting
"origin_server_ts": last_event_in_chunk["origin_server_ts"], "origin_server_ts": last_event_in_batch["origin_server_ts"],
} }
# Add the chunk event to the end of the chunk (newest-in-time) # Add the batch event to the end of the batch (newest-in-time)
events_to_create.append(chunk_event) events_to_create.append(batch_event)
# Add an "insertion" event to the start of each chunk (next to the oldest-in-time # Add an "insertion" event to the start of each batch (next to the oldest-in-time
# event in the chunk) so the next chunk can be connected to this one. # event in the batch) so the next batch can be connected to this one.
insertion_event = self._create_insertion_event_dict( insertion_event = self._create_insertion_event_dict(
sender=requester.user.to_string(), sender=requester.user.to_string(),
room_id=room_id, room_id=room_id,
# Since the insertion event is put at the start of the chunk, # Since the insertion event is put at the start of the batch,
# where the oldest-in-time event is, copy the origin_server_ts from # where the oldest-in-time event is, copy the origin_server_ts from
# the first event we're inserting # the first event we're inserting
origin_server_ts=events_to_create[0]["origin_server_ts"], origin_server_ts=events_to_create[0]["origin_server_ts"],
) )
# Prepend the insertion event to the start of the chunk (oldest-in-time) # Prepend the insertion event to the start of the batch (oldest-in-time)
events_to_create = [insertion_event] + events_to_create events_to_create = [insertion_event] + events_to_create
event_ids = [] event_ids = []
@ -424,20 +438,26 @@ class RoomBatchSendEventRestServlet(RestServlet):
context=context, context=context,
) )
# Add the base_insertion_event to the bottom of the list we return insertion_event_id = event_ids[0]
if base_insertion_event is not None: batch_event_id = event_ids[-1]
event_ids.append(base_insertion_event.event_id) historical_event_ids = event_ids[1:-1]
return 200, { response_dict = {
"state_events": state_events_at_start, "state_event_ids": state_event_ids_at_start,
"events": event_ids, "event_ids": historical_event_ids,
"next_chunk_id": insertion_event["content"][ "next_batch_id": insertion_event["content"][
EventContentFields.MSC2716_NEXT_CHUNK_ID EventContentFields.MSC2716_NEXT_BATCH_ID
], ],
"insertion_event_id": insertion_event_id,
"batch_event_id": batch_event_id,
} }
if base_insertion_event is not None:
response_dict["base_insertion_event_id"] = base_insertion_event.event_id
return HTTPStatus.OK, response_dict
def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]: def on_GET(self, request: Request, room_id: str) -> Tuple[int, str]:
return 501, "Not implemented" return HTTPStatus.NOT_IMPLEMENTED, "Not implemented"
def on_PUT( def on_PUT(
self, request: SynapseRequest, room_id: str self, request: SynapseRequest, room_id: str

View File

@ -17,17 +17,22 @@ import logging
from hashlib import sha256 from hashlib import sha256
from http import HTTPStatus from http import HTTPStatus
from os import path from os import path
from typing import Dict, List from typing import TYPE_CHECKING, Any, Dict, List
import jinja2 import jinja2
from jinja2 import TemplateNotFound from jinja2 import TemplateNotFound
from twisted.web.server import Request
from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError from synapse.config import ConfigError
from synapse.http.server import DirectServeHtmlResource, respond_with_html from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_bytes_from_args, parse_string from synapse.http.servlet import parse_bytes_from_args, parse_string
from synapse.types import UserID from synapse.types import UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
# language to use for the templates. TODO: figure this out from Accept-Language # language to use for the templates. TODO: figure this out from Accept-Language
TEMPLATE_LANGUAGE = "en" TEMPLATE_LANGUAGE = "en"
@ -69,11 +74,7 @@ class ConsentResource(DirectServeHtmlResource):
against the user. against the user.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): homeserver
"""
super().__init__() super().__init__()
self.hs = hs self.hs = hs
@ -106,18 +107,14 @@ class ConsentResource(DirectServeHtmlResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8") self._hmac_secret = hs.config.form_secret.encode("utf-8")
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
"""
Args:
request (twisted.web.http.Request):
"""
version = parse_string(request, "v", default=self._default_consent_version) version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", default="") username = parse_string(request, "u", default="")
userhmac = None userhmac = None
has_consented = False has_consented = False
public_version = username == "" public_version = username == ""
if not public_version: if not public_version:
args: Dict[bytes, List[bytes]] = request.args args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac_bytes = parse_bytes_from_args(args, "h", required=True) userhmac_bytes = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac_bytes) self._check_hash(username, userhmac_bytes)
@ -147,14 +144,10 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound: except TemplateNotFound:
raise NotFoundError("Unknown policy version") raise NotFoundError("Unknown policy version")
async def _async_render_POST(self, request): async def _async_render_POST(self, request: Request) -> None:
"""
Args:
request (twisted.web.http.Request):
"""
version = parse_string(request, "v", required=True) version = parse_string(request, "v", required=True)
username = parse_string(request, "u", required=True) username = parse_string(request, "u", required=True)
args: Dict[bytes, List[bytes]] = request.args args: Dict[bytes, List[bytes]] = request.args # type: ignore
userhmac = parse_bytes_from_args(args, "h", required=True) userhmac = parse_bytes_from_args(args, "h", required=True)
self._check_hash(username, userhmac) self._check_hash(username, userhmac)
@ -177,7 +170,9 @@ class ConsentResource(DirectServeHtmlResource):
except TemplateNotFound: except TemplateNotFound:
raise NotFoundError("success.html not found") raise NotFoundError("success.html not found")
def _render_template(self, request, template_name, **template_args): def _render_template(
self, request: Request, template_name: str, **template_args: Any
) -> None:
# get_template checks for ".." so we don't need to worry too much # get_template checks for ".." so we don't need to worry too much
# about path traversal here. # about path traversal here.
template_html = self._jinja_env.get_template( template_html = self._jinja_env.get_template(
@ -186,11 +181,11 @@ class ConsentResource(DirectServeHtmlResource):
html = template_html.render(**template_args) html = template_html.render(**template_args)
respond_with_html(request, 200, html) respond_with_html(request, 200, html)
def _check_hash(self, userid, userhmac): def _check_hash(self, userid: str, userhmac: bytes) -> None:
""" """
Args: Args:
userid (unicode): userid:
userhmac (bytes): userhmac:
Raises: Raises:
SynapseError if the hash doesn't match SynapseError if the hash doesn't match

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
class HealthResource(Resource): class HealthResource(Resource):
@ -25,6 +26,6 @@ class HealthResource(Resource):
isLeaf = 1 isLeaf = 1
def render_GET(self, request): def render_GET(self, request: Request) -> bytes:
request.setHeader(b"Content-Type", b"text/plain") request.setHeader(b"Content-Type", b"text/plain")
return b"OK" return b"OK"

View File

@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING
from twisted.web.resource import Resource from twisted.web.resource import Resource
from .local_key_resource import LocalKey from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey from .remote_key_resource import RemoteKey
if TYPE_CHECKING:
from synapse.server import HomeServer
class KeyApiV2Resource(Resource): class KeyApiV2Resource(Resource):
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
Resource.__init__(self) Resource.__init__(self)
self.putChild(b"server", LocalKey(hs)) self.putChild(b"server", LocalKey(hs))
self.putChild(b"query", RemoteKey(hs)) self.putChild(b"query", RemoteKey(hs))

Some files were not shown because too many files have changed in this diff Show More