Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
commit
512e313f18
|
@ -10,4 +10,7 @@ apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev x
|
||||||
|
|
||||||
export LANG="C.UTF-8"
|
export LANG="C.UTF-8"
|
||||||
|
|
||||||
|
# Prevent virtualenv from auto-updating pip to an incompatible version
|
||||||
|
export VIRTUALENV_NO_DOWNLOAD=1
|
||||||
|
|
||||||
exec tox -e py35-old,combine
|
exec tox -e py35-old,combine
|
||||||
|
|
17
CHANGES.md
17
CHANGES.md
|
@ -1,3 +1,20 @@
|
||||||
|
Synapse 1.26.0rc2 (2021-01-25)
|
||||||
|
==============================
|
||||||
|
|
||||||
|
Bugfixes
|
||||||
|
--------
|
||||||
|
|
||||||
|
- Fix receipts and account data not being sent down sync. Introduced in v1.26.0rc1. ([\#9193](https://github.com/matrix-org/synapse/issues/9193), [\#9195](https://github.com/matrix-org/synapse/issues/9195))
|
||||||
|
- Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1. ([\#9210](https://github.com/matrix-org/synapse/issues/9210))
|
||||||
|
|
||||||
|
|
||||||
|
Internal Changes
|
||||||
|
----------------
|
||||||
|
|
||||||
|
- Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. ([\#9189](https://github.com/matrix-org/synapse/issues/9189))
|
||||||
|
- Bump minimum `psycopg2` version to v2.8. ([\#9204](https://github.com/matrix-org/synapse/issues/9204))
|
||||||
|
|
||||||
|
|
||||||
Synapse 1.26.0rc1 (2021-01-20)
|
Synapse 1.26.0rc1 (2021-01-20)
|
||||||
==============================
|
==============================
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Add tests to `test_user.UsersListTestCase` for List Users Admin API.
|
|
@ -0,0 +1 @@
|
||||||
|
Add admin API for getting and deleting forward extremities for a room.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix spurious errors in logs when deleting a non-existant pusher.
|
|
@ -0,0 +1 @@
|
||||||
|
Various improvements to the federation client.
|
|
@ -0,0 +1 @@
|
||||||
|
Add link to Matrix VoIP tester for turn-howto.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled).
|
|
@ -0,0 +1 @@
|
||||||
|
Fix a long-standing bug where an internal server error was raised when attempting to preview an HTML document in an unknown character encoding.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix a long-standing bug where invalid data could cause errors when calculating the presentable room name for push.
|
|
@ -0,0 +1 @@
|
||||||
|
Speed up chain cover calculation when persisting a batch of state events at once.
|
|
@ -0,0 +1 @@
|
||||||
|
Add a `long_description_type` to the package metadata.
|
|
@ -0,0 +1 @@
|
||||||
|
Speed up batch insertion when using PostgreSQL.
|
|
@ -0,0 +1 @@
|
||||||
|
Emit an error at startup if different Identity Providers are configured with the same `idp_id`.
|
|
@ -0,0 +1 @@
|
||||||
|
Speed up batch insertion when using PostgreSQL.
|
|
@ -1 +0,0 @@
|
||||||
Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration.
|
|
|
@ -0,0 +1 @@
|
||||||
|
Improve performance of concurrent use of `StreamIDGenerators`.
|
|
@ -0,0 +1 @@
|
||||||
|
Add some missing source directories to the automatic linting script.
|
|
@ -1 +0,0 @@
|
||||||
Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1.
|
|
|
@ -1 +0,0 @@
|
||||||
Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1.
|
|
|
@ -0,0 +1 @@
|
||||||
|
Precompute joined hosts and store in Redis.
|
|
@ -1 +0,0 @@
|
||||||
Bump minimum `psycopg2` version.
|
|
|
@ -0,0 +1 @@
|
||||||
|
Add an admin API endpoint for shadow-banning users.
|
|
@ -1 +0,0 @@
|
||||||
Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1.
|
|
|
@ -0,0 +1 @@
|
||||||
|
Fix the Python 3.5 old dependencies build.
|
|
@ -0,0 +1 @@
|
||||||
|
Fix bug where we sometimes didn't detect that Redis connections had died, causing workers to not see new data.
|
|
@ -0,0 +1 @@
|
||||||
|
Update `isort` to v5.7.0 to bypass a bug where it would disagree with `black` about formatting.
|
|
@ -9,6 +9,7 @@
|
||||||
* [Response](#response)
|
* [Response](#response)
|
||||||
* [Undoing room shutdowns](#undoing-room-shutdowns)
|
* [Undoing room shutdowns](#undoing-room-shutdowns)
|
||||||
- [Make Room Admin API](#make-room-admin-api)
|
- [Make Room Admin API](#make-room-admin-api)
|
||||||
|
- [Forward Extremities Admin API](#forward-extremities-admin-api)
|
||||||
|
|
||||||
# List Room API
|
# List Room API
|
||||||
|
|
||||||
|
@ -511,3 +512,55 @@ optionally be specified, e.g.:
|
||||||
"user_id": "@foo:example.com"
|
"user_id": "@foo:example.com"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
# Forward Extremities Admin API
|
||||||
|
|
||||||
|
Enables querying and deleting forward extremities from rooms. When a lot of forward
|
||||||
|
extremities accumulate in a room, performance can become degraded. For details, see
|
||||||
|
[#1760](https://github.com/matrix-org/synapse/issues/1760).
|
||||||
|
|
||||||
|
## Check for forward extremities
|
||||||
|
|
||||||
|
To check the status of forward extremities for a room:
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
|
||||||
|
```
|
||||||
|
|
||||||
|
A response as follows will be returned:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"count": 1,
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"event_id": "$M5SP266vsnxctfwFgFLNceaCo3ujhRtg_NiiHabcdefgh",
|
||||||
|
"state_group": 439,
|
||||||
|
"depth": 123,
|
||||||
|
"received_ts": 1611263016761
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Deleting forward extremities
|
||||||
|
|
||||||
|
**WARNING**: Please ensure you know what you're doing and have read
|
||||||
|
the related issue [#1760](https://github.com/matrix-org/synapse/issues/1760).
|
||||||
|
Under no situations should this API be executed as an automated maintenance task!
|
||||||
|
|
||||||
|
If a room has lots of forward extremities, the extra can be
|
||||||
|
deleted as follows:
|
||||||
|
|
||||||
|
```
|
||||||
|
DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
|
||||||
|
```
|
||||||
|
|
||||||
|
A response as follows will be returned, indicating the amount of forward extremities
|
||||||
|
that were deleted.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"deleted": 1
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
|
@ -760,3 +760,33 @@ The following fields are returned in the JSON response body:
|
||||||
- ``total`` - integer - Number of pushers.
|
- ``total`` - integer - Number of pushers.
|
||||||
|
|
||||||
See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_
|
See also `Client-Server API Spec <https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers>`_
|
||||||
|
|
||||||
|
Shadow-banning users
|
||||||
|
====================
|
||||||
|
|
||||||
|
Shadow-banning is a useful tool for moderating malicious or egregiously abusive users.
|
||||||
|
A shadow-banned users receives successful responses to their client-server API requests,
|
||||||
|
but the events are not propagated into rooms. This can be an effective tool as it
|
||||||
|
(hopefully) takes longer for the user to realise they are being moderated before
|
||||||
|
pivoting to another account.
|
||||||
|
|
||||||
|
Shadow-banning a user should be used as a tool of last resort and may lead to confusing
|
||||||
|
or broken behaviour for the client. A shadow-banned user will not receive any
|
||||||
|
notification and it is generally more appropriate to ban or kick abusive users.
|
||||||
|
A shadow-banned user will be unable to contact anyone on the server.
|
||||||
|
|
||||||
|
The API is::
|
||||||
|
|
||||||
|
POST /_synapse/admin/v1/users/<user_id>/shadow_ban
|
||||||
|
|
||||||
|
To use it, you will need to authenticate by providing an ``access_token`` for a
|
||||||
|
server admin: see `README.rst <README.rst>`_.
|
||||||
|
|
||||||
|
An empty JSON dict is returned.
|
||||||
|
|
||||||
|
**Parameters**
|
||||||
|
|
||||||
|
The following parameters should be set in the URL:
|
||||||
|
|
||||||
|
- ``user_id`` - The fully qualified MXID: for example, ``@user:server.com``. The user must
|
||||||
|
be local.
|
||||||
|
|
|
@ -232,6 +232,12 @@ Here are a few things to try:
|
||||||
|
|
||||||
(Understanding the output is beyond the scope of this document!)
|
(Understanding the output is beyond the scope of this document!)
|
||||||
|
|
||||||
|
* You can test your Matrix homeserver TURN setup with https://test.voip.librepush.net/.
|
||||||
|
Note that this test is not fully reliable yet, so don't be discouraged if
|
||||||
|
the test fails.
|
||||||
|
[Here](https://github.com/matrix-org/voip-tester) is the github repo of the
|
||||||
|
source of the tester, where you can file bug reports.
|
||||||
|
|
||||||
* There is a WebRTC test tool at
|
* There is a WebRTC test tool at
|
||||||
https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To
|
https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To
|
||||||
use it, you will need a username/password for your TURN server. You can
|
use it, you will need a username/password for your TURN server. You can
|
||||||
|
|
|
@ -80,7 +80,8 @@ else
|
||||||
# then lint everything!
|
# then lint everything!
|
||||||
if [[ -z ${files+x} ]]; then
|
if [[ -z ${files+x} ]]; then
|
||||||
# Lint all source code files and directories
|
# Lint all source code files and directories
|
||||||
files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark")
|
# Note: this list aims the mirror the one in tox.ini
|
||||||
|
files=("synapse" "docker" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite")
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
3
setup.py
3
setup.py
|
@ -96,7 +96,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
|
||||||
#
|
#
|
||||||
# We pin black so that our tests don't start failing on new releases.
|
# We pin black so that our tests don't start failing on new releases.
|
||||||
CONDITIONAL_REQUIREMENTS["lint"] = [
|
CONDITIONAL_REQUIREMENTS["lint"] = [
|
||||||
"isort==5.0.3",
|
"isort==5.7.0",
|
||||||
"black==19.10b0",
|
"black==19.10b0",
|
||||||
"flake8-comprehensions",
|
"flake8-comprehensions",
|
||||||
"flake8",
|
"flake8",
|
||||||
|
@ -121,6 +121,7 @@ setup(
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
zip_safe=False,
|
zip_safe=False,
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
|
long_description_content_type="text/x-rst",
|
||||||
python_requires="~=3.5",
|
python_requires="~=3.5",
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Development Status :: 5 - Production/Stable",
|
"Development Status :: 5 - Production/Stable",
|
||||||
|
|
|
@ -15,13 +15,23 @@
|
||||||
|
|
||||||
"""Contains *incomplete* type hints for txredisapi.
|
"""Contains *incomplete* type hints for txredisapi.
|
||||||
"""
|
"""
|
||||||
|
from typing import Any, List, Optional, Type, Union
|
||||||
from typing import List, Optional, Type, Union
|
|
||||||
|
|
||||||
class RedisProtocol:
|
class RedisProtocol:
|
||||||
def publish(self, channel: str, message: bytes): ...
|
def publish(self, channel: str, message: bytes): ...
|
||||||
|
async def ping(self) -> None: ...
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
expire: Optional[int] = None,
|
||||||
|
pexpire: Optional[int] = None,
|
||||||
|
only_if_not_exists: bool = False,
|
||||||
|
only_if_exists: bool = False,
|
||||||
|
) -> None: ...
|
||||||
|
async def get(self, key: str) -> Any: ...
|
||||||
|
|
||||||
class SubscriberProtocol:
|
class SubscriberProtocol(RedisProtocol):
|
||||||
def __init__(self, *args, **kwargs): ...
|
def __init__(self, *args, **kwargs): ...
|
||||||
password: Optional[str]
|
password: Optional[str]
|
||||||
def subscribe(self, channels: Union[str, List[str]]): ...
|
def subscribe(self, channels: Union[str, List[str]]): ...
|
||||||
|
@ -40,14 +50,13 @@ def lazyConnection(
|
||||||
convertNumbers: bool = ...,
|
convertNumbers: bool = ...,
|
||||||
) -> RedisProtocol: ...
|
) -> RedisProtocol: ...
|
||||||
|
|
||||||
class SubscriberFactory:
|
|
||||||
def buildProtocol(self, addr): ...
|
|
||||||
|
|
||||||
class ConnectionHandler: ...
|
class ConnectionHandler: ...
|
||||||
|
|
||||||
class RedisFactory:
|
class RedisFactory:
|
||||||
continueTrying: bool
|
continueTrying: bool
|
||||||
handler: RedisProtocol
|
handler: RedisProtocol
|
||||||
|
pool: List[RedisProtocol]
|
||||||
|
replyTimeout: Optional[int]
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
uuid: str,
|
uuid: str,
|
||||||
|
@ -60,3 +69,7 @@ class RedisFactory:
|
||||||
replyTimeout: Optional[int] = None,
|
replyTimeout: Optional[int] = None,
|
||||||
convertNumbers: Optional[int] = True,
|
convertNumbers: Optional[int] = True,
|
||||||
): ...
|
): ...
|
||||||
|
def buildProtocol(self, addr) -> RedisProtocol: ...
|
||||||
|
|
||||||
|
class SubscriberFactory(RedisFactory):
|
||||||
|
def __init__(self): ...
|
||||||
|
|
|
@ -48,7 +48,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
__version__ = "1.26.0rc1"
|
__version__ = "1.26.0rc2"
|
||||||
|
|
||||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||||
# We import here so that we don't have to install a bunch of deps when
|
# We import here so that we don't have to install a bunch of deps when
|
||||||
|
|
|
@ -18,6 +18,7 @@ from synapse.config import (
|
||||||
password_auth_providers,
|
password_auth_providers,
|
||||||
push,
|
push,
|
||||||
ratelimiting,
|
ratelimiting,
|
||||||
|
redis,
|
||||||
registration,
|
registration,
|
||||||
repository,
|
repository,
|
||||||
room_directory,
|
room_directory,
|
||||||
|
@ -79,6 +80,7 @@ class RootConfig:
|
||||||
roomdirectory: room_directory.RoomDirectoryConfig
|
roomdirectory: room_directory.RoomDirectoryConfig
|
||||||
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
|
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
|
||||||
tracer: tracer.TracerConfig
|
tracer: tracer.TracerConfig
|
||||||
|
redis: redis.RedisConfig
|
||||||
|
|
||||||
config_classes: List = ...
|
config_classes: List = ...
|
||||||
def __init__(self) -> None: ...
|
def __init__(self) -> None: ...
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import string
|
import string
|
||||||
|
from collections import Counter
|
||||||
from typing import Iterable, Optional, Tuple, Type
|
from typing import Iterable, Optional, Tuple, Type
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -43,6 +44,16 @@ class OIDCConfig(Config):
|
||||||
except DependencyException as e:
|
except DependencyException as e:
|
||||||
raise ConfigError(e.message) from e
|
raise ConfigError(e.message) from e
|
||||||
|
|
||||||
|
# check we don't have any duplicate idp_ids now. (The SSO handler will also
|
||||||
|
# check for duplicates when the REST listeners get registered, but that happens
|
||||||
|
# after synapse has forked so doesn't give nice errors.)
|
||||||
|
c = Counter([i.idp_id for i in self.oidc_providers])
|
||||||
|
for idp_id, count in c.items():
|
||||||
|
if count > 1:
|
||||||
|
raise ConfigError(
|
||||||
|
"Multiple OIDC providers have the idp_id %r." % idp_id
|
||||||
|
)
|
||||||
|
|
||||||
public_baseurl = self.public_baseurl
|
public_baseurl = self.public_baseurl
|
||||||
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
|
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ import copy
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
|
@ -26,7 +27,6 @@ from typing import (
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
Tuple,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
@ -61,6 +61,9 @@ from synapse.util import unwrapFirstError
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
|
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
|
||||||
|
@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError):
|
||||||
|
|
||||||
|
|
||||||
class FederationClient(FederationBase):
|
class FederationClient(FederationBase):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.pdu_destination_tried = {}
|
self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
|
||||||
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
|
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.transport_layer = hs.get_federation_transport_client()
|
self.transport_layer = hs.get_federation_transport_client()
|
||||||
|
@ -116,33 +119,32 @@ class FederationClient(FederationBase):
|
||||||
self.pdu_destination_tried[event_id] = destination_dict
|
self.pdu_destination_tried[event_id] = destination_dict
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def make_query(
|
async def make_query(
|
||||||
self,
|
self,
|
||||||
destination,
|
destination: str,
|
||||||
query_type,
|
query_type: str,
|
||||||
args,
|
args: dict,
|
||||||
retry_on_dns_fail=False,
|
retry_on_dns_fail: bool = False,
|
||||||
ignore_backoff=False,
|
ignore_backoff: bool = False,
|
||||||
):
|
) -> JsonDict:
|
||||||
"""Sends a federation Query to a remote homeserver of the given type
|
"""Sends a federation Query to a remote homeserver of the given type
|
||||||
and arguments.
|
and arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination (str): Domain name of the remote homeserver
|
destination: Domain name of the remote homeserver
|
||||||
query_type (str): Category of the query type; should match the
|
query_type: Category of the query type; should match the
|
||||||
handler name used in register_query_handler().
|
handler name used in register_query_handler().
|
||||||
args (dict): Mapping of strings to strings containing the details
|
args: Mapping of strings to strings containing the details
|
||||||
of the query request.
|
of the query request.
|
||||||
ignore_backoff (bool): true to ignore the historical backoff data
|
ignore_backoff: true to ignore the historical backoff data
|
||||||
and try the request anyway.
|
and try the request anyway.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a Awaitable which will eventually yield a JSON object from the
|
The JSON object from the response
|
||||||
response
|
|
||||||
"""
|
"""
|
||||||
sent_queries_counter.labels(query_type).inc()
|
sent_queries_counter.labels(query_type).inc()
|
||||||
|
|
||||||
return self.transport_layer.make_query(
|
return await self.transport_layer.make_query(
|
||||||
destination,
|
destination,
|
||||||
query_type,
|
query_type,
|
||||||
args,
|
args,
|
||||||
|
@ -151,42 +153,52 @@ class FederationClient(FederationBase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def query_client_keys(self, destination, content, timeout):
|
async def query_client_keys(
|
||||||
|
self, destination: str, content: JsonDict, timeout: int
|
||||||
|
) -> JsonDict:
|
||||||
"""Query device keys for a device hosted on a remote server.
|
"""Query device keys for a device hosted on a remote server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination (str): Domain name of the remote homeserver
|
destination: Domain name of the remote homeserver
|
||||||
content (dict): The query content.
|
content: The query content.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
an Awaitable which will eventually yield a JSON object from the
|
The JSON object from the response
|
||||||
response
|
|
||||||
"""
|
"""
|
||||||
sent_queries_counter.labels("client_device_keys").inc()
|
sent_queries_counter.labels("client_device_keys").inc()
|
||||||
return self.transport_layer.query_client_keys(destination, content, timeout)
|
return await self.transport_layer.query_client_keys(
|
||||||
|
destination, content, timeout
|
||||||
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def query_user_devices(self, destination, user_id, timeout=30000):
|
async def query_user_devices(
|
||||||
|
self, destination: str, user_id: str, timeout: int = 30000
|
||||||
|
) -> JsonDict:
|
||||||
"""Query the device keys for a list of user ids hosted on a remote
|
"""Query the device keys for a list of user ids hosted on a remote
|
||||||
server.
|
server.
|
||||||
"""
|
"""
|
||||||
sent_queries_counter.labels("user_devices").inc()
|
sent_queries_counter.labels("user_devices").inc()
|
||||||
return self.transport_layer.query_user_devices(destination, user_id, timeout)
|
return await self.transport_layer.query_user_devices(
|
||||||
|
destination, user_id, timeout
|
||||||
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def claim_client_keys(self, destination, content, timeout):
|
async def claim_client_keys(
|
||||||
|
self, destination: str, content: JsonDict, timeout: int
|
||||||
|
) -> JsonDict:
|
||||||
"""Claims one-time keys for a device hosted on a remote server.
|
"""Claims one-time keys for a device hosted on a remote server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
destination (str): Domain name of the remote homeserver
|
destination: Domain name of the remote homeserver
|
||||||
content (dict): The query content.
|
content: The query content.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
an Awaitable which will eventually yield a JSON object from the
|
The JSON object from the response
|
||||||
response
|
|
||||||
"""
|
"""
|
||||||
sent_queries_counter.labels("client_one_time_keys").inc()
|
sent_queries_counter.labels("client_one_time_keys").inc()
|
||||||
return self.transport_layer.claim_client_keys(destination, content, timeout)
|
return await self.transport_layer.claim_client_keys(
|
||||||
|
destination, content, timeout
|
||||||
|
)
|
||||||
|
|
||||||
async def backfill(
|
async def backfill(
|
||||||
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
|
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
|
||||||
|
@ -195,10 +207,10 @@ class FederationClient(FederationBase):
|
||||||
given destination server.
|
given destination server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dest (str): The remote homeserver to ask.
|
dest: The remote homeserver to ask.
|
||||||
room_id (str): The room_id to backfill.
|
room_id: The room_id to backfill.
|
||||||
limit (int): The maximum number of events to return.
|
limit: The maximum number of events to return.
|
||||||
extremities (list): our current backwards extremities, to backfill from
|
extremities: our current backwards extremities, to backfill from
|
||||||
"""
|
"""
|
||||||
logger.debug("backfill extrem=%s", extremities)
|
logger.debug("backfill extrem=%s", extremities)
|
||||||
|
|
||||||
|
@ -370,7 +382,7 @@ class FederationClient(FederationBase):
|
||||||
for events that have failed their checks
|
for events that have failed their checks
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred : A list of PDUs that have valid signatures and hashes.
|
A list of PDUs that have valid signatures and hashes.
|
||||||
"""
|
"""
|
||||||
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
||||||
|
|
||||||
|
@ -418,7 +430,9 @@ class FederationClient(FederationBase):
|
||||||
else:
|
else:
|
||||||
return [p for p in valid_pdus if p]
|
return [p for p in valid_pdus if p]
|
||||||
|
|
||||||
async def get_event_auth(self, destination, room_id, event_id):
|
async def get_event_auth(
|
||||||
|
self, destination: str, room_id: str, event_id: str
|
||||||
|
) -> List[EventBase]:
|
||||||
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
|
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
|
||||||
|
|
||||||
room_version = await self.store.get_room_version(room_id)
|
room_version = await self.store.get_room_version(room_id)
|
||||||
|
@ -700,18 +714,16 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
return await self._try_destination_list("send_join", destinations, send_request)
|
return await self._try_destination_list("send_join", destinations, send_request)
|
||||||
|
|
||||||
async def _do_send_join(self, destination: str, pdu: EventBase):
|
async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = await self.transport_layer.send_join_v2(
|
return await self.transport_layer.send_join_v2(
|
||||||
destination=destination,
|
destination=destination,
|
||||||
room_id=pdu.room_id,
|
room_id=pdu.room_id,
|
||||||
event_id=pdu.event_id,
|
event_id=pdu.event_id,
|
||||||
content=pdu.get_pdu_json(time_now),
|
content=pdu.get_pdu_json(time_now),
|
||||||
)
|
)
|
||||||
|
|
||||||
return content
|
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
if e.code in [400, 404]:
|
if e.code in [400, 404]:
|
||||||
err = e.to_synapse_error()
|
err = e.to_synapse_error()
|
||||||
|
@ -769,7 +781,7 @@ class FederationClient(FederationBase):
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = await self.transport_layer.send_invite_v2(
|
return await self.transport_layer.send_invite_v2(
|
||||||
destination=destination,
|
destination=destination,
|
||||||
room_id=pdu.room_id,
|
room_id=pdu.room_id,
|
||||||
event_id=pdu.event_id,
|
event_id=pdu.event_id,
|
||||||
|
@ -779,7 +791,6 @@ class FederationClient(FederationBase):
|
||||||
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
|
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return content
|
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
if e.code in [400, 404]:
|
if e.code in [400, 404]:
|
||||||
err = e.to_synapse_error()
|
err = e.to_synapse_error()
|
||||||
|
@ -842,18 +853,16 @@ class FederationClient(FederationBase):
|
||||||
"send_leave", destinations, send_request
|
"send_leave", destinations, send_request
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _do_send_leave(self, destination, pdu):
|
async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = await self.transport_layer.send_leave_v2(
|
return await self.transport_layer.send_leave_v2(
|
||||||
destination=destination,
|
destination=destination,
|
||||||
room_id=pdu.room_id,
|
room_id=pdu.room_id,
|
||||||
event_id=pdu.event_id,
|
event_id=pdu.event_id,
|
||||||
content=pdu.get_pdu_json(time_now),
|
content=pdu.get_pdu_json(time_now),
|
||||||
)
|
)
|
||||||
|
|
||||||
return content
|
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
if e.code in [400, 404]:
|
if e.code in [400, 404]:
|
||||||
err = e.to_synapse_error()
|
err = e.to_synapse_error()
|
||||||
|
@ -879,7 +888,7 @@ class FederationClient(FederationBase):
|
||||||
# content.
|
# content.
|
||||||
return resp[1]
|
return resp[1]
|
||||||
|
|
||||||
def get_public_rooms(
|
async def get_public_rooms(
|
||||||
self,
|
self,
|
||||||
remote_server: str,
|
remote_server: str,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
|
@ -887,7 +896,7 @@ class FederationClient(FederationBase):
|
||||||
search_filter: Optional[Dict] = None,
|
search_filter: Optional[Dict] = None,
|
||||||
include_all_networks: bool = False,
|
include_all_networks: bool = False,
|
||||||
third_party_instance_id: Optional[str] = None,
|
third_party_instance_id: Optional[str] = None,
|
||||||
):
|
) -> JsonDict:
|
||||||
"""Get the list of public rooms from a remote homeserver
|
"""Get the list of public rooms from a remote homeserver
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -901,8 +910,7 @@ class FederationClient(FederationBase):
|
||||||
party instance
|
party instance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Awaitable[Dict[str, Any]]: The response from the remote server, or None if
|
The response from the remote server.
|
||||||
`remote_server` is the same as the local server_name
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HttpResponseException: There was an exception returned from the remote server
|
HttpResponseException: There was an exception returned from the remote server
|
||||||
|
@ -910,7 +918,7 @@ class FederationClient(FederationBase):
|
||||||
requests over federation
|
requests over federation
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return self.transport_layer.get_public_rooms(
|
return await self.transport_layer.get_public_rooms(
|
||||||
remote_server,
|
remote_server,
|
||||||
limit,
|
limit,
|
||||||
since_token,
|
since_token,
|
||||||
|
@ -923,7 +931,7 @@ class FederationClient(FederationBase):
|
||||||
self,
|
self,
|
||||||
destination: str,
|
destination: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
earliest_events_ids: Sequence[str],
|
earliest_events_ids: Iterable[str],
|
||||||
latest_events: Iterable[EventBase],
|
latest_events: Iterable[EventBase],
|
||||||
limit: int,
|
limit: int,
|
||||||
min_depth: int,
|
min_depth: int,
|
||||||
|
@ -974,7 +982,9 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
return signed_events
|
return signed_events
|
||||||
|
|
||||||
async def forward_third_party_invite(self, destinations, room_id, event_dict):
|
async def forward_third_party_invite(
|
||||||
|
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
|
||||||
|
) -> None:
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
if destination == self.server_name:
|
if destination == self.server_name:
|
||||||
continue
|
continue
|
||||||
|
@ -983,7 +993,7 @@ class FederationClient(FederationBase):
|
||||||
await self.transport_layer.exchange_third_party_invite(
|
await self.transport_layer.exchange_third_party_invite(
|
||||||
destination=destination, room_id=room_id, event_dict=event_dict
|
destination=destination, room_id=room_id, event_dict=event_dict
|
||||||
)
|
)
|
||||||
return None
|
return
|
||||||
except CodeMessageException:
|
except CodeMessageException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -995,7 +1005,7 @@ class FederationClient(FederationBase):
|
||||||
|
|
||||||
async def get_room_complexity(
|
async def get_room_complexity(
|
||||||
self, destination: str, room_id: str
|
self, destination: str, room_id: str
|
||||||
) -> Optional[dict]:
|
) -> Optional[JsonDict]:
|
||||||
"""
|
"""
|
||||||
Fetch the complexity of a remote room from another server.
|
Fetch the complexity of a remote room from another server.
|
||||||
|
|
||||||
|
@ -1008,10 +1018,9 @@ class FederationClient(FederationBase):
|
||||||
could not fetch the complexity.
|
could not fetch the complexity.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
complexity = await self.transport_layer.get_room_complexity(
|
return await self.transport_layer.get_room_complexity(
|
||||||
destination=destination, room_id=room_id
|
destination=destination, room_id=room_id
|
||||||
)
|
)
|
||||||
return complexity
|
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
# We didn't manage to get it -- probably a 404. We are okay if other
|
# We didn't manage to get it -- probably a 404. We are okay if other
|
||||||
# servers don't give it to us.
|
# servers don't give it to us.
|
||||||
|
|
|
@ -142,6 +142,8 @@ class FederationSender:
|
||||||
self._wake_destinations_needing_catchup,
|
self._wake_destinations_needing_catchup,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._external_cache = hs.get_external_cache()
|
||||||
|
|
||||||
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
|
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
|
||||||
"""Get or create a PerDestinationQueue for the given destination
|
"""Get or create a PerDestinationQueue for the given destination
|
||||||
|
|
||||||
|
@ -197,22 +199,40 @@ class FederationSender:
|
||||||
if not event.internal_metadata.should_proactively_send():
|
if not event.internal_metadata.should_proactively_send():
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
destinations = None # type: Optional[Set[str]]
|
||||||
# Get the state from before the event.
|
if not event.prev_event_ids():
|
||||||
# We need to make sure that this is the state from before
|
# If there are no prev event IDs then the state is empty
|
||||||
# the event and not from after it.
|
# and so no remote servers in the room
|
||||||
# Otherwise if the last member on a server in a room is
|
destinations = set()
|
||||||
# banned then it won't receive the event because it won't
|
else:
|
||||||
# be in the room after the ban.
|
# We check the external cache for the destinations, which is
|
||||||
destinations = await self.state.get_hosts_in_room_at_events(
|
# stored per state group.
|
||||||
event.room_id, event_ids=event.prev_event_ids()
|
|
||||||
|
sg = await self._external_cache.get(
|
||||||
|
"event_to_prev_state_group", event.event_id
|
||||||
)
|
)
|
||||||
except Exception:
|
if sg:
|
||||||
logger.exception(
|
destinations = await self._external_cache.get(
|
||||||
"Failed to calculate hosts in room for event: %s",
|
"get_joined_hosts", str(sg)
|
||||||
event.event_id,
|
)
|
||||||
)
|
|
||||||
return
|
if destinations is None:
|
||||||
|
try:
|
||||||
|
# Get the state from before the event.
|
||||||
|
# We need to make sure that this is the state from before
|
||||||
|
# the event and not from after it.
|
||||||
|
# Otherwise if the last member on a server in a room is
|
||||||
|
# banned then it won't receive the event because it won't
|
||||||
|
# be in the room after the ban.
|
||||||
|
destinations = await self.state.get_hosts_in_room_at_events(
|
||||||
|
event.room_id, event_ids=event.prev_event_ids()
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to calculate hosts in room for event: %s",
|
||||||
|
event.event_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
destinations = {
|
destinations = {
|
||||||
d
|
d
|
||||||
|
|
|
@ -2093,6 +2093,11 @@ class FederationHandler(BaseHandler):
|
||||||
if event.type == EventTypes.GuestAccess and not context.rejected:
|
if event.type == EventTypes.GuestAccess and not context.rejected:
|
||||||
await self.maybe_kick_guest_users(event)
|
await self.maybe_kick_guest_users(event)
|
||||||
|
|
||||||
|
# If we are going to send this event over federation we precaclculate
|
||||||
|
# the joined hosts.
|
||||||
|
if event.internal_metadata.get_send_on_behalf_of():
|
||||||
|
await self.event_creation_handler.cache_joined_hosts_for_event(event)
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
async def _check_for_soft_fail(
|
async def _check_for_soft_fail(
|
||||||
|
|
|
@ -432,6 +432,8 @@ class EventCreationHandler:
|
||||||
|
|
||||||
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
|
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
|
||||||
|
|
||||||
|
self._external_cache = hs.get_external_cache()
|
||||||
|
|
||||||
async def create_event(
|
async def create_event(
|
||||||
self,
|
self,
|
||||||
requester: Requester,
|
requester: Requester,
|
||||||
|
@ -939,6 +941,8 @@ class EventCreationHandler:
|
||||||
|
|
||||||
await self.action_generator.handle_push_actions_for_event(event, context)
|
await self.action_generator.handle_push_actions_for_event(event, context)
|
||||||
|
|
||||||
|
await self.cache_joined_hosts_for_event(event)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# If we're a worker we need to hit out to the master.
|
# If we're a worker we need to hit out to the master.
|
||||||
writer_instance = self._events_shard_config.get_instance(event.room_id)
|
writer_instance = self._events_shard_config.get_instance(event.room_id)
|
||||||
|
@ -978,6 +982,44 @@ class EventCreationHandler:
|
||||||
await self.store.remove_push_actions_from_staging(event.event_id)
|
await self.store.remove_push_actions_from_staging(event.event_id)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
|
||||||
|
"""Precalculate the joined hosts at the event, when using Redis, so that
|
||||||
|
external federation senders don't have to recalculate it themselves.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not self._external_cache.is_enabled():
|
||||||
|
return
|
||||||
|
|
||||||
|
# We actually store two mappings, event ID -> prev state group,
|
||||||
|
# state group -> joined hosts, which is much more space efficient
|
||||||
|
# than event ID -> joined hosts.
|
||||||
|
#
|
||||||
|
# Note: We have to cache event ID -> prev state group, as we don't
|
||||||
|
# store that in the DB.
|
||||||
|
#
|
||||||
|
# Note: We always set the state group -> joined hosts cache, even if
|
||||||
|
# we already set it, so that the expiry time is reset.
|
||||||
|
|
||||||
|
state_entry = await self.state.resolve_state_groups_for_events(
|
||||||
|
event.room_id, event_ids=event.prev_event_ids()
|
||||||
|
)
|
||||||
|
|
||||||
|
if state_entry.state_group:
|
||||||
|
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
|
||||||
|
|
||||||
|
await self._external_cache.set(
|
||||||
|
"event_to_prev_state_group",
|
||||||
|
event.event_id,
|
||||||
|
state_entry.state_group,
|
||||||
|
expiry_ms=60 * 60 * 1000,
|
||||||
|
)
|
||||||
|
await self._external_cache.set(
|
||||||
|
"get_joined_hosts",
|
||||||
|
str(state_entry.state_group),
|
||||||
|
list(joined_hosts),
|
||||||
|
expiry_ms=60 * 60 * 1000,
|
||||||
|
)
|
||||||
|
|
||||||
async def _validate_canonical_alias(
|
async def _validate_canonical_alias(
|
||||||
self, directory_handler, room_alias_str: str, expected_room_id: str
|
self, directory_handler, room_alias_str: str, expected_room_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -17,7 +17,7 @@ import logging
|
||||||
import re
|
import re
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, Optional
|
from typing import TYPE_CHECKING, Dict, Iterable, Optional
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.types import StateMap
|
from synapse.types import StateMap
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ async def calculate_room_name(
|
||||||
m_room_name = await store.get_event(
|
m_room_name = await store.get_event(
|
||||||
room_state_ids[(EventTypes.Name, "")], allow_none=True
|
room_state_ids[(EventTypes.Name, "")], allow_none=True
|
||||||
)
|
)
|
||||||
if m_room_name and m_room_name.content and m_room_name.content["name"]:
|
if m_room_name and m_room_name.content and m_room_name.content.get("name"):
|
||||||
return m_room_name.content["name"]
|
return m_room_name.content["name"]
|
||||||
|
|
||||||
# does it have a canonical alias?
|
# does it have a canonical alias?
|
||||||
|
@ -74,15 +74,11 @@ async def calculate_room_name(
|
||||||
if (
|
if (
|
||||||
canon_alias
|
canon_alias
|
||||||
and canon_alias.content
|
and canon_alias.content
|
||||||
and canon_alias.content["alias"]
|
and canon_alias.content.get("alias")
|
||||||
and _looks_like_an_alias(canon_alias.content["alias"])
|
and _looks_like_an_alias(canon_alias.content["alias"])
|
||||||
):
|
):
|
||||||
return canon_alias.content["alias"]
|
return canon_alias.content["alias"]
|
||||||
|
|
||||||
# at this point we're going to need to search the state by all state keys
|
|
||||||
# for an event type, so rearrange the data structure
|
|
||||||
room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
|
|
||||||
|
|
||||||
if not fallback_to_members:
|
if not fallback_to_members:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -94,7 +90,7 @@ async def calculate_room_name(
|
||||||
|
|
||||||
if (
|
if (
|
||||||
my_member_event is not None
|
my_member_event is not None
|
||||||
and my_member_event.content["membership"] == "invite"
|
and my_member_event.content.get("membership") == Membership.INVITE
|
||||||
):
|
):
|
||||||
if (EventTypes.Member, my_member_event.sender) in room_state_ids:
|
if (EventTypes.Member, my_member_event.sender) in room_state_ids:
|
||||||
inviter_member_event = await store.get_event(
|
inviter_member_event = await store.get_event(
|
||||||
|
@ -111,6 +107,10 @@ async def calculate_room_name(
|
||||||
else:
|
else:
|
||||||
return "Room Invite"
|
return "Room Invite"
|
||||||
|
|
||||||
|
# at this point we're going to need to search the state by all state keys
|
||||||
|
# for an event type, so rearrange the data structure
|
||||||
|
room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
|
||||||
|
|
||||||
# we're going to have to generate a name based on who's in the room,
|
# we're going to have to generate a name based on who's in the room,
|
||||||
# so find out who is in the room that isn't the user.
|
# so find out who is in the room that isn't the user.
|
||||||
if EventTypes.Member in room_state_bytype_ids:
|
if EventTypes.Member in room_state_bytype_ids:
|
||||||
|
@ -120,8 +120,8 @@ async def calculate_room_name(
|
||||||
all_members = [
|
all_members = [
|
||||||
ev
|
ev
|
||||||
for ev in member_events.values()
|
for ev in member_events.values()
|
||||||
if ev.content["membership"] == "join"
|
if ev.content.get("membership") == Membership.JOIN
|
||||||
or ev.content["membership"] == "invite"
|
or ev.content.get("membership") == Membership.INVITE
|
||||||
]
|
]
|
||||||
# Sort the member events oldest-first so the we name people in the
|
# Sort the member events oldest-first so the we name people in the
|
||||||
# order the joined (it should at least be deterministic rather than
|
# order the joined (it should at least be deterministic rather than
|
||||||
|
@ -194,11 +194,7 @@ def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
|
||||||
|
|
||||||
|
|
||||||
def name_from_member_event(member_event: EventBase) -> str:
|
def name_from_member_event(member_event: EventBase) -> str:
|
||||||
if (
|
if member_event.content and member_event.content.get("displayname"):
|
||||||
member_event.content
|
|
||||||
and "displayname" in member_event.content
|
|
||||||
and member_event.content["displayname"]
|
|
||||||
):
|
|
||||||
return member_event.content["displayname"]
|
return member_event.content["displayname"]
|
||||||
return member_event.state_key
|
return member_event.state_key
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
from prometheus_client import Counter
|
||||||
|
|
||||||
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
from synapse.util import json_decoder, json_encoder
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
set_counter = Counter(
|
||||||
|
"synapse_external_cache_set",
|
||||||
|
"Number of times we set a cache",
|
||||||
|
labelnames=["cache_name"],
|
||||||
|
)
|
||||||
|
|
||||||
|
get_counter = Counter(
|
||||||
|
"synapse_external_cache_get",
|
||||||
|
"Number of times we get a cache",
|
||||||
|
labelnames=["cache_name", "hit"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalCache:
|
||||||
|
"""A cache backed by an external Redis. Does nothing if no Redis is
|
||||||
|
configured.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self._redis_connection = hs.get_outbound_redis_connection()
|
||||||
|
|
||||||
|
def _get_redis_key(self, cache_name: str, key: str) -> str:
|
||||||
|
return "cache_v1:%s:%s" % (cache_name, key)
|
||||||
|
|
||||||
|
def is_enabled(self) -> bool:
|
||||||
|
"""Whether the external cache is used or not.
|
||||||
|
|
||||||
|
It's safe to use the cache when this returns false, the methods will
|
||||||
|
just no-op, but the function is useful to avoid doing unnecessary work.
|
||||||
|
"""
|
||||||
|
return self._redis_connection is not None
|
||||||
|
|
||||||
|
async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
|
||||||
|
"""Add the key/value to the named cache, with the expiry time given.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._redis_connection is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
set_counter.labels(cache_name).inc()
|
||||||
|
|
||||||
|
# txredisapi requires the value to be string, bytes or numbers, so we
|
||||||
|
# encode stuff in JSON.
|
||||||
|
encoded_value = json_encoder.encode(value)
|
||||||
|
|
||||||
|
logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
|
||||||
|
|
||||||
|
return await make_deferred_yieldable(
|
||||||
|
self._redis_connection.set(
|
||||||
|
self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get(self, cache_name: str, key: str) -> Optional[Any]:
|
||||||
|
"""Look up a key/value in the named cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self._redis_connection is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = await make_deferred_yieldable(
|
||||||
|
self._redis_connection.get(self._get_redis_key(cache_name, key))
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Got cache result %s %s: %r", cache_name, key, result)
|
||||||
|
|
||||||
|
get_counter.labels(cache_name, result is not None).inc()
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For some reason the integers get magically converted back to integers
|
||||||
|
if isinstance(result, int):
|
||||||
|
return result
|
||||||
|
|
||||||
|
return json_decoder.decode(result)
|
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Dict,
|
Dict,
|
||||||
|
@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
|
||||||
TypingStream,
|
TypingStream,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,7 +92,7 @@ class ReplicationCommandHandler:
|
||||||
back out to connections.
|
back out to connections.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self._replication_data_handler = hs.get_replication_data_handler()
|
self._replication_data_handler = hs.get_replication_data_handler()
|
||||||
self._presence_handler = hs.get_presence_handler()
|
self._presence_handler = hs.get_presence_handler()
|
||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
|
@ -282,13 +286,6 @@ class ReplicationCommandHandler:
|
||||||
if hs.config.redis.redis_enabled:
|
if hs.config.redis.redis_enabled:
|
||||||
from synapse.replication.tcp.redis import (
|
from synapse.replication.tcp.redis import (
|
||||||
RedisDirectTcpReplicationClientFactory,
|
RedisDirectTcpReplicationClientFactory,
|
||||||
lazyConnection,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Connecting to redis (host=%r port=%r)",
|
|
||||||
hs.config.redis_host,
|
|
||||||
hs.config.redis_port,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# First let's ensure that we have a ReplicationStreamer started.
|
# First let's ensure that we have a ReplicationStreamer started.
|
||||||
|
@ -299,13 +296,7 @@ class ReplicationCommandHandler:
|
||||||
# connection after SUBSCRIBE is called).
|
# connection after SUBSCRIBE is called).
|
||||||
|
|
||||||
# First create the connection for sending commands.
|
# First create the connection for sending commands.
|
||||||
outbound_redis_connection = lazyConnection(
|
outbound_redis_connection = hs.get_outbound_redis_connection()
|
||||||
reactor=hs.get_reactor(),
|
|
||||||
host=hs.config.redis_host,
|
|
||||||
port=hs.config.redis_port,
|
|
||||||
password=hs.config.redis.redis_password,
|
|
||||||
reconnect=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now create the factory/connection for the subscription stream.
|
# Now create the factory/connection for the subscription stream.
|
||||||
self._factory = RedisDirectTcpReplicationClientFactory(
|
self._factory = RedisDirectTcpReplicationClientFactory(
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from inspect import isawaitable
|
from inspect import isawaitable
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional, Type, cast
|
||||||
|
|
||||||
import txredisapi
|
import txredisapi
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
|
||||||
from synapse.metrics.background_process_metrics import (
|
from synapse.metrics.background_process_metrics import (
|
||||||
BackgroundProcessLoggingContext,
|
BackgroundProcessLoggingContext,
|
||||||
run_as_background_process,
|
run_as_background_process,
|
||||||
|
wrap_as_background_process,
|
||||||
)
|
)
|
||||||
from synapse.replication.tcp.commands import (
|
from synapse.replication.tcp.commands import (
|
||||||
Command,
|
Command,
|
||||||
|
@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||||
immediately after initialisation.
|
immediately after initialisation.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
handler: The command handler to handle incoming commands.
|
synapse_handler: The command handler to handle incoming commands.
|
||||||
stream_name: The *redis* stream name to subscribe to and publish from
|
synapse_stream_name: The *redis* stream name to subscribe to and publish
|
||||||
(not anything to do with Synapse replication streams).
|
from (not anything to do with Synapse replication streams).
|
||||||
outbound_redis_connection: The connection to redis to use to send
|
synapse_outbound_redis_connection: The connection to redis to use to send
|
||||||
commands.
|
commands.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
handler = None # type: ReplicationCommandHandler
|
synapse_handler = None # type: ReplicationCommandHandler
|
||||||
stream_name = None # type: str
|
synapse_stream_name = None # type: str
|
||||||
outbound_redis_connection = None # type: txredisapi.RedisProtocol
|
synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||||
# it's important to make sure that we only send the REPLICATE command once we
|
# it's important to make sure that we only send the REPLICATE command once we
|
||||||
# have successfully subscribed to the stream - otherwise we might miss the
|
# have successfully subscribed to the stream - otherwise we might miss the
|
||||||
# POSITION response sent back by the other end.
|
# POSITION response sent back by the other end.
|
||||||
logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
|
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
|
||||||
await make_deferred_yieldable(self.subscribe(self.stream_name))
|
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
|
||||||
logger.info(
|
logger.info(
|
||||||
"Successfully subscribed to redis stream, sending REPLICATE command"
|
"Successfully subscribed to redis stream, sending REPLICATE command"
|
||||||
)
|
)
|
||||||
self.handler.new_connection(self)
|
self.synapse_handler.new_connection(self)
|
||||||
await self._async_send_command(ReplicateCommand())
|
await self._async_send_command(ReplicateCommand())
|
||||||
logger.info("REPLICATE successfully sent")
|
logger.info("REPLICATE successfully sent")
|
||||||
|
|
||||||
# We send out our positions when there is a new connection in case the
|
# We send out our positions when there is a new connection in case the
|
||||||
# other side missed updates. We do this for Redis connections as the
|
# other side missed updates. We do this for Redis connections as the
|
||||||
# otherside won't know we've connected and so won't issue a REPLICATE.
|
# otherside won't know we've connected and so won't issue a REPLICATE.
|
||||||
self.handler.send_positions_to_connection(self)
|
self.synapse_handler.send_positions_to_connection(self)
|
||||||
|
|
||||||
def messageReceived(self, pattern: str, channel: str, message: str):
|
def messageReceived(self, pattern: str, channel: str, message: str):
|
||||||
"""Received a message from redis.
|
"""Received a message from redis.
|
||||||
|
@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||||
cmd: received command
|
cmd: received command
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
|
cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
|
||||||
if not cmd_func:
|
if not cmd_func:
|
||||||
logger.warning("Unhandled command: %r", cmd)
|
logger.warning("Unhandled command: %r", cmd)
|
||||||
return
|
return
|
||||||
|
@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||||
def connectionLost(self, reason):
|
def connectionLost(self, reason):
|
||||||
logger.info("Lost connection to redis")
|
logger.info("Lost connection to redis")
|
||||||
super().connectionLost(reason)
|
super().connectionLost(reason)
|
||||||
self.handler.lost_connection(self)
|
self.synapse_handler.lost_connection(self)
|
||||||
|
|
||||||
# mark the logging context as finished
|
# mark the logging context as finished
|
||||||
self._logging_context.__exit__(None, None, None)
|
self._logging_context.__exit__(None, None, None)
|
||||||
|
@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||||
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
|
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
|
||||||
|
|
||||||
await make_deferred_yieldable(
|
await make_deferred_yieldable(
|
||||||
self.outbound_redis_connection.publish(self.stream_name, encoded_string)
|
self.synapse_outbound_redis_connection.publish(
|
||||||
|
self.synapse_stream_name, encoded_string
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
|
class SynapseRedisFactory(txredisapi.RedisFactory):
|
||||||
|
"""A subclass of RedisFactory that periodically sends pings to ensure that
|
||||||
|
we detect dead connections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
uuid: str,
|
||||||
|
dbid: Optional[int],
|
||||||
|
poolsize: int,
|
||||||
|
isLazy: bool = False,
|
||||||
|
handler: Type = txredisapi.ConnectionHandler,
|
||||||
|
charset: str = "utf-8",
|
||||||
|
password: Optional[str] = None,
|
||||||
|
replyTimeout: int = 30,
|
||||||
|
convertNumbers: Optional[int] = True,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
uuid=uuid,
|
||||||
|
dbid=dbid,
|
||||||
|
poolsize=poolsize,
|
||||||
|
isLazy=isLazy,
|
||||||
|
handler=handler,
|
||||||
|
charset=charset,
|
||||||
|
password=password,
|
||||||
|
replyTimeout=replyTimeout,
|
||||||
|
convertNumbers=convertNumbers,
|
||||||
|
)
|
||||||
|
|
||||||
|
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
|
||||||
|
|
||||||
|
@wrap_as_background_process("redis_ping")
|
||||||
|
async def _send_ping(self):
|
||||||
|
for connection in self.pool:
|
||||||
|
try:
|
||||||
|
await make_deferred_yieldable(connection.ping())
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to send ping to a redis connection")
|
||||||
|
|
||||||
|
|
||||||
|
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
|
||||||
"""This is a reconnecting factory that connects to redis and immediately
|
"""This is a reconnecting factory that connects to redis and immediately
|
||||||
subscribes to a stream.
|
subscribes to a stream.
|
||||||
|
|
||||||
|
@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
|
||||||
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
|
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__(
|
||||||
|
hs,
|
||||||
|
uuid="subscriber",
|
||||||
|
dbid=None,
|
||||||
|
poolsize=1,
|
||||||
|
replyTimeout=30,
|
||||||
|
password=hs.config.redis.redis_password,
|
||||||
|
)
|
||||||
|
|
||||||
# This sets the password on the RedisFactory base class (as
|
self.synapse_handler = hs.get_tcp_replication()
|
||||||
# SubscriberFactory constructor doesn't pass it through).
|
self.synapse_stream_name = hs.hostname
|
||||||
self.password = hs.config.redis.redis_password
|
|
||||||
|
|
||||||
self.handler = hs.get_tcp_replication()
|
self.synapse_outbound_redis_connection = outbound_redis_connection
|
||||||
self.stream_name = hs.hostname
|
|
||||||
|
|
||||||
self.outbound_redis_connection = outbound_redis_connection
|
|
||||||
|
|
||||||
def buildProtocol(self, addr):
|
def buildProtocol(self, addr):
|
||||||
p = super().buildProtocol(addr) # type: RedisSubscriber
|
p = super().buildProtocol(addr)
|
||||||
|
p = cast(RedisSubscriber, p)
|
||||||
|
|
||||||
# We do this here rather than add to the constructor of `RedisSubcriber`
|
# We do this here rather than add to the constructor of `RedisSubcriber`
|
||||||
# as to do so would involve overriding `buildProtocol` entirely, however
|
# as to do so would involve overriding `buildProtocol` entirely, however
|
||||||
# the base method does some other things than just instantiating the
|
# the base method does some other things than just instantiating the
|
||||||
# protocol.
|
# protocol.
|
||||||
p.handler = self.handler
|
p.synapse_handler = self.synapse_handler
|
||||||
p.outbound_redis_connection = self.outbound_redis_connection
|
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
|
||||||
p.stream_name = self.stream_name
|
p.synapse_stream_name = self.synapse_stream_name
|
||||||
p.password = self.password
|
|
||||||
|
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
def lazyConnection(
|
def lazyConnection(
|
||||||
reactor,
|
hs: "HomeServer",
|
||||||
host: str = "localhost",
|
host: str = "localhost",
|
||||||
port: int = 6379,
|
port: int = 6379,
|
||||||
dbid: Optional[int] = None,
|
dbid: Optional[int] = None,
|
||||||
reconnect: bool = True,
|
reconnect: bool = True,
|
||||||
charset: str = "utf-8",
|
|
||||||
password: Optional[str] = None,
|
password: Optional[str] = None,
|
||||||
connectTimeout: Optional[int] = None,
|
replyTimeout: int = 30,
|
||||||
replyTimeout: Optional[int] = None,
|
|
||||||
convertNumbers: bool = True,
|
|
||||||
) -> txredisapi.RedisProtocol:
|
) -> txredisapi.RedisProtocol:
|
||||||
"""Equivalent to `txredisapi.lazyConnection`, except allows specifying a
|
"""Creates a connection to Redis that is lazily set up and reconnects if the
|
||||||
reactor.
|
connections is lost.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
isLazy = True
|
|
||||||
poolsize = 1
|
|
||||||
|
|
||||||
uuid = "%s:%d" % (host, port)
|
uuid = "%s:%d" % (host, port)
|
||||||
factory = txredisapi.RedisFactory(
|
factory = SynapseRedisFactory(
|
||||||
uuid,
|
hs,
|
||||||
dbid,
|
uuid=uuid,
|
||||||
poolsize,
|
dbid=dbid,
|
||||||
isLazy,
|
poolsize=1,
|
||||||
txredisapi.ConnectionHandler,
|
isLazy=True,
|
||||||
charset,
|
handler=txredisapi.ConnectionHandler,
|
||||||
password,
|
password=password,
|
||||||
replyTimeout,
|
replyTimeout=replyTimeout,
|
||||||
convertNumbers,
|
|
||||||
)
|
)
|
||||||
factory.continueTrying = reconnect
|
factory.continueTrying = reconnect
|
||||||
for x in range(poolsize):
|
|
||||||
reactor.connectTCP(host, port, factory, connectTimeout)
|
reactor = hs.get_reactor()
|
||||||
|
reactor.connectTCP(host, port, factory, 30)
|
||||||
|
|
||||||
return factory.handler
|
return factory.handler
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
# Copyright 2018-2019 New Vector Ltd
|
# Copyright 2018-2019 New Vector Ltd
|
||||||
|
# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -36,6 +38,7 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
|
||||||
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
|
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
|
||||||
from synapse.rest.admin.rooms import (
|
from synapse.rest.admin.rooms import (
|
||||||
DeleteRoomRestServlet,
|
DeleteRoomRestServlet,
|
||||||
|
ForwardExtremitiesRestServlet,
|
||||||
JoinRoomAliasServlet,
|
JoinRoomAliasServlet,
|
||||||
ListRoomRestServlet,
|
ListRoomRestServlet,
|
||||||
MakeRoomAdminRestServlet,
|
MakeRoomAdminRestServlet,
|
||||||
|
@ -51,6 +54,7 @@ from synapse.rest.admin.users import (
|
||||||
PushersRestServlet,
|
PushersRestServlet,
|
||||||
ResetPasswordRestServlet,
|
ResetPasswordRestServlet,
|
||||||
SearchUsersRestServlet,
|
SearchUsersRestServlet,
|
||||||
|
ShadowBanRestServlet,
|
||||||
UserAdminServlet,
|
UserAdminServlet,
|
||||||
UserMediaRestServlet,
|
UserMediaRestServlet,
|
||||||
UserMembershipRestServlet,
|
UserMembershipRestServlet,
|
||||||
|
@ -230,6 +234,8 @@ def register_servlets(hs, http_server):
|
||||||
EventReportsRestServlet(hs).register(http_server)
|
EventReportsRestServlet(hs).register(http_server)
|
||||||
PushersRestServlet(hs).register(http_server)
|
PushersRestServlet(hs).register(http_server)
|
||||||
MakeRoomAdminRestServlet(hs).register(http_server)
|
MakeRoomAdminRestServlet(hs).register(http_server)
|
||||||
|
ShadowBanRestServlet(hs).register(http_server)
|
||||||
|
ForwardExtremitiesRestServlet(hs).register(http_server)
|
||||||
|
|
||||||
|
|
||||||
def register_servlets_for_client_rest_resource(hs, http_server):
|
def register_servlets_for_client_rest_resource(hs, http_server):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -499,3 +499,60 @@ class MakeRoomAdminRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
|
class ForwardExtremitiesRestServlet(RestServlet):
|
||||||
|
"""Allows a server admin to get or clear forward extremities.
|
||||||
|
|
||||||
|
Clearing does not require restarting the server.
|
||||||
|
|
||||||
|
Clear forward extremities:
|
||||||
|
DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
|
||||||
|
|
||||||
|
Get forward_extremities:
|
||||||
|
GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self.hs = hs
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.room_member_handler = hs.get_room_member_handler()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
async def resolve_room_id(self, room_identifier: str) -> str:
|
||||||
|
"""Resolve to a room ID, if necessary."""
|
||||||
|
if RoomID.is_valid(room_identifier):
|
||||||
|
resolved_room_id = room_identifier
|
||||||
|
elif RoomAlias.is_valid(room_identifier):
|
||||||
|
room_alias = RoomAlias.from_string(room_identifier)
|
||||||
|
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
|
||||||
|
resolved_room_id = room_id.to_string()
|
||||||
|
else:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||||
|
)
|
||||||
|
if not resolved_room_id:
|
||||||
|
raise SynapseError(
|
||||||
|
400, "Unknown room ID or room alias %s" % room_identifier
|
||||||
|
)
|
||||||
|
return resolved_room_id
|
||||||
|
|
||||||
|
async def on_DELETE(self, request, room_identifier):
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
room_id = await self.resolve_room_id(room_identifier)
|
||||||
|
|
||||||
|
deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
|
||||||
|
return 200, {"deleted": deleted_count}
|
||||||
|
|
||||||
|
async def on_GET(self, request, room_identifier):
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
|
||||||
|
room_id = await self.resolve_room_id(room_identifier)
|
||||||
|
|
||||||
|
extremities = await self.store.get_forward_extremities_for_room(room_id)
|
||||||
|
return 200, {"count": len(extremities), "results": extremities}
|
||||||
|
|
|
@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet):
|
||||||
The parameter `deactivated` can be used to include deactivated users.
|
The parameter `deactivated` can be used to include deactivated users.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.admin_handler = hs.get_admin_handler()
|
self.admin_handler = hs.get_admin_handler()
|
||||||
|
|
||||||
async def on_GET(self, request):
|
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
await assert_requester_is_admin(self.auth, request)
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
start = parse_integer(request, "from", default=0)
|
start = parse_integer(request, "from", default=0)
|
||||||
limit = parse_integer(request, "limit", default=100)
|
limit = parse_integer(request, "limit", default=100)
|
||||||
|
|
||||||
|
if start < 0:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"Query parameter from must be a string representing a positive integer.",
|
||||||
|
errcode=Codes.INVALID_PARAM,
|
||||||
|
)
|
||||||
|
|
||||||
|
if limit < 0:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"Query parameter limit must be a string representing a positive integer.",
|
||||||
|
errcode=Codes.INVALID_PARAM,
|
||||||
|
)
|
||||||
|
|
||||||
user_id = parse_string(request, "user_id", default=None)
|
user_id = parse_string(request, "user_id", default=None)
|
||||||
name = parse_string(request, "name", default=None)
|
name = parse_string(request, "name", default=None)
|
||||||
guests = parse_boolean(request, "guests", default=True)
|
guests = parse_boolean(request, "guests", default=True)
|
||||||
|
@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet):
|
||||||
start, limit, user_id, name, guests, deactivated
|
start, limit, user_id, name, guests, deactivated
|
||||||
)
|
)
|
||||||
ret = {"users": users, "total": total}
|
ret = {"users": users, "total": total}
|
||||||
if len(users) >= limit:
|
if (start + limit) < total:
|
||||||
ret["next_token"] = str(start + len(users))
|
ret["next_token"] = str(start + len(users))
|
||||||
|
|
||||||
return 200, ret
|
return 200, ret
|
||||||
|
@ -875,3 +890,39 @@ class UserTokenRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, {"access_token": token}
|
return 200, {"access_token": token}
|
||||||
|
|
||||||
|
|
||||||
|
class ShadowBanRestServlet(RestServlet):
|
||||||
|
"""An admin API for shadow-banning a user.
|
||||||
|
|
||||||
|
A shadow-banned users receives successful responses to their client-server
|
||||||
|
API requests, but the events are not propagated into rooms.
|
||||||
|
|
||||||
|
Shadow-banning a user should be used as a tool of last resort and may lead
|
||||||
|
to confusing or broken behaviour for the client.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
|
||||||
|
{}
|
||||||
|
|
||||||
|
200 OK
|
||||||
|
{}
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self.hs = hs
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
|
async def on_POST(self, request, user_id):
|
||||||
|
await assert_requester_is_admin(self.auth, request)
|
||||||
|
|
||||||
|
if not self.hs.is_mine_id(user_id):
|
||||||
|
raise SynapseError(400, "Only local users can be shadow-banned")
|
||||||
|
|
||||||
|
await self.store.set_shadow_banned(UserID.from_string(user_id), True)
|
||||||
|
|
||||||
|
return 200, {}
|
||||||
|
|
|
@ -300,6 +300,7 @@ class FileInfo:
|
||||||
thumbnail_height (int)
|
thumbnail_height (int)
|
||||||
thumbnail_method (str)
|
thumbnail_method (str)
|
||||||
thumbnail_type (str): Content type of thumbnail, e.g. image/png
|
thumbnail_type (str): Content type of thumbnail, e.g. image/png
|
||||||
|
thumbnail_length (int): The size of the media file, in bytes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -312,6 +313,7 @@ class FileInfo:
|
||||||
thumbnail_height=None,
|
thumbnail_height=None,
|
||||||
thumbnail_method=None,
|
thumbnail_method=None,
|
||||||
thumbnail_type=None,
|
thumbnail_type=None,
|
||||||
|
thumbnail_length=None,
|
||||||
):
|
):
|
||||||
self.server_name = server_name
|
self.server_name = server_name
|
||||||
self.file_id = file_id
|
self.file_id = file_id
|
||||||
|
@ -321,6 +323,7 @@ class FileInfo:
|
||||||
self.thumbnail_height = thumbnail_height
|
self.thumbnail_height = thumbnail_height
|
||||||
self.thumbnail_method = thumbnail_method
|
self.thumbnail_method = thumbnail_method
|
||||||
self.thumbnail_type = thumbnail_type
|
self.thumbnail_type = thumbnail_type
|
||||||
|
self.thumbnail_length = thumbnail_length
|
||||||
|
|
||||||
|
|
||||||
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
|
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
|
||||||
|
|
|
@ -386,7 +386,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
"""
|
"""
|
||||||
Check whether the URL should be downloaded as oEmbed content instead.
|
Check whether the URL should be downloaded as oEmbed content instead.
|
||||||
|
|
||||||
Params:
|
Args:
|
||||||
url: The URL to check.
|
url: The URL to check.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -403,7 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
"""
|
"""
|
||||||
Request content from an oEmbed endpoint.
|
Request content from an oEmbed endpoint.
|
||||||
|
|
||||||
Params:
|
Args:
|
||||||
endpoint: The oEmbed API endpoint.
|
endpoint: The oEmbed API endpoint.
|
||||||
url: The URL to pass to the API.
|
url: The URL to pass to the API.
|
||||||
|
|
||||||
|
@ -692,27 +692,51 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||||
def decode_and_calc_og(
|
def decode_and_calc_og(
|
||||||
body: bytes, media_uri: str, request_encoding: Optional[str] = None
|
body: bytes, media_uri: str, request_encoding: Optional[str] = None
|
||||||
) -> Dict[str, Optional[str]]:
|
) -> Dict[str, Optional[str]]:
|
||||||
|
"""
|
||||||
|
Calculate metadata for an HTML document.
|
||||||
|
|
||||||
|
This uses lxml to parse the HTML document into the OG response. If errors
|
||||||
|
occur during processing of the document, an empty response is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
body: The HTML document, as bytes.
|
||||||
|
media_url: The URI used to download the body.
|
||||||
|
request_encoding: The character encoding of the body, as a string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The OG response as a dictionary.
|
||||||
|
"""
|
||||||
# If there's no body, nothing useful is going to be found.
|
# If there's no body, nothing useful is going to be found.
|
||||||
if not body:
|
if not body:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
from lxml import etree
|
from lxml import etree
|
||||||
|
|
||||||
|
# Create an HTML parser. If this fails, log and return no metadata.
|
||||||
try:
|
try:
|
||||||
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
|
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
|
||||||
tree = etree.fromstring(body, parser)
|
except LookupError:
|
||||||
og = _calc_og(tree, media_uri)
|
# blindly consider the encoding as utf-8.
|
||||||
|
parser = etree.HTMLParser(recover=True, encoding="utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Unable to create HTML parser: %s" % (e,))
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
|
||||||
|
# Attempt to parse the body. If this fails, log and return no metadata.
|
||||||
|
tree = etree.fromstring(body_attempt, parser)
|
||||||
|
return _calc_og(tree, media_uri)
|
||||||
|
|
||||||
|
# Attempt to parse the body. If this fails, log and return no metadata.
|
||||||
|
try:
|
||||||
|
return _attempt_calc_og(body)
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
# blindly try decoding the body as utf-8, which seems to fix
|
# blindly try decoding the body as utf-8, which seems to fix
|
||||||
# the charset mismatches on https://google.com
|
# the charset mismatches on https://google.com
|
||||||
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
|
return _attempt_calc_og(body.decode("utf-8", "ignore"))
|
||||||
tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
|
|
||||||
og = _calc_og(tree, media_uri)
|
|
||||||
|
|
||||||
return og
|
|
||||||
|
|
||||||
|
|
||||||
def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
|
def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
|
||||||
# suck our tree into lxml and define our OG response.
|
# suck our tree into lxml and define our OG response.
|
||||||
|
|
||||||
# if we see any image URLs in the OG response, then spider them
|
# if we see any image URLs in the OG response, then spider them
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from twisted.web.http import Request
|
from twisted.web.http import Request
|
||||||
|
|
||||||
|
@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||||
return
|
return
|
||||||
|
|
||||||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
||||||
|
await self._select_and_respond_with_thumbnail(
|
||||||
if thumbnail_infos:
|
request,
|
||||||
thumbnail_info = self._select_thumbnail(
|
width,
|
||||||
width, height, method, m_type, thumbnail_infos
|
height,
|
||||||
)
|
method,
|
||||||
|
m_type,
|
||||||
file_info = FileInfo(
|
thumbnail_infos,
|
||||||
server_name=None,
|
media_id,
|
||||||
file_id=media_id,
|
url_cache=media_info["url_cache"],
|
||||||
url_cache=media_info["url_cache"],
|
server_name=None,
|
||||||
thumbnail=True,
|
)
|
||||||
thumbnail_width=thumbnail_info["thumbnail_width"],
|
|
||||||
thumbnail_height=thumbnail_info["thumbnail_height"],
|
|
||||||
thumbnail_type=thumbnail_info["thumbnail_type"],
|
|
||||||
thumbnail_method=thumbnail_info["thumbnail_method"],
|
|
||||||
)
|
|
||||||
|
|
||||||
t_type = file_info.thumbnail_type
|
|
||||||
t_length = thumbnail_info["thumbnail_length"]
|
|
||||||
|
|
||||||
responder = await self.media_storage.fetch_media(file_info)
|
|
||||||
await respond_with_responder(request, responder, t_type, t_length)
|
|
||||||
else:
|
|
||||||
logger.info("Couldn't find any generated thumbnails")
|
|
||||||
respond_404(request)
|
|
||||||
|
|
||||||
async def _select_or_generate_local_thumbnail(
|
async def _select_or_generate_local_thumbnail(
|
||||||
self,
|
self,
|
||||||
|
@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||||
server_name, media_id
|
server_name, media_id
|
||||||
)
|
)
|
||||||
|
await self._select_and_respond_with_thumbnail(
|
||||||
|
request,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
method,
|
||||||
|
m_type,
|
||||||
|
thumbnail_infos,
|
||||||
|
media_info["filesystem_id"],
|
||||||
|
url_cache=None,
|
||||||
|
server_name=server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _select_and_respond_with_thumbnail(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
desired_width: int,
|
||||||
|
desired_height: int,
|
||||||
|
desired_method: str,
|
||||||
|
desired_type: str,
|
||||||
|
thumbnail_infos: List[Dict[str, Any]],
|
||||||
|
file_id: str,
|
||||||
|
url_cache: Optional[str] = None,
|
||||||
|
server_name: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The incoming request.
|
||||||
|
desired_width: The desired width, the returned thumbnail may be larger than this.
|
||||||
|
desired_height: The desired height, the returned thumbnail may be larger than this.
|
||||||
|
desired_method: The desired method used to generate the thumbnail.
|
||||||
|
desired_type: The desired content-type of the thumbnail.
|
||||||
|
thumbnail_infos: A list of dictionaries of candidate thumbnails.
|
||||||
|
file_id: The ID of the media that a thumbnail is being requested for.
|
||||||
|
url_cache: The URL cache value.
|
||||||
|
server_name: The server name, if this is a remote thumbnail.
|
||||||
|
"""
|
||||||
if thumbnail_infos:
|
if thumbnail_infos:
|
||||||
thumbnail_info = self._select_thumbnail(
|
file_info = self._select_thumbnail(
|
||||||
width, height, method, m_type, thumbnail_infos
|
desired_width,
|
||||||
|
desired_height,
|
||||||
|
desired_method,
|
||||||
|
desired_type,
|
||||||
|
thumbnail_infos,
|
||||||
|
file_id,
|
||||||
|
url_cache,
|
||||||
|
server_name,
|
||||||
)
|
)
|
||||||
file_info = FileInfo(
|
if not file_info:
|
||||||
server_name=server_name,
|
logger.info("Couldn't find a thumbnail matching the desired inputs")
|
||||||
file_id=media_info["filesystem_id"],
|
respond_404(request)
|
||||||
thumbnail=True,
|
return
|
||||||
thumbnail_width=thumbnail_info["thumbnail_width"],
|
|
||||||
thumbnail_height=thumbnail_info["thumbnail_height"],
|
|
||||||
thumbnail_type=thumbnail_info["thumbnail_type"],
|
|
||||||
thumbnail_method=thumbnail_info["thumbnail_method"],
|
|
||||||
)
|
|
||||||
|
|
||||||
t_type = file_info.thumbnail_type
|
|
||||||
t_length = thumbnail_info["thumbnail_length"]
|
|
||||||
|
|
||||||
responder = await self.media_storage.fetch_media(file_info)
|
responder = await self.media_storage.fetch_media(file_info)
|
||||||
await respond_with_responder(request, responder, t_type, t_length)
|
await respond_with_responder(
|
||||||
|
request, responder, file_info.thumbnail_type, file_info.thumbnail_length
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Failed to find any generated thumbnails")
|
logger.info("Failed to find any generated thumbnails")
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
|
@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||||
desired_height: int,
|
desired_height: int,
|
||||||
desired_method: str,
|
desired_method: str,
|
||||||
desired_type: str,
|
desired_type: str,
|
||||||
thumbnail_infos,
|
thumbnail_infos: List[Dict[str, Any]],
|
||||||
) -> dict:
|
file_id: str,
|
||||||
|
url_cache: Optional[str],
|
||||||
|
server_name: Optional[str],
|
||||||
|
) -> Optional[FileInfo]:
|
||||||
|
"""
|
||||||
|
Choose an appropriate thumbnail from the previously generated thumbnails.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
desired_width: The desired width, the returned thumbnail may be larger than this.
|
||||||
|
desired_height: The desired height, the returned thumbnail may be larger than this.
|
||||||
|
desired_method: The desired method used to generate the thumbnail.
|
||||||
|
desired_type: The desired content-type of the thumbnail.
|
||||||
|
thumbnail_infos: A list of dictionaries of candidate thumbnails.
|
||||||
|
file_id: The ID of the media that a thumbnail is being requested for.
|
||||||
|
url_cache: The URL cache value.
|
||||||
|
server_name: The server name, if this is a remote thumbnail.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The thumbnail which best matches the desired parameters.
|
||||||
|
"""
|
||||||
|
desired_method = desired_method.lower()
|
||||||
|
|
||||||
|
# The chosen thumbnail.
|
||||||
|
thumbnail_info = None
|
||||||
|
|
||||||
d_w = desired_width
|
d_w = desired_width
|
||||||
d_h = desired_height
|
d_h = desired_height
|
||||||
|
|
||||||
if desired_method.lower() == "crop":
|
if desired_method == "crop":
|
||||||
|
# Thumbnails that match equal or larger sizes of desired width/height.
|
||||||
crop_info_list = []
|
crop_info_list = []
|
||||||
|
# Other thumbnails.
|
||||||
crop_info_list2 = []
|
crop_info_list2 = []
|
||||||
for info in thumbnail_infos:
|
for info in thumbnail_infos:
|
||||||
|
# Skip thumbnails generated with different methods.
|
||||||
|
if info["thumbnail_method"] != "crop":
|
||||||
|
continue
|
||||||
|
|
||||||
t_w = info["thumbnail_width"]
|
t_w = info["thumbnail_width"]
|
||||||
t_h = info["thumbnail_height"]
|
t_h = info["thumbnail_height"]
|
||||||
t_method = info["thumbnail_method"]
|
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
||||||
if t_method == "crop":
|
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
|
||||||
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
|
||||||
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
|
|
||||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
|
||||||
type_quality = desired_type != info["thumbnail_type"]
|
|
||||||
length_quality = info["thumbnail_length"]
|
|
||||||
if t_w >= d_w or t_h >= d_h:
|
|
||||||
crop_info_list.append(
|
|
||||||
(
|
|
||||||
aspect_quality,
|
|
||||||
min_quality,
|
|
||||||
size_quality,
|
|
||||||
type_quality,
|
|
||||||
length_quality,
|
|
||||||
info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
crop_info_list2.append(
|
|
||||||
(
|
|
||||||
aspect_quality,
|
|
||||||
min_quality,
|
|
||||||
size_quality,
|
|
||||||
type_quality,
|
|
||||||
length_quality,
|
|
||||||
info,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if crop_info_list:
|
|
||||||
return min(crop_info_list)[-1]
|
|
||||||
else:
|
|
||||||
return min(crop_info_list2)[-1]
|
|
||||||
else:
|
|
||||||
info_list = []
|
|
||||||
info_list2 = []
|
|
||||||
for info in thumbnail_infos:
|
|
||||||
t_w = info["thumbnail_width"]
|
|
||||||
t_h = info["thumbnail_height"]
|
|
||||||
t_method = info["thumbnail_method"]
|
|
||||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||||
type_quality = desired_type != info["thumbnail_type"]
|
type_quality = desired_type != info["thumbnail_type"]
|
||||||
length_quality = info["thumbnail_length"]
|
length_quality = info["thumbnail_length"]
|
||||||
if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
|
if t_w >= d_w or t_h >= d_h:
|
||||||
|
crop_info_list.append(
|
||||||
|
(
|
||||||
|
aspect_quality,
|
||||||
|
min_quality,
|
||||||
|
size_quality,
|
||||||
|
type_quality,
|
||||||
|
length_quality,
|
||||||
|
info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
crop_info_list2.append(
|
||||||
|
(
|
||||||
|
aspect_quality,
|
||||||
|
min_quality,
|
||||||
|
size_quality,
|
||||||
|
type_quality,
|
||||||
|
length_quality,
|
||||||
|
info,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if crop_info_list:
|
||||||
|
thumbnail_info = min(crop_info_list)[-1]
|
||||||
|
elif crop_info_list2:
|
||||||
|
thumbnail_info = min(crop_info_list2)[-1]
|
||||||
|
elif desired_method == "scale":
|
||||||
|
# Thumbnails that match equal or larger sizes of desired width/height.
|
||||||
|
info_list = []
|
||||||
|
# Other thumbnails.
|
||||||
|
info_list2 = []
|
||||||
|
|
||||||
|
for info in thumbnail_infos:
|
||||||
|
# Skip thumbnails generated with different methods.
|
||||||
|
if info["thumbnail_method"] != "scale":
|
||||||
|
continue
|
||||||
|
|
||||||
|
t_w = info["thumbnail_width"]
|
||||||
|
t_h = info["thumbnail_height"]
|
||||||
|
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||||
|
type_quality = desired_type != info["thumbnail_type"]
|
||||||
|
length_quality = info["thumbnail_length"]
|
||||||
|
if t_w >= d_w or t_h >= d_h:
|
||||||
info_list.append((size_quality, type_quality, length_quality, info))
|
info_list.append((size_quality, type_quality, length_quality, info))
|
||||||
elif t_method == "scale":
|
else:
|
||||||
info_list2.append(
|
info_list2.append(
|
||||||
(size_quality, type_quality, length_quality, info)
|
(size_quality, type_quality, length_quality, info)
|
||||||
)
|
)
|
||||||
if info_list:
|
if info_list:
|
||||||
return min(info_list)[-1]
|
thumbnail_info = min(info_list)[-1]
|
||||||
else:
|
elif info_list2:
|
||||||
return min(info_list2)[-1]
|
thumbnail_info = min(info_list2)[-1]
|
||||||
|
|
||||||
|
if thumbnail_info:
|
||||||
|
return FileInfo(
|
||||||
|
file_id=file_id,
|
||||||
|
url_cache=url_cache,
|
||||||
|
server_name=server_name,
|
||||||
|
thumbnail=True,
|
||||||
|
thumbnail_width=thumbnail_info["thumbnail_width"],
|
||||||
|
thumbnail_height=thumbnail_info["thumbnail_height"],
|
||||||
|
thumbnail_type=thumbnail_info["thumbnail_type"],
|
||||||
|
thumbnail_method=thumbnail_info["thumbnail_method"],
|
||||||
|
thumbnail_length=thumbnail_info["thumbnail_length"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# No matching thumbnail was found.
|
||||||
|
return None
|
||||||
|
|
|
@ -103,6 +103,7 @@ from synapse.notifier import Notifier
|
||||||
from synapse.push.action_generator import ActionGenerator
|
from synapse.push.action_generator import ActionGenerator
|
||||||
from synapse.push.pusherpool import PusherPool
|
from synapse.push.pusherpool import PusherPool
|
||||||
from synapse.replication.tcp.client import ReplicationDataHandler
|
from synapse.replication.tcp.client import ReplicationDataHandler
|
||||||
|
from synapse.replication.tcp.external_cache import ExternalCache
|
||||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||||
from synapse.replication.tcp.resource import ReplicationStreamer
|
from synapse.replication.tcp.resource import ReplicationStreamer
|
||||||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
||||||
|
@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from txredisapi import RedisProtocol
|
||||||
|
|
||||||
from synapse.handlers.oidc_handler import OidcHandler
|
from synapse.handlers.oidc_handler import OidcHandler
|
||||||
from synapse.handlers.saml_handler import SamlHandler
|
from synapse.handlers.saml_handler import SamlHandler
|
||||||
|
|
||||||
|
@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||||
def get_account_data_handler(self) -> AccountDataHandler:
|
def get_account_data_handler(self) -> AccountDataHandler:
|
||||||
return AccountDataHandler(self)
|
return AccountDataHandler(self)
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
|
def get_external_cache(self) -> ExternalCache:
|
||||||
|
return ExternalCache(self)
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
|
def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
|
||||||
|
if not self.config.redis.redis_enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# We only want to import redis module if we're using it, as we have
|
||||||
|
# `txredisapi` as an optional dependency.
|
||||||
|
from synapse.replication.tcp.redis import lazyConnection
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Connecting to redis (host=%r port=%r) for external cache",
|
||||||
|
self.config.redis_host,
|
||||||
|
self.config.redis_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
return lazyConnection(
|
||||||
|
hs=self,
|
||||||
|
host=self.config.redis_host,
|
||||||
|
port=self.config.redis_port,
|
||||||
|
password=self.config.redis.redis_password,
|
||||||
|
reconnect=True,
|
||||||
|
)
|
||||||
|
|
||||||
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
|
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
|
||||||
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||||
|
|
||||||
|
|
|
@ -310,6 +310,7 @@ class StateHandler:
|
||||||
state_group_before_event = None
|
state_group_before_event = None
|
||||||
state_group_before_event_prev_group = None
|
state_group_before_event_prev_group = None
|
||||||
deltas_to_state_group_before_event = None
|
deltas_to_state_group_before_event = None
|
||||||
|
entry = None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# otherwise, we'll need to resolve the state across the prev_events.
|
# otherwise, we'll need to resolve the state across the prev_events.
|
||||||
|
@ -340,9 +341,13 @@ class StateHandler:
|
||||||
current_state_ids=state_ids_before_event,
|
current_state_ids=state_ids_before_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
# XXX: can we update the state cache entry for the new state group? or
|
# Assign the new state group to the cached state entry.
|
||||||
# could we set a flag on resolve_state_groups_for_events to tell it to
|
#
|
||||||
# always make a state group?
|
# Note that this can race in that we could generate multiple state
|
||||||
|
# groups for the same state entry, but that is just inefficient
|
||||||
|
# rather than dangerous.
|
||||||
|
if entry and entry.state_group is None:
|
||||||
|
entry.state_group = state_group_before_event
|
||||||
|
|
||||||
#
|
#
|
||||||
# now if it's not a state event, we're done
|
# now if it's not a state event, we're done
|
||||||
|
|
|
@ -262,13 +262,18 @@ class LoggingTransaction:
|
||||||
return self.txn.description
|
return self.txn.description
|
||||||
|
|
||||||
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
|
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
|
||||||
|
"""Similar to `executemany`, except `txn.rowcount` will not be correct
|
||||||
|
afterwards.
|
||||||
|
|
||||||
|
More efficient than `executemany` on PostgreSQL
|
||||||
|
"""
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
from psycopg2.extras import execute_batch # type: ignore
|
from psycopg2.extras import execute_batch # type: ignore
|
||||||
|
|
||||||
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
|
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
|
||||||
else:
|
else:
|
||||||
for val in args:
|
self.executemany(sql, args)
|
||||||
self.execute(sql, val)
|
|
||||||
|
|
||||||
def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
|
def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
|
||||||
"""Corresponds to psycopg2.extras.execute_values. Only available when
|
"""Corresponds to psycopg2.extras.execute_values. Only available when
|
||||||
|
@ -888,7 +893,7 @@ class DatabasePool:
|
||||||
", ".join("?" for _ in keys[0]),
|
", ".join("?" for _ in keys[0]),
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.executemany(sql, vals)
|
txn.execute_batch(sql, vals)
|
||||||
|
|
||||||
async def simple_upsert(
|
async def simple_upsert(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2014-2016 OpenMarket Ltd
|
# Copyright 2014-2016 OpenMarket Ltd
|
||||||
# Copyright 2018 New Vector Ltd
|
# Copyright 2018 New Vector Ltd
|
||||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
|
||||||
from .event_federation import EventFederationStore
|
from .event_federation import EventFederationStore
|
||||||
from .event_push_actions import EventPushActionsStore
|
from .event_push_actions import EventPushActionsStore
|
||||||
from .events_bg_updates import EventsBackgroundUpdatesStore
|
from .events_bg_updates import EventsBackgroundUpdatesStore
|
||||||
|
from .events_forward_extremities import EventForwardExtremitiesStore
|
||||||
from .filtering import FilteringStore
|
from .filtering import FilteringStore
|
||||||
from .group_server import GroupServerStore
|
from .group_server import GroupServerStore
|
||||||
from .keys import KeyStore
|
from .keys import KeyStore
|
||||||
|
@ -118,6 +119,7 @@ class DataStore(
|
||||||
UIAuthStore,
|
UIAuthStore,
|
||||||
CacheInvalidationWorkerStore,
|
CacheInvalidationWorkerStore,
|
||||||
ServerMetricsStore,
|
ServerMetricsStore,
|
||||||
|
EventForwardExtremitiesStore,
|
||||||
):
|
):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
|
|
|
@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||||
DELETE FROM device_lists_outbound_last_success
|
DELETE FROM device_lists_outbound_last_success
|
||||||
WHERE destination = ? AND user_id = ?
|
WHERE destination = ? AND user_id = ?
|
||||||
"""
|
"""
|
||||||
txn.executemany(sql, ((row[0], row[1]) for row in rows))
|
txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
|
||||||
|
|
||||||
logger.info("Pruned %d device list outbound pokes", count)
|
logger.info("Pruned %d device list outbound pokes", count)
|
||||||
|
|
||||||
|
@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
|
|
||||||
# Delete older entries in the table, as we really only care about
|
# Delete older entries in the table, as we really only care about
|
||||||
# when the latest change happened.
|
# when the latest change happened.
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
"""
|
"""
|
||||||
DELETE FROM device_lists_stream
|
DELETE FROM device_lists_stream
|
||||||
WHERE user_id = ? AND device_id = ? AND stream_id < ?
|
WHERE user_id = ? AND device_id = ? AND stream_id < ?
|
||||||
|
|
|
@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
VALUES (?, ?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
sql,
|
sql,
|
||||||
(
|
(
|
||||||
_gen_entry(user_id, actions)
|
_gen_entry(user_id, actions)
|
||||||
|
@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
"""
|
"""
|
||||||
UPDATE event_push_summary
|
UPDATE event_push_summary
|
||||||
SET notif_count = ?, unread_count = ?, stream_ordering = ?
|
SET notif_count = ?, unread_count = ?, stream_ordering = ?
|
||||||
|
|
|
@ -473,8 +473,9 @@ class PersistEventsStore:
|
||||||
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
|
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def _add_chain_cover_index(
|
def _add_chain_cover_index(
|
||||||
|
cls,
|
||||||
txn,
|
txn,
|
||||||
db_pool: DatabasePool,
|
db_pool: DatabasePool,
|
||||||
event_to_room_id: Dict[str, str],
|
event_to_room_id: Dict[str, str],
|
||||||
|
@ -614,60 +615,17 @@ class PersistEventsStore:
|
||||||
if not events_to_calc_chain_id_for:
|
if not events_to_calc_chain_id_for:
|
||||||
return
|
return
|
||||||
|
|
||||||
# We now calculate the chain IDs/sequence numbers for the events. We
|
# Allocate chain ID/sequence numbers to each new event.
|
||||||
# do this by looking at the chain ID and sequence number of any auth
|
new_chain_tuples = cls._allocate_chain_ids(
|
||||||
# event with the same type/state_key and incrementing the sequence
|
txn,
|
||||||
# number by one. If there was no match or the chain ID/sequence
|
db_pool,
|
||||||
# number is already taken we generate a new chain.
|
event_to_room_id,
|
||||||
#
|
event_to_types,
|
||||||
# We need to do this in a topologically sorted order as we want to
|
event_to_auth_chain,
|
||||||
# generate chain IDs/sequence numbers of an event's auth events
|
events_to_calc_chain_id_for,
|
||||||
# before the event itself.
|
chain_map,
|
||||||
chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
|
)
|
||||||
new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
|
chain_map.update(new_chain_tuples)
|
||||||
for event_id in sorted_topologically(
|
|
||||||
events_to_calc_chain_id_for, event_to_auth_chain
|
|
||||||
):
|
|
||||||
existing_chain_id = None
|
|
||||||
for auth_id in event_to_auth_chain.get(event_id, []):
|
|
||||||
if event_to_types.get(event_id) == event_to_types.get(auth_id):
|
|
||||||
existing_chain_id = chain_map[auth_id]
|
|
||||||
break
|
|
||||||
|
|
||||||
new_chain_tuple = None
|
|
||||||
if existing_chain_id:
|
|
||||||
# We found a chain ID/sequence number candidate, check its
|
|
||||||
# not already taken.
|
|
||||||
proposed_new_id = existing_chain_id[0]
|
|
||||||
proposed_new_seq = existing_chain_id[1] + 1
|
|
||||||
if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
|
|
||||||
already_allocated = db_pool.simple_select_one_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="event_auth_chains",
|
|
||||||
keyvalues={
|
|
||||||
"chain_id": proposed_new_id,
|
|
||||||
"sequence_number": proposed_new_seq,
|
|
||||||
},
|
|
||||||
retcol="event_id",
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
if already_allocated:
|
|
||||||
# Mark it as already allocated so we don't need to hit
|
|
||||||
# the DB again.
|
|
||||||
chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
|
|
||||||
else:
|
|
||||||
new_chain_tuple = (
|
|
||||||
proposed_new_id,
|
|
||||||
proposed_new_seq,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not new_chain_tuple:
|
|
||||||
new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
|
|
||||||
|
|
||||||
chains_tuples_allocated.add(new_chain_tuple)
|
|
||||||
|
|
||||||
chain_map[event_id] = new_chain_tuple
|
|
||||||
new_chain_tuples[event_id] = new_chain_tuple
|
|
||||||
|
|
||||||
db_pool.simple_insert_many_txn(
|
db_pool.simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
|
@ -794,6 +752,137 @@ class PersistEventsStore:
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _allocate_chain_ids(
|
||||||
|
txn,
|
||||||
|
db_pool: DatabasePool,
|
||||||
|
event_to_room_id: Dict[str, str],
|
||||||
|
event_to_types: Dict[str, Tuple[str, str]],
|
||||||
|
event_to_auth_chain: Dict[str, List[str]],
|
||||||
|
events_to_calc_chain_id_for: Set[str],
|
||||||
|
chain_map: Dict[str, Tuple[int, int]],
|
||||||
|
) -> Dict[str, Tuple[int, int]]:
|
||||||
|
"""Allocates, but does not persist, chain ID/sequence numbers for the
|
||||||
|
events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
|
||||||
|
for info on args)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We now calculate the chain IDs/sequence numbers for the events. We do
|
||||||
|
# this by looking at the chain ID and sequence number of any auth event
|
||||||
|
# with the same type/state_key and incrementing the sequence number by
|
||||||
|
# one. If there was no match or the chain ID/sequence number is already
|
||||||
|
# taken we generate a new chain.
|
||||||
|
#
|
||||||
|
# We try to reduce the number of times that we hit the database by
|
||||||
|
# batching up calls, to make this more efficient when persisting large
|
||||||
|
# numbers of state events (e.g. during joins).
|
||||||
|
#
|
||||||
|
# We do this by:
|
||||||
|
# 1. Calculating for each event which auth event will be used to
|
||||||
|
# inherit the chain ID, i.e. converting the auth chain graph to a
|
||||||
|
# tree that we can allocate chains on. We also keep track of which
|
||||||
|
# existing chain IDs have been referenced.
|
||||||
|
# 2. Fetching the max allocated sequence number for each referenced
|
||||||
|
# existing chain ID, generating a map from chain ID to the max
|
||||||
|
# allocated sequence number.
|
||||||
|
# 3. Iterating over the tree and allocating a chain ID/seq no. to the
|
||||||
|
# new event, by incrementing the sequence number from the
|
||||||
|
# referenced event's chain ID/seq no. and checking that the
|
||||||
|
# incremented sequence number hasn't already been allocated (by
|
||||||
|
# looking in the map generated in the previous step). We generate a
|
||||||
|
# new chain if the sequence number has already been allocated.
|
||||||
|
#
|
||||||
|
|
||||||
|
existing_chains = set() # type: Set[int]
|
||||||
|
tree = [] # type: List[Tuple[str, Optional[str]]]
|
||||||
|
|
||||||
|
# We need to do this in a topologically sorted order as we want to
|
||||||
|
# generate chain IDs/sequence numbers of an event's auth events before
|
||||||
|
# the event itself.
|
||||||
|
for event_id in sorted_topologically(
|
||||||
|
events_to_calc_chain_id_for, event_to_auth_chain
|
||||||
|
):
|
||||||
|
for auth_id in event_to_auth_chain.get(event_id, []):
|
||||||
|
if event_to_types.get(event_id) == event_to_types.get(auth_id):
|
||||||
|
existing_chain_id = chain_map.get(auth_id)
|
||||||
|
if existing_chain_id:
|
||||||
|
existing_chains.add(existing_chain_id[0])
|
||||||
|
|
||||||
|
tree.append((event_id, auth_id))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
tree.append((event_id, None))
|
||||||
|
|
||||||
|
# Fetch the current max sequence number for each existing referenced chain.
|
||||||
|
sql = """
|
||||||
|
SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
|
||||||
|
WHERE %s
|
||||||
|
GROUP BY chain_id
|
||||||
|
"""
|
||||||
|
clause, args = make_in_list_sql_clause(
|
||||||
|
db_pool.engine, "chain_id", existing_chains
|
||||||
|
)
|
||||||
|
txn.execute(sql % (clause,), args)
|
||||||
|
|
||||||
|
chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
|
||||||
|
|
||||||
|
# Allocate the new events chain ID/sequence numbers.
|
||||||
|
#
|
||||||
|
# To reduce the number of calls to the database we don't allocate a
|
||||||
|
# chain ID number in the loop, instead we use a temporary `object()` for
|
||||||
|
# each new chain ID. Once we've done the loop we generate the necessary
|
||||||
|
# number of new chain IDs in one call, replacing all temporary
|
||||||
|
# objects with real allocated chain IDs.
|
||||||
|
|
||||||
|
unallocated_chain_ids = set() # type: Set[object]
|
||||||
|
new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
|
||||||
|
for event_id, auth_event_id in tree:
|
||||||
|
# If we reference an auth_event_id we fetch the allocated chain ID,
|
||||||
|
# either from the existing `chain_map` or the newly generated
|
||||||
|
# `new_chain_tuples` map.
|
||||||
|
existing_chain_id = None
|
||||||
|
if auth_event_id:
|
||||||
|
existing_chain_id = new_chain_tuples.get(auth_event_id)
|
||||||
|
if not existing_chain_id:
|
||||||
|
existing_chain_id = chain_map[auth_event_id]
|
||||||
|
|
||||||
|
new_chain_tuple = None # type: Optional[Tuple[Any, int]]
|
||||||
|
if existing_chain_id:
|
||||||
|
# We found a chain ID/sequence number candidate, check its
|
||||||
|
# not already taken.
|
||||||
|
proposed_new_id = existing_chain_id[0]
|
||||||
|
proposed_new_seq = existing_chain_id[1] + 1
|
||||||
|
|
||||||
|
if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
|
||||||
|
new_chain_tuple = (
|
||||||
|
proposed_new_id,
|
||||||
|
proposed_new_seq,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we need to start a new chain we allocate a temporary chain ID.
|
||||||
|
if not new_chain_tuple:
|
||||||
|
new_chain_tuple = (object(), 1)
|
||||||
|
unallocated_chain_ids.add(new_chain_tuple[0])
|
||||||
|
|
||||||
|
new_chain_tuples[event_id] = new_chain_tuple
|
||||||
|
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
|
||||||
|
|
||||||
|
# Generate new chain IDs for all unallocated chain IDs.
|
||||||
|
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
|
||||||
|
txn, len(unallocated_chain_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map from potentially temporary chain ID to real chain ID
|
||||||
|
chain_id_to_allocated_map = dict(
|
||||||
|
zip(unallocated_chain_ids, newly_allocated_chain_ids)
|
||||||
|
) # type: Dict[Any, int]
|
||||||
|
chain_id_to_allocated_map.update((c, c) for c in existing_chains)
|
||||||
|
|
||||||
|
return {
|
||||||
|
event_id: (chain_id_to_allocated_map[chain_id], seq)
|
||||||
|
for event_id, (chain_id, seq) in new_chain_tuples.items()
|
||||||
|
}
|
||||||
|
|
||||||
def _persist_transaction_ids_txn(
|
def _persist_transaction_ids_txn(
|
||||||
self,
|
self,
|
||||||
txn: LoggingTransaction,
|
txn: LoggingTransaction,
|
||||||
|
@ -876,7 +965,7 @@ class PersistEventsStore:
|
||||||
WHERE room_id = ? AND type = ? AND state_key = ?
|
WHERE room_id = ? AND type = ? AND state_key = ?
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
sql,
|
sql,
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
|
@ -895,7 +984,7 @@ class PersistEventsStore:
|
||||||
)
|
)
|
||||||
# Now we actually update the current_state_events table
|
# Now we actually update the current_state_events table
|
||||||
|
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
"DELETE FROM current_state_events"
|
"DELETE FROM current_state_events"
|
||||||
" WHERE room_id = ? AND type = ? AND state_key = ?",
|
" WHERE room_id = ? AND type = ? AND state_key = ?",
|
||||||
(
|
(
|
||||||
|
@ -907,7 +996,7 @@ class PersistEventsStore:
|
||||||
# We include the membership in the current state table, hence we do
|
# We include the membership in the current state table, hence we do
|
||||||
# a lookup when we insert. This assumes that all events have already
|
# a lookup when we insert. This assumes that all events have already
|
||||||
# been inserted into room_memberships.
|
# been inserted into room_memberships.
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
"""INSERT INTO current_state_events
|
"""INSERT INTO current_state_events
|
||||||
(room_id, type, state_key, event_id, membership)
|
(room_id, type, state_key, event_id, membership)
|
||||||
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
|
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
|
||||||
|
@ -927,7 +1016,7 @@ class PersistEventsStore:
|
||||||
# we have no record of the fact the user *was* a member of the
|
# we have no record of the fact the user *was* a member of the
|
||||||
# room but got, say, state reset out of it.
|
# room but got, say, state reset out of it.
|
||||||
if to_delete or to_insert:
|
if to_delete or to_insert:
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
"DELETE FROM local_current_membership"
|
"DELETE FROM local_current_membership"
|
||||||
" WHERE room_id = ? AND user_id = ?",
|
" WHERE room_id = ? AND user_id = ?",
|
||||||
(
|
(
|
||||||
|
@ -938,7 +1027,7 @@ class PersistEventsStore:
|
||||||
)
|
)
|
||||||
|
|
||||||
if to_insert:
|
if to_insert:
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
"""INSERT INTO local_current_membership
|
"""INSERT INTO local_current_membership
|
||||||
(room_id, user_id, event_id, membership)
|
(room_id, user_id, event_id, membership)
|
||||||
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
|
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
|
||||||
|
@ -1738,7 +1827,7 @@ class PersistEventsStore:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if events_and_contexts:
|
if events_and_contexts:
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
sql,
|
sql,
|
||||||
(
|
(
|
||||||
(
|
(
|
||||||
|
@ -1767,7 +1856,7 @@ class PersistEventsStore:
|
||||||
|
|
||||||
# Now we delete the staging area for *all* events that were being
|
# Now we delete the staging area for *all* events that were being
|
||||||
# persisted.
|
# persisted.
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
|
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
|
||||||
((event.event_id,) for event, _ in all_events_and_contexts),
|
((event.event_id,) for event, _ in all_events_and_contexts),
|
||||||
)
|
)
|
||||||
|
@ -1886,7 +1975,7 @@ class PersistEventsStore:
|
||||||
" )"
|
" )"
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
query,
|
query,
|
||||||
[
|
[
|
||||||
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
|
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
|
||||||
|
@ -1900,7 +1989,7 @@ class PersistEventsStore:
|
||||||
"DELETE FROM event_backward_extremities"
|
"DELETE FROM event_backward_extremities"
|
||||||
" WHERE event_id = ? AND room_id = ?"
|
" WHERE event_id = ? AND room_id = ?"
|
||||||
)
|
)
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
query,
|
query,
|
||||||
[
|
[
|
||||||
(ev.event_id, ev.room_id)
|
(ev.event_id, ev.room_id)
|
||||||
|
|
|
@ -139,8 +139,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
max_stream_id = progress["max_stream_id_exclusive"]
|
max_stream_id = progress["max_stream_id_exclusive"]
|
||||||
rows_inserted = progress.get("rows_inserted", 0)
|
rows_inserted = progress.get("rows_inserted", 0)
|
||||||
|
|
||||||
INSERT_CLUMP_SIZE = 1000
|
|
||||||
|
|
||||||
def reindex_txn(txn):
|
def reindex_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT stream_ordering, event_id, json FROM events"
|
"SELECT stream_ordering, event_id, json FROM events"
|
||||||
|
@ -178,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
|
|
||||||
sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
|
sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
|
||||||
|
|
||||||
for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
|
txn.execute_batch(sql, update_rows)
|
||||||
clump = update_rows[index : index + INSERT_CLUMP_SIZE]
|
|
||||||
txn.executemany(sql, clump)
|
|
||||||
|
|
||||||
progress = {
|
progress = {
|
||||||
"target_min_stream_id_inclusive": target_min_stream_id,
|
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||||
|
@ -210,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
max_stream_id = progress["max_stream_id_exclusive"]
|
max_stream_id = progress["max_stream_id_exclusive"]
|
||||||
rows_inserted = progress.get("rows_inserted", 0)
|
rows_inserted = progress.get("rows_inserted", 0)
|
||||||
|
|
||||||
INSERT_CLUMP_SIZE = 1000
|
|
||||||
|
|
||||||
def reindex_search_txn(txn):
|
def reindex_search_txn(txn):
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT stream_ordering, event_id FROM events"
|
"SELECT stream_ordering, event_id FROM events"
|
||||||
|
@ -256,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||||
|
|
||||||
sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
|
sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
|
||||||
|
|
||||||
for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
|
txn.execute_batch(sql, rows_to_update)
|
||||||
clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
|
|
||||||
txn.executemany(sql, clump)
|
|
||||||
|
|
||||||
progress = {
|
progress = {
|
||||||
"target_min_stream_id_inclusive": target_min_stream_id,
|
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EventForwardExtremitiesStore(SQLBaseStore):
|
||||||
|
async def delete_forward_extremities_for_room(self, room_id: str) -> int:
|
||||||
|
"""Delete any extra forward extremities for a room.
|
||||||
|
|
||||||
|
Invalidates the "get_latest_event_ids_in_room" cache if any forward
|
||||||
|
extremities were deleted.
|
||||||
|
|
||||||
|
Returns count deleted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def delete_forward_extremities_for_room_txn(txn):
|
||||||
|
# First we need to get the event_id to not delete
|
||||||
|
sql = """
|
||||||
|
SELECT event_id FROM event_forward_extremities
|
||||||
|
INNER JOIN events USING (room_id, event_id)
|
||||||
|
WHERE room_id = ?
|
||||||
|
ORDER BY stream_ordering DESC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
txn.execute(sql, (room_id,))
|
||||||
|
rows = txn.fetchall()
|
||||||
|
try:
|
||||||
|
event_id = rows[0][0]
|
||||||
|
logger.debug(
|
||||||
|
"Found event_id %s as the forward extremity to keep for room %s",
|
||||||
|
event_id,
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
except KeyError:
|
||||||
|
msg = "No forward extremity event found for room %s" % room_id
|
||||||
|
logger.warning(msg)
|
||||||
|
raise SynapseError(400, msg)
|
||||||
|
|
||||||
|
# Now delete the extra forward extremities
|
||||||
|
sql = """
|
||||||
|
DELETE FROM event_forward_extremities
|
||||||
|
WHERE event_id != ? AND room_id = ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (event_id, room_id))
|
||||||
|
logger.info(
|
||||||
|
"Deleted %s extra forward extremities for room %s",
|
||||||
|
txn.rowcount,
|
||||||
|
room_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if txn.rowcount > 0:
|
||||||
|
# Invalidate the cache
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_latest_event_ids_in_room, (room_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
return txn.rowcount
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"delete_forward_extremities_for_room",
|
||||||
|
delete_forward_extremities_for_room_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
|
||||||
|
"""Get list of forward extremities for a room."""
|
||||||
|
|
||||||
|
def get_forward_extremities_for_room_txn(txn):
|
||||||
|
sql = """
|
||||||
|
SELECT event_id, state_group, depth, received_ts
|
||||||
|
FROM event_forward_extremities
|
||||||
|
INNER JOIN event_to_state_groups USING (event_id)
|
||||||
|
INNER JOIN events USING (room_id, event_id)
|
||||||
|
WHERE room_id = ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (room_id,))
|
||||||
|
return self.db_pool.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
|
||||||
|
)
|
|
@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
" WHERE media_origin = ? AND media_id = ?"
|
" WHERE media_origin = ? AND media_id = ?"
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
sql,
|
sql,
|
||||||
(
|
(
|
||||||
(time_ms, media_origin, media_id)
|
(time_ms, media_origin, media_id)
|
||||||
|
@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
" WHERE media_id = ?"
|
" WHERE media_id = ?"
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
|
txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"update_cached_last_access_time", update_cache_txn
|
"update_cached_last_access_time", update_cache_txn
|
||||||
|
@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
|
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
|
||||||
|
|
||||||
def _delete_url_cache_txn(txn):
|
def _delete_url_cache_txn(txn):
|
||||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"delete_url_cache", _delete_url_cache_txn
|
"delete_url_cache", _delete_url_cache_txn
|
||||||
|
@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||||
def _delete_url_cache_media_txn(txn):
|
def _delete_url_cache_media_txn(txn):
|
||||||
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
|
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
|
||||||
|
|
||||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
|
||||||
|
|
||||||
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
|
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
|
||||||
|
|
||||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
|
||||||
|
|
||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"delete_url_cache_media", _delete_url_cache_media_txn
|
"delete_url_cache_media", _delete_url_cache_media_txn
|
||||||
|
|
|
@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update backward extremeties
|
# Update backward extremeties
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
"INSERT INTO event_backward_extremities (room_id, event_id)"
|
"INSERT INTO event_backward_extremities (room_id, event_id)"
|
||||||
" VALUES (?, ?)",
|
" VALUES (?, ?)",
|
||||||
[(room_id, event_id) for event_id, in new_backwards_extrems],
|
[(room_id, event_id) for event_id, in new_backwards_extrems],
|
||||||
|
|
|
@ -344,7 +344,9 @@ class PusherStore(PusherWorkerStore):
|
||||||
txn, self.get_if_user_has_pusher, (user_id,)
|
txn, self.get_if_user_has_pusher, (user_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.db_pool.simple_delete_one_txn(
|
# It is expected that there is exactly one pusher to delete, but
|
||||||
|
# if it isn't there (or there are multiple) delete them all.
|
||||||
|
self.db_pool.simple_delete_txn(
|
||||||
txn,
|
txn,
|
||||||
"pushers",
|
"pushers",
|
||||||
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
|
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
|
||||||
|
|
|
@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
|
||||||
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
||||||
|
|
||||||
|
async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
|
||||||
|
"""Sets whether a user shadow-banned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: user ID of the user to test
|
||||||
|
shadow_banned: true iff the user is to be shadow-banned, false otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def set_shadow_banned_txn(txn):
|
||||||
|
self.db_pool.simple_update_one_txn(
|
||||||
|
txn,
|
||||||
|
table="users",
|
||||||
|
keyvalues={"name": user.to_string()},
|
||||||
|
updatevalues={"shadow_banned": shadow_banned},
|
||||||
|
)
|
||||||
|
# In order for this to apply immediately, clear the cache for this user.
|
||||||
|
tokens = self.db_pool.simple_select_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="access_tokens",
|
||||||
|
keyvalues={"user_id": user.to_string()},
|
||||||
|
retcol="token",
|
||||||
|
)
|
||||||
|
for token in tokens:
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_user_by_access_token, (token,)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
|
||||||
|
|
||||||
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
|
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
|
||||||
sql = """
|
sql = """
|
||||||
SELECT users.name as user_id,
|
SELECT users.name as user_id,
|
||||||
|
@ -1104,7 +1133,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
FROM user_threepids
|
FROM user_threepids
|
||||||
"""
|
"""
|
||||||
|
|
||||||
txn.executemany(sql, [(id_server,) for id_server in id_servers])
|
txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
|
||||||
|
|
||||||
if id_servers:
|
if id_servers:
|
||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
|
|
|
@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||||
"max_stream_id_exclusive", self._stream_order_on_start + 1
|
"max_stream_id_exclusive", self._stream_order_on_start + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
INSERT_CLUMP_SIZE = 1000
|
|
||||||
|
|
||||||
def add_membership_profile_txn(txn):
|
def add_membership_profile_txn(txn):
|
||||||
sql = """
|
sql = """
|
||||||
SELECT stream_ordering, event_id, events.room_id, event_json.json
|
SELECT stream_ordering, event_id, events.room_id, event_json.json
|
||||||
|
@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||||
UPDATE room_memberships SET display_name = ?, avatar_url = ?
|
UPDATE room_memberships SET display_name = ?, avatar_url = ?
|
||||||
WHERE event_id = ? AND room_id = ?
|
WHERE event_id = ? AND room_id = ?
|
||||||
"""
|
"""
|
||||||
for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
|
txn.execute_batch(to_update_sql, to_update)
|
||||||
clump = to_update[index : index + INSERT_CLUMP_SIZE]
|
|
||||||
txn.executemany(to_update_sql, clump)
|
|
||||||
|
|
||||||
progress = {
|
progress = {
|
||||||
"target_min_stream_id_inclusive": target_min_stream_id,
|
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||||
|
|
|
@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
|
||||||
# { "ignored_users": "@someone:example.org": {} }
|
# { "ignored_users": "@someone:example.org": {} }
|
||||||
ignored_users = content.get("ignored_users", {})
|
ignored_users = content.get("ignored_users", {})
|
||||||
if isinstance(ignored_users, dict) and ignored_users:
|
if isinstance(ignored_users, dict) and ignored_users:
|
||||||
cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
|
cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
|
||||||
|
|
||||||
# Add indexes after inserting data for efficiency.
|
# Add indexes after inserting data for efficiency.
|
||||||
logger.info("Adding constraints to ignored_users table")
|
logger.info("Adding constraints to ignored_users table")
|
||||||
|
|
|
@ -63,7 +63,7 @@ class SearchWorkerStore(SQLBaseStore):
|
||||||
for entry in entries
|
for entry in entries
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.executemany(sql, args)
|
txn.execute_batch(sql, args)
|
||||||
|
|
||||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
sql = (
|
sql = (
|
||||||
|
@ -75,7 +75,7 @@ class SearchWorkerStore(SQLBaseStore):
|
||||||
for entry in entries
|
for entry in entries
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.executemany(sql, args)
|
txn.execute_batch(sql, args)
|
||||||
else:
|
else:
|
||||||
# This should be unreachable.
|
# This should be unreachable.
|
||||||
raise Exception("Unrecognized database engine")
|
raise Exception("Unrecognized database engine")
|
||||||
|
|
|
@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("[purge] removing redundant state groups")
|
logger.info("[purge] removing redundant state groups")
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
"DELETE FROM state_groups_state WHERE state_group = ?",
|
"DELETE FROM state_groups_state WHERE state_group = ?",
|
||||||
((sg,) for sg in state_groups_to_delete),
|
((sg,) for sg in state_groups_to_delete),
|
||||||
)
|
)
|
||||||
txn.executemany(
|
txn.execute_batch(
|
||||||
"DELETE FROM state_groups WHERE id = ?",
|
"DELETE FROM state_groups WHERE id = ?",
|
||||||
((sg,) for sg in state_groups_to_delete),
|
((sg,) for sg in state_groups_to_delete),
|
||||||
)
|
)
|
||||||
|
|
|
@ -15,12 +15,11 @@
|
||||||
import heapq
|
import heapq
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import deque
|
from collections import OrderedDict
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import Deque
|
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||||
|
@ -101,7 +100,13 @@ class StreamIdGenerator:
|
||||||
self._current = (max if step > 0 else min)(
|
self._current = (max if step > 0 else min)(
|
||||||
self._current, _load_current_id(db_conn, table, column, step)
|
self._current, _load_current_id(db_conn, table, column, step)
|
||||||
)
|
)
|
||||||
self._unfinished_ids = deque() # type: Deque[int]
|
|
||||||
|
# We use this as an ordered set, as we want to efficiently append items,
|
||||||
|
# remove items and get the first item. Since we insert IDs in order, the
|
||||||
|
# insertion ordering will ensure its in the correct ordering.
|
||||||
|
#
|
||||||
|
# The key and values are the same, but we never look at the values.
|
||||||
|
self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
|
||||||
|
|
||||||
def get_next(self):
|
def get_next(self):
|
||||||
"""
|
"""
|
||||||
|
@ -113,7 +118,7 @@ class StreamIdGenerator:
|
||||||
self._current += self._step
|
self._current += self._step
|
||||||
next_id = self._current
|
next_id = self._current
|
||||||
|
|
||||||
self._unfinished_ids.append(next_id)
|
self._unfinished_ids[next_id] = next_id
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def manager():
|
def manager():
|
||||||
|
@ -121,7 +126,7 @@ class StreamIdGenerator:
|
||||||
yield next_id
|
yield next_id
|
||||||
finally:
|
finally:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._unfinished_ids.remove(next_id)
|
self._unfinished_ids.pop(next_id)
|
||||||
|
|
||||||
return _AsyncCtxManagerWrapper(manager())
|
return _AsyncCtxManagerWrapper(manager())
|
||||||
|
|
||||||
|
@ -140,7 +145,7 @@ class StreamIdGenerator:
|
||||||
self._current += n * self._step
|
self._current += n * self._step
|
||||||
|
|
||||||
for next_id in next_ids:
|
for next_id in next_ids:
|
||||||
self._unfinished_ids.append(next_id)
|
self._unfinished_ids[next_id] = next_id
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def manager():
|
def manager():
|
||||||
|
@ -149,7 +154,7 @@ class StreamIdGenerator:
|
||||||
finally:
|
finally:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for next_id in next_ids:
|
for next_id in next_ids:
|
||||||
self._unfinished_ids.remove(next_id)
|
self._unfinished_ids.pop(next_id)
|
||||||
|
|
||||||
return _AsyncCtxManagerWrapper(manager())
|
return _AsyncCtxManagerWrapper(manager())
|
||||||
|
|
||||||
|
@ -162,7 +167,7 @@ class StreamIdGenerator:
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._unfinished_ids:
|
if self._unfinished_ids:
|
||||||
return self._unfinished_ids[0] - self._step
|
return next(iter(self._unfinished_ids)) - self._step
|
||||||
|
|
||||||
return self._current
|
return self._current
|
||||||
|
|
||||||
|
|
|
@ -69,6 +69,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
|
||||||
"""Gets the next ID in the sequence"""
|
"""Gets the next ID in the sequence"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
||||||
|
"""Get the next `n` IDs in the sequence"""
|
||||||
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def check_consistency(
|
def check_consistency(
|
||||||
self,
|
self,
|
||||||
|
@ -219,6 +224,17 @@ class LocalSequenceGenerator(SequenceGenerator):
|
||||||
self._current_max_id += 1
|
self._current_max_id += 1
|
||||||
return self._current_max_id
|
return self._current_max_id
|
||||||
|
|
||||||
|
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
||||||
|
with self._lock:
|
||||||
|
if self._current_max_id is None:
|
||||||
|
assert self._callback is not None
|
||||||
|
self._current_max_id = self._callback(txn)
|
||||||
|
self._callback = None
|
||||||
|
|
||||||
|
first_id = self._current_max_id + 1
|
||||||
|
self._current_max_id += n
|
||||||
|
return [first_id + i for i in range(n)]
|
||||||
|
|
||||||
def check_consistency(
|
def check_consistency(
|
||||||
self,
|
self,
|
||||||
db_conn: Connection,
|
db_conn: Connection,
|
||||||
|
|
|
@ -0,0 +1,229 @@
|
||||||
|
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
|
from synapse.api.constants import EventTypes, Membership
|
||||||
|
from synapse.api.room_versions import RoomVersions
|
||||||
|
from synapse.events import FrozenEvent
|
||||||
|
from synapse.push.presentable_names import calculate_room_name
|
||||||
|
from synapse.types import StateKey, StateMap
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class MockDataStore:
|
||||||
|
"""
|
||||||
|
A fake data store which stores a mapping of state key to event content.
|
||||||
|
(I.e. the state key is used as the event ID.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, events: Iterable[Tuple[StateKey, dict]]):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
events: A state map to event contents.
|
||||||
|
"""
|
||||||
|
self._events = {}
|
||||||
|
|
||||||
|
for i, (event_id, content) in enumerate(events):
|
||||||
|
self._events[event_id] = FrozenEvent(
|
||||||
|
{
|
||||||
|
"event_id": "$event_id",
|
||||||
|
"type": event_id[0],
|
||||||
|
"sender": "@user:test",
|
||||||
|
"state_key": event_id[1],
|
||||||
|
"room_id": "#room:test",
|
||||||
|
"content": content,
|
||||||
|
"origin_server_ts": i,
|
||||||
|
},
|
||||||
|
RoomVersions.V1,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_event(
|
||||||
|
self, event_id: StateKey, allow_none: bool = False
|
||||||
|
) -> Optional[FrozenEvent]:
|
||||||
|
assert allow_none, "Mock not configured for allow_none = False"
|
||||||
|
|
||||||
|
return self._events.get(event_id)
|
||||||
|
|
||||||
|
async def get_events(self, event_ids: Iterable[StateKey]):
|
||||||
|
# This is cheating since it just returns all events.
|
||||||
|
return self._events
|
||||||
|
|
||||||
|
|
||||||
|
class PresentableNamesTestCase(unittest.HomeserverTestCase):
|
||||||
|
USER_ID = "@test:test"
|
||||||
|
OTHER_USER_ID = "@user:test"
|
||||||
|
|
||||||
|
def _calculate_room_name(
|
||||||
|
self,
|
||||||
|
events: StateMap[dict],
|
||||||
|
user_id: str = "",
|
||||||
|
fallback_to_members: bool = True,
|
||||||
|
fallback_to_single_member: bool = True,
|
||||||
|
):
|
||||||
|
# This isn't 100% accurate, but works with MockDataStore.
|
||||||
|
room_state_ids = {k[0]: k[0] for k in events}
|
||||||
|
|
||||||
|
return self.get_success(
|
||||||
|
calculate_room_name(
|
||||||
|
MockDataStore(events),
|
||||||
|
room_state_ids,
|
||||||
|
user_id or self.USER_ID,
|
||||||
|
fallback_to_members,
|
||||||
|
fallback_to_single_member,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_name(self):
|
||||||
|
"""A room name event should be used."""
|
||||||
|
events = [
|
||||||
|
((EventTypes.Name, ""), {"name": "test-name"}),
|
||||||
|
]
|
||||||
|
self.assertEqual("test-name", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
# Check if the event content has garbage.
|
||||||
|
events = [((EventTypes.Name, ""), {"foo": 1})]
|
||||||
|
self.assertEqual("Empty Room", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
events = [((EventTypes.Name, ""), {"name": 1})]
|
||||||
|
self.assertEqual(1, self._calculate_room_name(events))
|
||||||
|
|
||||||
|
def test_canonical_alias(self):
|
||||||
|
"""An canonical alias should be used."""
|
||||||
|
events = [
|
||||||
|
((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}),
|
||||||
|
]
|
||||||
|
self.assertEqual("#test-name:test", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
# Check if the event content has garbage.
|
||||||
|
events = [((EventTypes.CanonicalAlias, ""), {"foo": 1})]
|
||||||
|
self.assertEqual("Empty Room", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})]
|
||||||
|
self.assertEqual("Empty Room", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
def test_invite(self):
|
||||||
|
"""An invite has special behaviour."""
|
||||||
|
events = [
|
||||||
|
((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
|
||||||
|
((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}),
|
||||||
|
]
|
||||||
|
self.assertEqual("Invite from Other User", self._calculate_room_name(events))
|
||||||
|
self.assertIsNone(
|
||||||
|
self._calculate_room_name(events, fallback_to_single_member=False)
|
||||||
|
)
|
||||||
|
# Ensure this logic is skipped if we don't fallback to members.
|
||||||
|
self.assertIsNone(self._calculate_room_name(events, fallback_to_members=False))
|
||||||
|
|
||||||
|
# Check if the event content has garbage.
|
||||||
|
events = [
|
||||||
|
((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
|
||||||
|
((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
|
||||||
|
]
|
||||||
|
self.assertEqual("Invite from @user:test", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
# No member event for sender.
|
||||||
|
events = [
|
||||||
|
((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
|
||||||
|
]
|
||||||
|
self.assertEqual("Room Invite", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
def test_no_members(self):
|
||||||
|
"""Behaviour of an empty room."""
|
||||||
|
events = []
|
||||||
|
self.assertEqual("Empty Room", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
# Note that events with invalid (or missing) membership are ignored.
|
||||||
|
events = [
|
||||||
|
((EventTypes.Member, self.OTHER_USER_ID), {"foo": 1}),
|
||||||
|
((EventTypes.Member, "@foo:test"), {"membership": "foo"}),
|
||||||
|
]
|
||||||
|
self.assertEqual("Empty Room", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
def test_no_other_members(self):
|
||||||
|
"""Behaviour of a room with no other members in it."""
|
||||||
|
events = [
|
||||||
|
(
|
||||||
|
(EventTypes.Member, self.USER_ID),
|
||||||
|
{"membership": Membership.JOIN, "displayname": "Me"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.assertEqual("Me", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
# Check if the event content has no displayname.
|
||||||
|
events = [
|
||||||
|
((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
|
||||||
|
]
|
||||||
|
self.assertEqual("@test:test", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
# 3pid invite, use the other user (who is set as the sender).
|
||||||
|
events = [
|
||||||
|
((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
|
||||||
|
]
|
||||||
|
self.assertEqual(
|
||||||
|
"nobody", self._calculate_room_name(events, user_id=self.OTHER_USER_ID)
|
||||||
|
)
|
||||||
|
|
||||||
|
events = [
|
||||||
|
((EventTypes.Member, self.OTHER_USER_ID), {"membership": Membership.JOIN}),
|
||||||
|
((EventTypes.ThirdPartyInvite, self.OTHER_USER_ID), {}),
|
||||||
|
]
|
||||||
|
self.assertEqual(
|
||||||
|
"Inviting email address",
|
||||||
|
self._calculate_room_name(events, user_id=self.OTHER_USER_ID),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_one_other_member(self):
|
||||||
|
"""Behaviour of a room with a single other member."""
|
||||||
|
events = [
|
||||||
|
((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
|
||||||
|
(
|
||||||
|
(EventTypes.Member, self.OTHER_USER_ID),
|
||||||
|
{"membership": Membership.JOIN, "displayname": "Other User"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.assertEqual("Other User", self._calculate_room_name(events))
|
||||||
|
self.assertIsNone(
|
||||||
|
self._calculate_room_name(events, fallback_to_single_member=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the event content has no displayname and is an invite.
|
||||||
|
events = [
|
||||||
|
((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
|
||||||
|
(
|
||||||
|
(EventTypes.Member, self.OTHER_USER_ID),
|
||||||
|
{"membership": Membership.INVITE},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.assertEqual("@user:test", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
def test_other_members(self):
|
||||||
|
"""Behaviour of a room with multiple other members."""
|
||||||
|
# Two other members.
|
||||||
|
events = [
|
||||||
|
((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
|
||||||
|
(
|
||||||
|
(EventTypes.Member, self.OTHER_USER_ID),
|
||||||
|
{"membership": Membership.JOIN, "displayname": "Other User"},
|
||||||
|
),
|
||||||
|
((EventTypes.Member, "@foo:test"), {"membership": Membership.JOIN}),
|
||||||
|
]
|
||||||
|
self.assertEqual("Other User and @foo:test", self._calculate_room_name(events))
|
||||||
|
|
||||||
|
# Three or more other members.
|
||||||
|
events.append(
|
||||||
|
((EventTypes.Member, "@fourth:test"), {"membership": Membership.INVITE})
|
||||||
|
)
|
||||||
|
self.assertEqual("Other User and 2 others", self._calculate_room_name(events))
|
|
@ -29,7 +29,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||||
"type": "m.room.history_visibility",
|
"type": "m.room.history_visibility",
|
||||||
"sender": "@user:test",
|
"sender": "@user:test",
|
||||||
"state_key": "",
|
"state_key": "",
|
||||||
"room_id": "@room:test",
|
"room_id": "#room:test",
|
||||||
"content": content,
|
"content": content,
|
||||||
},
|
},
|
||||||
RoomVersions.V1,
|
RoomVersions.V1,
|
||||||
|
|
|
@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
# Fake in memory Redis server that servers can connect to.
|
# Fake in memory Redis server that servers can connect to.
|
||||||
self._redis_server = FakeRedisPubSubServer()
|
self._redis_server = FakeRedisPubSubServer()
|
||||||
|
|
||||||
|
# We may have an attempt to connect to redis for the external cache already.
|
||||||
|
self.connect_any_redis_attempts()
|
||||||
|
|
||||||
store = self.hs.get_datastore()
|
store = self.hs.get_datastore()
|
||||||
self.database_pool = store.db_pool
|
self.database_pool = store.db_pool
|
||||||
|
|
||||||
|
@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||||
fake one.
|
fake one.
|
||||||
"""
|
"""
|
||||||
clients = self.reactor.tcpClients
|
clients = self.reactor.tcpClients
|
||||||
self.assertEqual(len(clients), 1)
|
while clients:
|
||||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||||
self.assertEqual(host, "localhost")
|
self.assertEqual(host, "localhost")
|
||||||
self.assertEqual(port, 6379)
|
self.assertEqual(port, 6379)
|
||||||
|
|
||||||
client_protocol = client_factory.buildProtocol(None)
|
client_protocol = client_factory.buildProtocol(None)
|
||||||
server_protocol = self._redis_server.buildProtocol(None)
|
server_protocol = self._redis_server.buildProtocol(None)
|
||||||
|
|
||||||
client_to_server_transport = FakeTransport(
|
client_to_server_transport = FakeTransport(
|
||||||
server_protocol, self.reactor, client_protocol
|
server_protocol, self.reactor, client_protocol
|
||||||
)
|
)
|
||||||
client_protocol.makeConnection(client_to_server_transport)
|
client_protocol.makeConnection(client_to_server_transport)
|
||||||
|
|
||||||
server_to_client_transport = FakeTransport(
|
server_to_client_transport = FakeTransport(
|
||||||
client_protocol, self.reactor, server_protocol
|
client_protocol, self.reactor, server_protocol
|
||||||
)
|
)
|
||||||
server_protocol.makeConnection(server_to_client_transport)
|
server_protocol.makeConnection(server_to_client_transport)
|
||||||
|
|
||||||
return client_to_server_transport, server_to_client_transport
|
|
||||||
|
|
||||||
|
|
||||||
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
|
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
|
||||||
|
@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||||
(channel,) = args
|
(channel,) = args
|
||||||
self._server.add_subscriber(self)
|
self._server.add_subscriber(self)
|
||||||
self.send(["subscribe", channel, 1])
|
self.send(["subscribe", channel, 1])
|
||||||
|
|
||||||
|
# Since we use SET/GET to cache things we can safely no-op them.
|
||||||
|
elif command == b"SET":
|
||||||
|
self.send("OK")
|
||||||
|
elif command == b"GET":
|
||||||
|
self.send(None)
|
||||||
else:
|
else:
|
||||||
raise Exception("Unknown command")
|
raise Exception("Unknown command")
|
||||||
|
|
||||||
|
@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||||
# We assume bytes are just unicode strings.
|
# We assume bytes are just unicode strings.
|
||||||
obj = obj.decode("utf-8")
|
obj = obj.decode("utf-8")
|
||||||
|
|
||||||
|
if obj is None:
|
||||||
|
return "$-1\r\n"
|
||||||
if isinstance(obj, str):
|
if isinstance(obj, str):
|
||||||
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
|
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
|
||||||
if isinstance(obj, int):
|
if isinstance(obj, int):
|
||||||
|
|
|
@ -28,6 +28,7 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.rest.client.v1 import login, logout, profile, room
|
from synapse.rest.client.v1 import login, logout, profile, room
|
||||||
from synapse.rest.client.v2_alpha import devices, sync
|
from synapse.rest.client.v2_alpha import devices, sync
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
|
@ -468,13 +469,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||||
self.admin_user_tok = self.login("admin", "pass")
|
self.admin_user_tok = self.login("admin", "pass")
|
||||||
|
|
||||||
self.user1 = self.register_user(
|
|
||||||
"user1", "pass1", admin=False, displayname="Name 1"
|
|
||||||
)
|
|
||||||
self.user2 = self.register_user(
|
|
||||||
"user2", "pass2", admin=False, displayname="Name 2"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_no_auth(self):
|
def test_no_auth(self):
|
||||||
"""
|
"""
|
||||||
Try to list users without authentication.
|
Try to list users without authentication.
|
||||||
|
@ -488,6 +482,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
If the user is not a server admin, an error is returned.
|
If the user is not a server admin, an error is returned.
|
||||||
"""
|
"""
|
||||||
|
self._create_users(1)
|
||||||
other_user_token = self.login("user1", "pass1")
|
other_user_token = self.login("user1", "pass1")
|
||||||
|
|
||||||
channel = self.make_request("GET", self.url, access_token=other_user_token)
|
channel = self.make_request("GET", self.url, access_token=other_user_token)
|
||||||
|
@ -499,6 +494,8 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
List all users, including deactivated users.
|
List all users, including deactivated users.
|
||||||
"""
|
"""
|
||||||
|
self._create_users(2)
|
||||||
|
|
||||||
channel = self.make_request(
|
channel = self.make_request(
|
||||||
"GET",
|
"GET",
|
||||||
self.url + "?deactivated=true",
|
self.url + "?deactivated=true",
|
||||||
|
@ -511,14 +508,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(3, channel.json_body["total"])
|
self.assertEqual(3, channel.json_body["total"])
|
||||||
|
|
||||||
# Check that all fields are available
|
# Check that all fields are available
|
||||||
for u in channel.json_body["users"]:
|
self._check_fields(channel.json_body["users"])
|
||||||
self.assertIn("name", u)
|
|
||||||
self.assertIn("is_guest", u)
|
|
||||||
self.assertIn("admin", u)
|
|
||||||
self.assertIn("user_type", u)
|
|
||||||
self.assertIn("deactivated", u)
|
|
||||||
self.assertIn("displayname", u)
|
|
||||||
self.assertIn("avatar_url", u)
|
|
||||||
|
|
||||||
def test_search_term(self):
|
def test_search_term(self):
|
||||||
"""Test that searching for a users works correctly"""
|
"""Test that searching for a users works correctly"""
|
||||||
|
@ -549,6 +539,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
# Check that users were returned
|
# Check that users were returned
|
||||||
self.assertTrue("users" in channel.json_body)
|
self.assertTrue("users" in channel.json_body)
|
||||||
|
self._check_fields(channel.json_body["users"])
|
||||||
users = channel.json_body["users"]
|
users = channel.json_body["users"]
|
||||||
|
|
||||||
# Check that the expected number of users were returned
|
# Check that the expected number of users were returned
|
||||||
|
@ -561,25 +552,30 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
u = users[0]
|
u = users[0]
|
||||||
self.assertEqual(expected_user_id, u["name"])
|
self.assertEqual(expected_user_id, u["name"])
|
||||||
|
|
||||||
|
self._create_users(2)
|
||||||
|
|
||||||
|
user1 = "@user1:test"
|
||||||
|
user2 = "@user2:test"
|
||||||
|
|
||||||
# Perform search tests
|
# Perform search tests
|
||||||
_search_test(self.user1, "er1")
|
_search_test(user1, "er1")
|
||||||
_search_test(self.user1, "me 1")
|
_search_test(user1, "me 1")
|
||||||
|
|
||||||
_search_test(self.user2, "er2")
|
_search_test(user2, "er2")
|
||||||
_search_test(self.user2, "me 2")
|
_search_test(user2, "me 2")
|
||||||
|
|
||||||
_search_test(self.user1, "er1", "user_id")
|
_search_test(user1, "er1", "user_id")
|
||||||
_search_test(self.user2, "er2", "user_id")
|
_search_test(user2, "er2", "user_id")
|
||||||
|
|
||||||
# Test case insensitive
|
# Test case insensitive
|
||||||
_search_test(self.user1, "ER1")
|
_search_test(user1, "ER1")
|
||||||
_search_test(self.user1, "NAME 1")
|
_search_test(user1, "NAME 1")
|
||||||
|
|
||||||
_search_test(self.user2, "ER2")
|
_search_test(user2, "ER2")
|
||||||
_search_test(self.user2, "NAME 2")
|
_search_test(user2, "NAME 2")
|
||||||
|
|
||||||
_search_test(self.user1, "ER1", "user_id")
|
_search_test(user1, "ER1", "user_id")
|
||||||
_search_test(self.user2, "ER2", "user_id")
|
_search_test(user2, "ER2", "user_id")
|
||||||
|
|
||||||
_search_test(None, "foo")
|
_search_test(None, "foo")
|
||||||
_search_test(None, "bar")
|
_search_test(None, "bar")
|
||||||
|
@ -587,6 +583,179 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||||
_search_test(None, "foo", "user_id")
|
_search_test(None, "foo", "user_id")
|
||||||
_search_test(None, "bar", "user_id")
|
_search_test(None, "bar", "user_id")
|
||||||
|
|
||||||
|
def test_invalid_parameter(self):
|
||||||
|
"""
|
||||||
|
If parameters are invalid, an error is returned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# negative limit
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
# negative from
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
# invalid guests
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
# invalid deactivated
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_limit(self):
|
||||||
|
"""
|
||||||
|
Testing list of users with limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
number_users = 20
|
||||||
|
# Create one less user (since there's already an admin user).
|
||||||
|
self._create_users(number_users - 1)
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["total"], number_users)
|
||||||
|
self.assertEqual(len(channel.json_body["users"]), 5)
|
||||||
|
self.assertEqual(channel.json_body["next_token"], "5")
|
||||||
|
self._check_fields(channel.json_body["users"])
|
||||||
|
|
||||||
|
def test_from(self):
|
||||||
|
"""
|
||||||
|
Testing list of users with a defined starting point (from)
|
||||||
|
"""
|
||||||
|
|
||||||
|
number_users = 20
|
||||||
|
# Create one less user (since there's already an admin user).
|
||||||
|
self._create_users(number_users - 1)
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["total"], number_users)
|
||||||
|
self.assertEqual(len(channel.json_body["users"]), 15)
|
||||||
|
self.assertNotIn("next_token", channel.json_body)
|
||||||
|
self._check_fields(channel.json_body["users"])
|
||||||
|
|
||||||
|
def test_limit_and_from(self):
|
||||||
|
"""
|
||||||
|
Testing list of users with a defined starting point and limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
number_users = 20
|
||||||
|
# Create one less user (since there's already an admin user).
|
||||||
|
self._create_users(number_users - 1)
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["total"], number_users)
|
||||||
|
self.assertEqual(channel.json_body["next_token"], "15")
|
||||||
|
self.assertEqual(len(channel.json_body["users"]), 10)
|
||||||
|
self._check_fields(channel.json_body["users"])
|
||||||
|
|
||||||
|
def test_next_token(self):
|
||||||
|
"""
|
||||||
|
Testing that `next_token` appears at the right place
|
||||||
|
"""
|
||||||
|
|
||||||
|
number_users = 20
|
||||||
|
# Create one less user (since there's already an admin user).
|
||||||
|
self._create_users(number_users - 1)
|
||||||
|
|
||||||
|
# `next_token` does not appear
|
||||||
|
# Number of results is the number of entries
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["total"], number_users)
|
||||||
|
self.assertEqual(len(channel.json_body["users"]), number_users)
|
||||||
|
self.assertNotIn("next_token", channel.json_body)
|
||||||
|
|
||||||
|
# `next_token` does not appear
|
||||||
|
# Number of max results is larger than the number of entries
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["total"], number_users)
|
||||||
|
self.assertEqual(len(channel.json_body["users"]), number_users)
|
||||||
|
self.assertNotIn("next_token", channel.json_body)
|
||||||
|
|
||||||
|
# `next_token` does appear
|
||||||
|
# Number of max results is smaller than the number of entries
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["total"], number_users)
|
||||||
|
self.assertEqual(len(channel.json_body["users"]), 19)
|
||||||
|
self.assertEqual(channel.json_body["next_token"], "19")
|
||||||
|
|
||||||
|
# Check
|
||||||
|
# Set `from` to value of `next_token` for request remaining entries
|
||||||
|
# `next_token` does not appear
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(channel.json_body["total"], number_users)
|
||||||
|
self.assertEqual(len(channel.json_body["users"]), 1)
|
||||||
|
self.assertNotIn("next_token", channel.json_body)
|
||||||
|
|
||||||
|
def _check_fields(self, content: JsonDict):
|
||||||
|
"""Checks that the expected user attributes are present in content
|
||||||
|
Args:
|
||||||
|
content: List that is checked for content
|
||||||
|
"""
|
||||||
|
for u in content:
|
||||||
|
self.assertIn("name", u)
|
||||||
|
self.assertIn("is_guest", u)
|
||||||
|
self.assertIn("admin", u)
|
||||||
|
self.assertIn("user_type", u)
|
||||||
|
self.assertIn("deactivated", u)
|
||||||
|
self.assertIn("displayname", u)
|
||||||
|
self.assertIn("avatar_url", u)
|
||||||
|
|
||||||
|
def _create_users(self, number_users: int):
|
||||||
|
"""
|
||||||
|
Create a number of users
|
||||||
|
Args:
|
||||||
|
number_users: Number of users to be created
|
||||||
|
"""
|
||||||
|
for i in range(1, number_users + 1):
|
||||||
|
self.register_user(
|
||||||
|
"user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DeactivateAccountTestCase(unittest.HomeserverTestCase):
|
class DeactivateAccountTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
|
@ -2211,3 +2380,67 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual(self.other_user, channel.json_body["user_id"])
|
self.assertEqual(self.other_user, channel.json_body["user_id"])
|
||||||
self.assertIn("devices", channel.json_body)
|
self.assertIn("devices", channel.json_body)
|
||||||
|
|
||||||
|
|
||||||
|
class ShadowBanRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, hs):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||||
|
self.admin_user_tok = self.login("admin", "pass")
|
||||||
|
|
||||||
|
self.other_user = self.register_user("user", "pass")
|
||||||
|
|
||||||
|
self.url = "/_synapse/admin/v1/users/%s/shadow_ban" % urllib.parse.quote(
|
||||||
|
self.other_user
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_no_auth(self):
|
||||||
|
"""
|
||||||
|
Try to get information of an user without authentication.
|
||||||
|
"""
|
||||||
|
channel = self.make_request("POST", self.url)
|
||||||
|
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_requester_is_not_admin(self):
|
||||||
|
"""
|
||||||
|
If the user is not a server admin, an error is returned.
|
||||||
|
"""
|
||||||
|
other_user_token = self.login("user", "pass")
|
||||||
|
|
||||||
|
channel = self.make_request("POST", self.url, access_token=other_user_token)
|
||||||
|
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_user_is_not_local(self):
|
||||||
|
"""
|
||||||
|
Tests that shadow-banning for a user that is not a local returns a 400
|
||||||
|
"""
|
||||||
|
url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
|
||||||
|
|
||||||
|
channel = self.make_request("POST", url, access_token=self.admin_user_tok)
|
||||||
|
self.assertEqual(400, channel.code, msg=channel.json_body)
|
||||||
|
|
||||||
|
def test_success(self):
|
||||||
|
"""
|
||||||
|
Shadow-banning should succeed for an admin.
|
||||||
|
"""
|
||||||
|
# The user starts off as not shadow-banned.
|
||||||
|
other_user_token = self.login("user", "pass")
|
||||||
|
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
|
||||||
|
self.assertFalse(result.shadow_banned)
|
||||||
|
|
||||||
|
channel = self.make_request("POST", self.url, access_token=self.admin_user_tok)
|
||||||
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
|
self.assertEqual({}, channel.json_body)
|
||||||
|
|
||||||
|
# Ensure the user is shadow-banned (and the cache was cleared).
|
||||||
|
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
|
||||||
|
self.assertTrue(result.shadow_banned)
|
||||||
|
|
|
@ -18,6 +18,7 @@ import synapse.rest.admin
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.rest.client.v1 import directory, login, profile, room
|
from synapse.rest.client.v1 import directory, login, profile, room
|
||||||
from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
|
from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
|
||||||
|
from synapse.types import UserID
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
@ -31,12 +32,7 @@ class _ShadowBannedBase(unittest.HomeserverTestCase):
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.store.db_pool.simple_update(
|
self.store.set_shadow_banned(UserID.from_string(self.banned_user_id), True)
|
||||||
table="users",
|
|
||||||
keyvalues={"name": self.banned_user_id},
|
|
||||||
updatevalues={"shadow_banned": True},
|
|
||||||
desc="shadow_ban",
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.other_user_id = self.register_user("otheruser", "pass")
|
self.other_user_id = self.register_user("otheruser", "pass")
|
||||||
|
|
|
@ -202,7 +202,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
|
|
||||||
config = self.default_config()
|
config = self.default_config()
|
||||||
config["media_store_path"] = self.media_store_path
|
config["media_store_path"] = self.media_store_path
|
||||||
config["thumbnail_requirements"] = {}
|
|
||||||
config["max_image_pixels"] = 2000000
|
config["max_image_pixels"] = 2000000
|
||||||
|
|
||||||
provider_config = {
|
provider_config = {
|
||||||
|
@ -313,15 +312,39 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
|
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
|
||||||
|
|
||||||
def test_thumbnail_crop(self):
|
def test_thumbnail_crop(self):
|
||||||
|
"""Test that a cropped remote thumbnail is available."""
|
||||||
self._test_thumbnail(
|
self._test_thumbnail(
|
||||||
"crop", self.test_image.expected_cropped, self.test_image.expected_found
|
"crop", self.test_image.expected_cropped, self.test_image.expected_found
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_thumbnail_scale(self):
|
def test_thumbnail_scale(self):
|
||||||
|
"""Test that a scaled remote thumbnail is available."""
|
||||||
self._test_thumbnail(
|
self._test_thumbnail(
|
||||||
"scale", self.test_image.expected_scaled, self.test_image.expected_found
|
"scale", self.test_image.expected_scaled, self.test_image.expected_found
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_invalid_type(self):
|
||||||
|
"""An invalid thumbnail type is never available."""
|
||||||
|
self._test_thumbnail("invalid", None, False)
|
||||||
|
|
||||||
|
@unittest.override_config(
|
||||||
|
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
|
||||||
|
)
|
||||||
|
def test_no_thumbnail_crop(self):
|
||||||
|
"""
|
||||||
|
Override the config to generate only scaled thumbnails, but request a cropped one.
|
||||||
|
"""
|
||||||
|
self._test_thumbnail("crop", None, False)
|
||||||
|
|
||||||
|
@unittest.override_config(
|
||||||
|
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
|
||||||
|
)
|
||||||
|
def test_no_thumbnail_scale(self):
|
||||||
|
"""
|
||||||
|
Override the config to generate only cropped thumbnails, but request a scaled one.
|
||||||
|
"""
|
||||||
|
self._test_thumbnail("scale", None, False)
|
||||||
|
|
||||||
def _test_thumbnail(self, method, expected_body, expected_found):
|
def _test_thumbnail(self, method, expected_body, expected_found):
|
||||||
params = "?width=32&height=32&method=" + method
|
params = "?width=32&height=32&method=" + method
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
|
|
|
@ -261,3 +261,32 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||||
html = ""
|
html = ""
|
||||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||||
self.assertEqual(og, {})
|
self.assertEqual(og, {})
|
||||||
|
|
||||||
|
def test_invalid_encoding(self):
|
||||||
|
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
|
||||||
|
html = """
|
||||||
|
<html>
|
||||||
|
<head><title>Foo</title></head>
|
||||||
|
<body>
|
||||||
|
Some text.
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
og = decode_and_calc_og(
|
||||||
|
html, "http://example.com/test.html", "invalid-encoding"
|
||||||
|
)
|
||||||
|
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||||
|
|
||||||
|
def test_invalid_encoding2(self):
|
||||||
|
"""A body which doesn't match the sent character encoding."""
|
||||||
|
# Note that this contains an invalid UTF-8 sequence in the title.
|
||||||
|
html = b"""
|
||||||
|
<html>
|
||||||
|
<head><title>\xff\xff Foo</title></head>
|
||||||
|
<body>
|
||||||
|
Some text.
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||||
|
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
|
||||||
|
|
18
tox.ini
18
tox.ini
|
@ -18,13 +18,16 @@ deps =
|
||||||
# installed on that).
|
# installed on that).
|
||||||
#
|
#
|
||||||
# anyway, make sure that we have a recent enough setuptools.
|
# anyway, make sure that we have a recent enough setuptools.
|
||||||
setuptools>=18.5
|
setuptools>=18.5 ; python_version >= '3.6'
|
||||||
|
setuptools>=18.5,<51.0.0 ; python_version < '3.6'
|
||||||
|
|
||||||
# we also need a semi-recent version of pip, because old ones fail to
|
# we also need a semi-recent version of pip, because old ones fail to
|
||||||
# install the "enum34" dependency of cryptography.
|
# install the "enum34" dependency of cryptography.
|
||||||
pip>=10
|
pip>=10 ; python_version >= '3.6'
|
||||||
|
pip>=10,<21.0 ; python_version < '3.6'
|
||||||
|
|
||||||
# directories/files we run the linters on
|
# directories/files we run the linters on.
|
||||||
|
# if you update this list, make sure to do the same in scripts-dev/lint.sh
|
||||||
lint_targets =
|
lint_targets =
|
||||||
setup.py
|
setup.py
|
||||||
synapse
|
synapse
|
||||||
|
@ -103,15 +106,10 @@ usedevelop=true
|
||||||
[testenv:py35-old]
|
[testenv:py35-old]
|
||||||
skip_install=True
|
skip_install=True
|
||||||
deps =
|
deps =
|
||||||
# Ensure a version of setuptools that supports Python 3.5 is installed.
|
|
||||||
setuptools < 51.0.0
|
|
||||||
|
|
||||||
# Old automat version for Twisted
|
# Old automat version for Twisted
|
||||||
Automat == 0.3.0
|
Automat == 0.3.0
|
||||||
|
|
||||||
lxml
|
lxml
|
||||||
coverage
|
{[base]deps}
|
||||||
coverage-enable-subprocess==1.0
|
|
||||||
|
|
||||||
commands =
|
commands =
|
||||||
# Make all greater-thans equals so we test the oldest version of our direct
|
# Make all greater-thans equals so we test the oldest version of our direct
|
||||||
|
@ -168,6 +166,8 @@ commands = {toxinidir}/scripts-dev/generate_sample_config --check
|
||||||
skip_install = True
|
skip_install = True
|
||||||
deps =
|
deps =
|
||||||
coverage
|
coverage
|
||||||
|
pip>=10 ; python_version >= '3.6'
|
||||||
|
pip>=10,<21.0 ; python_version < '3.6'
|
||||||
commands=
|
commands=
|
||||||
coverage combine
|
coverage combine
|
||||||
coverage report
|
coverage report
|
||||||
|
|
Loading…
Reference in New Issue