Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

pull/8675/head
Erik Johnston 2020-10-16 11:34:53 +01:00
commit e9b5e642c3
76 changed files with 2027 additions and 791 deletions

View File

@ -1,3 +1,27 @@
Synapse 1.21.2 (2020-10-15)
===========================
Debian packages and Docker images have been rebuilt using the latest versions of dependency libraries, including authlib 0.15.1. Please see bugfixes below.
Security advisory
-----------------
* HTML pages served via Synapse were vulnerable to cross-site scripting (XSS)
attacks. All server administrators are encouraged to upgrade.
([\#8444](https://github.com/matrix-org/synapse/pull/8444))
([CVE-2020-26891](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-26891))
This fix was originally included in v1.21.0 but was missing a security advisory.
This was reported by [Denis Kasak](https://github.com/dkasak).
Bugfixes
--------
- Fix rare bug where sending an event would fail due to a racey assertion. ([\#8530](https://github.com/matrix-org/synapse/issues/8530))
- An updated version of the authlib dependency is included in the Docker and Debian images to fix an issue using OpenID Connect. See [\#8534](https://github.com/matrix-org/synapse/issues/8534) for details.
Synapse 1.21.1 (2020-10-13) Synapse 1.21.1 (2020-10-13)
=========================== ===========================

View File

@ -63,6 +63,10 @@ run-time:
./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder ./scripts-dev/lint.sh path/to/file1.py path/to/file2.py path/to/folder
``` ```
You can also provided the `-d` option, which will lint the files that have been
changed since the last git commit. This will often be significantly faster than
linting the whole codebase.
Before pushing new changes, ensure they don't produce linting errors. Commit any Before pushing new changes, ensure they don't produce linting errors. Commit any
files that were corrected. files that were corrected.

1
changelog.d/8437.feature Normal file
View File

@ -0,0 +1 @@
Implement [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409) to send typing, read receipts, and presence events to appservices.

1
changelog.d/8472.misc Normal file
View File

@ -0,0 +1 @@
Add `-d` option to `./scripts-dev/lint.sh` to lint files that have changed since the last git commit.

1
changelog.d/8488.misc Normal file
View File

@ -0,0 +1 @@
Allow events to be sent to clients sooner when using sharded event persisters.

1
changelog.d/8503.misc Normal file
View File

@ -0,0 +1 @@
Add user agent to user_daily_visits table.

1
changelog.d/8515.misc Normal file
View File

@ -0,0 +1 @@
Apply some internal fixes to the `HomeServer` class to make its code more idiomatic and statically-verifiable.

1
changelog.d/8526.doc Normal file
View File

@ -0,0 +1 @@
Added note about docker in manhole.md regarding which ip address to bind to. Contributed by @Maquis196.

1
changelog.d/8529.doc Normal file
View File

@ -0,0 +1 @@
Document the new behaviour of the `allowed_lifetime_min` and `allowed_lifetime_max` settings in the room retention configuration.

1
changelog.d/8535.feature Normal file
View File

@ -0,0 +1 @@
Support modifying event content in `ThirdPartyRules` modules.

1
changelog.d/8537.misc Normal file
View File

@ -0,0 +1 @@
Factor out common code between `RoomMemberHandler._locally_reject_invite` and `EventCreationHandler.create_event`.

1
changelog.d/8542.misc Normal file
View File

@ -0,0 +1 @@
Improve database performance by executing more queries without starting transactions.

1
changelog.d/8547.misc Normal file
View File

@ -0,0 +1 @@
Enable mypy type checking for `synapse.util.caches`.

1
changelog.d/8548.misc Normal file
View File

@ -0,0 +1 @@
Rename `Cache` to `DeferredCache`, to better reflect its purpose.

7
debian/changelog vendored
View File

@ -1,3 +1,10 @@
matrix-synapse-py3 (1.21.2) stable; urgency=medium
[ Synapse Packaging team ]
* New synapse release 1.21.2.
-- Synapse Packaging team <packages@matrix.org> Thu, 15 Oct 2020 09:23:27 -0400
matrix-synapse-py3 (1.21.1) stable; urgency=medium matrix-synapse-py3 (1.21.1) stable; urgency=medium
[ Synapse Packaging team ] [ Synapse Packaging team ]

View File

@ -5,8 +5,45 @@ The "manhole" allows server administrators to access a Python shell on a running
Synapse installation. This is a very powerful mechanism for administration and Synapse installation. This is a very powerful mechanism for administration and
debugging. debugging.
**_Security Warning_**
Note that this will give administrative access to synapse to **all users** with
shell access to the server. It should therefore **not** be enabled in
environments where untrusted users have shell access.
***
To enable it, first uncomment the `manhole` listener configuration in To enable it, first uncomment the `manhole` listener configuration in
`homeserver.yaml`: `homeserver.yaml`. The configuration is slightly different if you're using docker.
#### Docker config
If you are using Docker, set `bind_addresses` to `['0.0.0.0']` as shown:
```yaml
listeners:
- port: 9000
bind_addresses: ['0.0.0.0']
type: manhole
```
When using `docker run` to start the server, you will then need to change the command to the following to include the
`manhole` port forwarding. The `-p 127.0.0.1:9000:9000` below is important: it
ensures that access to the `manhole` is only possible for local users.
```bash
docker run -d --name synapse \
--mount type=volume,src=synapse-data,dst=/data \
-p 8008:8008 \
-p 127.0.0.1:9000:9000 \
matrixdotorg/synapse:latest
```
#### Native config
If you are not using docker, set `bind_addresses` to `['::1', '127.0.0.1']` as shown.
The `bind_addresses` in the example below is important: it ensures that access to the
`manhole` is only possible for local users).
```yaml ```yaml
listeners: listeners:
@ -15,12 +52,7 @@ listeners:
type: manhole type: manhole
``` ```
(`bind_addresses` in the above is important: it ensures that access to the #### Accessing synapse manhole
manhole is only possible for local users).
Note that this will give administrative access to synapse to **all users** with
shell access to the server. It should therefore **not** be enabled in
environments where untrusted users have shell access.
Then restart synapse, and point an ssh client at port 9000 on localhost, using Then restart synapse, and point an ssh client at port 9000 on localhost, using
the username `matrix`: the username `matrix`:

View File

@ -136,24 +136,34 @@ the server's database.
### Lifetime limits ### Lifetime limits
**Note: this feature is mainly useful within a closed federation or on Server admins can set limits on the values of `max_lifetime` to use when
servers that don't federate, because there currently is no way to purging old events in a room. These limits can be defined as such in the
enforce these limits in an open federation.** `retention` section of the configuration file:
Server admins can restrict the values their local users are allowed to
use for both `min_lifetime` and `max_lifetime`. These limits can be
defined as such in the `retention` section of the configuration file:
```yaml ```yaml
allowed_lifetime_min: 1d allowed_lifetime_min: 1d
allowed_lifetime_max: 1y allowed_lifetime_max: 1y
``` ```
Here, `allowed_lifetime_min` is the lowest value a local user can set The limits are considered when running purge jobs. If necessary, the
for both `min_lifetime` and `max_lifetime`, and `allowed_lifetime_max` effective value of `max_lifetime` will be brought between
is the highest value. Both parameters are optional (e.g. setting `allowed_lifetime_min` and `allowed_lifetime_max` (inclusive).
`allowed_lifetime_min` but not `allowed_lifetime_max` only enforces a This means that, if the value of `max_lifetime` defined in the room's state
minimum and no maximum). is lower than `allowed_lifetime_min`, the value of `allowed_lifetime_min`
will be used instead. Likewise, if the value of `max_lifetime` is higher
than `allowed_lifetime_max`, the value of `allowed_lifetime_max` will be
used instead.
In the example above, we ensure Synapse never deletes events that are less
than one day old, and that it always deletes events that are over a year
old.
If a default policy is set, and its `max_lifetime` value is lower than
`allowed_lifetime_min` or higher than `allowed_lifetime_max`, the same
process applies.
Both parameters are optional; if one is omitted Synapse won't use it to
adjust the effective value of `max_lifetime`.
Like other settings in this section, these parameters can be expressed Like other settings in this section, these parameters can be expressed
either as a duration or as a number of milliseconds. either as a duration or as a number of milliseconds.

View File

@ -15,6 +15,7 @@ files =
synapse/events/builder.py, synapse/events/builder.py,
synapse/events/spamcheck.py, synapse/events/spamcheck.py,
synapse/federation, synapse/federation,
synapse/handlers/appservice.py,
synapse/handlers/account_data.py, synapse/handlers/account_data.py,
synapse/handlers/auth.py, synapse/handlers/auth.py,
synapse/handlers/cas_handler.py, synapse/handlers/cas_handler.py,
@ -64,9 +65,7 @@ files =
synapse/streams, synapse/streams,
synapse/types.py, synapse/types.py,
synapse/util/async_helpers.py, synapse/util/async_helpers.py,
synapse/util/caches/descriptors.py, synapse/util/caches,
synapse/util/caches/response_cache.py,
synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py, synapse/util/metrics.py,
tests/replication, tests/replication,
tests/test_utils, tests/test_utils,

View File

@ -1,4 +1,4 @@
#!/bin/sh #!/bin/bash
# #
# Runs linting scripts over the local Synapse checkout # Runs linting scripts over the local Synapse checkout
# isort - sorts import statements # isort - sorts import statements
@ -7,15 +7,90 @@
set -e set -e
if [ $# -ge 1 ] usage() {
then echo
files=$* echo "Usage: $0 [-h] [-d] [paths...]"
else echo
files="synapse tests scripts-dev scripts contrib synctl" echo "-d"
echo " Lint files that have changed since the last git commit."
echo
echo " If paths are provided and this option is set, both provided paths and those"
echo " that have changed since the last commit will be linted."
echo
echo " If no paths are provided and this option is not set, all files will be linted."
echo
echo " Note that paths with a file extension that is not '.py' will be excluded."
echo "-h"
echo " Display this help text."
}
USING_DIFF=0
files=()
while getopts ":dh" opt; do
case $opt in
d)
USING_DIFF=1
;;
h)
usage
exit
;;
\?)
echo "ERROR: Invalid option: -$OPTARG" >&2
usage
exit
;;
esac
done
# Strip any options from the command line arguments now that
# we've finished processing them
shift "$((OPTIND-1))"
if [ $USING_DIFF -eq 1 ]; then
# Check both staged and non-staged changes
for path in $(git diff HEAD --name-only); do
filename=$(basename "$path")
file_extension="${filename##*.}"
# If an extension is present, and it's something other than 'py',
# then ignore this file
if [[ -n ${file_extension+x} && $file_extension != "py" ]]; then
continue
fi
# Append this path to our list of files to lint
files+=("$path")
done
fi fi
echo "Linting these locations: $files" # Append any remaining arguments as files to lint
isort $files files+=("$@")
python3 -m black $files
if [[ $USING_DIFF -eq 1 ]]; then
# If we were asked to lint changed files, and no paths were found as a result...
if [ ${#files[@]} -eq 0 ]; then
# Then print and exit
echo "No files found to lint."
exit 0
fi
else
# If we were not asked to lint changed files, and no paths were found as a result,
# then lint everything!
if [[ -z ${files+x} ]]; then
# Lint all source code files and directories
files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py")
fi
fi
echo "Linting these paths: ${files[*]}"
echo
# Print out the commands being run
set -x
isort "${files[@]}"
python3 -m black "${files[@]}"
./scripts-dev/config-lint.sh ./scripts-dev/config-lint.sh
flake8 $files flake8 "${files[@]}"

View File

@ -15,12 +15,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import glob import glob
import os import os
from setuptools import setup, find_packages, Command
import sys
from setuptools import Command, find_packages, setup
here = os.path.abspath(os.path.dirname(__file__)) here = os.path.abspath(os.path.dirname(__file__))

View File

@ -1,13 +1,12 @@
from .sorteddict import ( from .sorteddict import SortedDict, SortedItemsView, SortedKeysView, SortedValuesView
SortedDict, from .sortedlist import SortedKeyList, SortedList, SortedListWithKey
SortedKeysView,
SortedItemsView,
SortedValuesView,
)
__all__ = [ __all__ = [
"SortedDict", "SortedDict",
"SortedKeysView", "SortedKeysView",
"SortedItemsView", "SortedItemsView",
"SortedValuesView", "SortedValuesView",
"SortedKeyList",
"SortedList",
"SortedListWithKey",
] ]

View File

@ -0,0 +1,177 @@
# stub for SortedList. This is an exact copy of
# https://github.com/grantjenks/python-sortedcontainers/blob/a419ffbd2b1c935b09f11f0971696e537fd0c510/sortedcontainers/sortedlist.pyi
# (from https://github.com/grantjenks/python-sortedcontainers/pull/107)
from typing import (
Any,
Callable,
Generic,
Iterable,
Iterator,
List,
MutableSequence,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)
_T = TypeVar("_T")
_SL = TypeVar("_SL", bound=SortedList)
_SKL = TypeVar("_SKL", bound=SortedKeyList)
_Key = Callable[[_T], Any]
_Repr = Callable[[], str]
def recursive_repr(fillvalue: str = ...) -> Callable[[_Repr], _Repr]: ...
class SortedList(MutableSequence[_T]):
DEFAULT_LOAD_FACTOR: int = ...
def __init__(
self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ...,
): ...
# NB: currently mypy does not honour return type, see mypy #3307
@overload
def __new__(cls: Type[_SL], iterable: None, key: None) -> _SL: ...
@overload
def __new__(cls: Type[_SL], iterable: None, key: _Key[_T]) -> SortedKeyList[_T]: ...
@overload
def __new__(cls: Type[_SL], iterable: Iterable[_T], key: None) -> _SL: ...
@overload
def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ...
@property
def key(self) -> Optional[Callable[[_T], Any]]: ...
def _reset(self, load: int) -> None: ...
def clear(self) -> None: ...
def _clear(self) -> None: ...
def add(self, value: _T) -> None: ...
def _expand(self, pos: int) -> None: ...
def update(self, iterable: Iterable[_T]) -> None: ...
def _update(self, iterable: Iterable[_T]) -> None: ...
def discard(self, value: _T) -> None: ...
def remove(self, value: _T) -> None: ...
def _delete(self, pos: int, idx: int) -> None: ...
def _loc(self, pos: int, idx: int) -> int: ...
def _pos(self, idx: int) -> int: ...
def _build_index(self) -> None: ...
def __contains__(self, value: Any) -> bool: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
@overload
def __getitem__(self, index: int) -> _T: ...
@overload
def __getitem__(self, index: slice) -> List[_T]: ...
@overload
def _getitem(self, index: int) -> _T: ...
@overload
def _getitem(self, index: slice) -> List[_T]: ...
@overload
def __setitem__(self, index: int, value: _T) -> None: ...
@overload
def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ...
def __iter__(self) -> Iterator[_T]: ...
def __reversed__(self) -> Iterator[_T]: ...
def __len__(self) -> int: ...
def reverse(self) -> None: ...
def islice(
self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
) -> Iterator[_T]: ...
def _islice(
self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool,
) -> Iterator[_T]: ...
def irange(
self,
minimum: Optional[int] = ...,
maximum: Optional[int] = ...,
inclusive: Tuple[bool, bool] = ...,
reverse: bool = ...,
) -> Iterator[_T]: ...
def bisect_left(self, value: _T) -> int: ...
def bisect_right(self, value: _T) -> int: ...
def bisect(self, value: _T) -> int: ...
def _bisect_right(self, value: _T) -> int: ...
def count(self, value: _T) -> int: ...
def copy(self: _SL) -> _SL: ...
def __copy__(self: _SL) -> _SL: ...
def append(self, value: _T) -> None: ...
def extend(self, values: Iterable[_T]) -> None: ...
def insert(self, index: int, value: _T) -> None: ...
def pop(self, index: int = ...) -> _T: ...
def index(
self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ...
) -> int: ...
def __add__(self: _SL, other: Iterable[_T]) -> _SL: ...
def __radd__(self: _SL, other: Iterable[_T]) -> _SL: ...
def __iadd__(self: _SL, other: Iterable[_T]) -> _SL: ...
def __mul__(self: _SL, num: int) -> _SL: ...
def __rmul__(self: _SL, num: int) -> _SL: ...
def __imul__(self: _SL, num: int) -> _SL: ...
def __eq__(self, other: Any) -> bool: ...
def __ne__(self, other: Any) -> bool: ...
def __lt__(self, other: Sequence[_T]) -> bool: ...
def __gt__(self, other: Sequence[_T]) -> bool: ...
def __le__(self, other: Sequence[_T]) -> bool: ...
def __ge__(self, other: Sequence[_T]) -> bool: ...
def __repr__(self) -> str: ...
def _check(self) -> None: ...
class SortedKeyList(SortedList[_T]):
def __init__(
self, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ...
) -> None: ...
def __new__(
cls, iterable: Optional[Iterable[_T]] = ..., key: _Key[_T] = ...
) -> SortedKeyList[_T]: ...
@property
def key(self) -> Callable[[_T], Any]: ...
def clear(self) -> None: ...
def _clear(self) -> None: ...
def add(self, value: _T) -> None: ...
def _expand(self, pos: int) -> None: ...
def update(self, iterable: Iterable[_T]) -> None: ...
def _update(self, iterable: Iterable[_T]) -> None: ...
# NB: Must be T to be safely passed to self.func, yet base class imposes Any
def __contains__(self, value: _T) -> bool: ... # type: ignore
def discard(self, value: _T) -> None: ...
def remove(self, value: _T) -> None: ...
def _delete(self, pos: int, idx: int) -> None: ...
def irange(
self,
minimum: Optional[int] = ...,
maximum: Optional[int] = ...,
inclusive: Tuple[bool, bool] = ...,
reverse: bool = ...,
): ...
def irange_key(
self,
min_key: Optional[Any] = ...,
max_key: Optional[Any] = ...,
inclusive: Tuple[bool, bool] = ...,
reserve: bool = ...,
): ...
def bisect_left(self, value: _T) -> int: ...
def bisect_right(self, value: _T) -> int: ...
def bisect(self, value: _T) -> int: ...
def bisect_key_left(self, key: Any) -> int: ...
def _bisect_key_left(self, key: Any) -> int: ...
def bisect_key_right(self, key: Any) -> int: ...
def _bisect_key_right(self, key: Any) -> int: ...
def bisect_key(self, key: Any) -> int: ...
def count(self, value: _T) -> int: ...
def copy(self: _SKL) -> _SKL: ...
def __copy__(self: _SKL) -> _SKL: ...
def index(
self, value: _T, start: Optional[int] = ..., stop: Optional[int] = ...
) -> int: ...
def __add__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
def __radd__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
def __iadd__(self: _SKL, other: Iterable[_T]) -> _SKL: ...
def __mul__(self: _SKL, num: int) -> _SKL: ...
def __rmul__(self: _SKL, num: int) -> _SKL: ...
def __imul__(self: _SKL, num: int) -> _SKL: ...
def __repr__(self) -> str: ...
def _check(self) -> None: ...
SortedListWithKey = SortedKeyList

View File

@ -48,7 +48,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.21.1" __version__ = "1.21.2"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View File

@ -14,14 +14,15 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Iterable, List, Match, Optional
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.appservice.api import ApplicationServiceApi from synapse.events import EventBase
from synapse.types import GroupID, get_domain_from_id from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.appservice.api import ApplicationServiceApi
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -32,38 +33,6 @@ class ApplicationServiceState:
UP = "up" UP = "up"
class AppServiceTransaction:
"""Represents an application service transaction."""
def __init__(self, service, id, events):
self.service = service
self.id = id
self.events = events
async def send(self, as_api: ApplicationServiceApi) -> bool:
"""Sends this transaction using the provided AS API interface.
Args:
as_api: The API to use to send.
Returns:
True if the transaction was sent.
"""
return await as_api.push_bulk(
service=self.service, events=self.events, txn_id=self.id
)
async def complete(self, store: "DataStore") -> None:
"""Completes this transaction as successful.
Marks this transaction ID on the application service and removes the
transaction contents from the database.
Args:
store: The database store to operate on.
"""
await store.complete_appservice_txn(service=self.service, txn_id=self.id)
class ApplicationService: class ApplicationService:
"""Defines an application service. This definition is mostly what is """Defines an application service. This definition is mostly what is
provided to the /register AS API. provided to the /register AS API.
@ -91,6 +60,7 @@ class ApplicationService:
protocols=None, protocols=None,
rate_limited=True, rate_limited=True,
ip_range_whitelist=None, ip_range_whitelist=None,
supports_ephemeral=False,
): ):
self.token = token self.token = token
self.url = ( self.url = (
@ -102,6 +72,7 @@ class ApplicationService:
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
self.ip_range_whitelist = ip_range_whitelist self.ip_range_whitelist = ip_range_whitelist
self.supports_ephemeral = supports_ephemeral
if "|" in self.id: if "|" in self.id:
raise Exception("application service ID cannot contain '|' character") raise Exception("application service ID cannot contain '|' character")
@ -161,19 +132,21 @@ class ApplicationService:
raise ValueError("Expected string for 'regex' in ns '%s'" % ns) raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
return namespaces return namespaces
def _matches_regex(self, test_string, namespace_key): def _matches_regex(self, test_string: str, namespace_key: str) -> Optional[Match]:
for regex_obj in self.namespaces[namespace_key]: for regex_obj in self.namespaces[namespace_key]:
if regex_obj["regex"].match(test_string): if regex_obj["regex"].match(test_string):
return regex_obj return regex_obj
return None return None
def _is_exclusive(self, ns_key, test_string): def _is_exclusive(self, ns_key: str, test_string: str) -> bool:
regex_obj = self._matches_regex(test_string, ns_key) regex_obj = self._matches_regex(test_string, ns_key)
if regex_obj: if regex_obj:
return regex_obj["exclusive"] return regex_obj["exclusive"]
return False return False
async def _matches_user(self, event, store): async def _matches_user(
self, event: Optional[EventBase], store: Optional["DataStore"] = None
) -> bool:
if not event: if not event:
return False return False
@ -188,14 +161,23 @@ class ApplicationService:
if not store: if not store:
return False return False
does_match = await self._matches_user_in_member_list(event.room_id, store) does_match = await self.matches_user_in_member_list(event.room_id, store)
return does_match return does_match
@cached(num_args=1, cache_context=True) @cached(num_args=1)
async def _matches_user_in_member_list(self, room_id, store, cache_context): async def matches_user_in_member_list(
member_list = await store.get_users_in_room( self, room_id: str, store: "DataStore"
room_id, on_invalidate=cache_context.invalidate ) -> bool:
) """Check if this service is interested a room based upon it's membership
Args:
room_id: The room to check.
store: The datastore to query.
Returns:
True if this service would like to know about this room.
"""
member_list = await store.get_users_in_room(room_id)
# check joined member events # check joined member events
for user_id in member_list: for user_id in member_list:
@ -203,12 +185,14 @@ class ApplicationService:
return True return True
return False return False
def _matches_room_id(self, event): def _matches_room_id(self, event: EventBase) -> bool:
if hasattr(event, "room_id"): if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id) return self.is_interested_in_room(event.room_id)
return False return False
async def _matches_aliases(self, event, store): async def _matches_aliases(
self, event: EventBase, store: Optional["DataStore"] = None
) -> bool:
if not store or not event: if not store or not event:
return False return False
@ -218,12 +202,15 @@ class ApplicationService:
return True return True
return False return False
async def is_interested(self, event, store=None) -> bool: async def is_interested(
self, event: EventBase, store: Optional["DataStore"] = None
) -> bool:
"""Check if this service is interested in this event. """Check if this service is interested in this event.
Args: Args:
event(Event): The event to check. event: The event to check.
store(DataStore) store: The datastore to query.
Returns: Returns:
True if this service would like to know about this event. True if this service would like to know about this event.
""" """
@ -231,39 +218,66 @@ class ApplicationService:
if self._matches_room_id(event): if self._matches_room_id(event):
return True return True
if await self._matches_aliases(event, store): # This will check the namespaces first before
# checking the store, so should be run before _matches_aliases
if await self._matches_user(event, store):
return True return True
if await self._matches_user(event, store): # This will check the store, so should be run last
if await self._matches_aliases(event, store):
return True return True
return False return False
def is_interested_in_user(self, user_id): @cached(num_args=1)
async def is_interested_in_presence(
self, user_id: UserID, store: "DataStore"
) -> bool:
"""Check if this service is interested a user's presence
Args:
user_id: The user to check.
store: The datastore to query.
Returns:
True if this service would like to know about presence for this user.
"""
# Find all the rooms the sender is in
if self.is_interested_in_user(user_id.to_string()):
return True
room_ids = await store.get_rooms_for_user(user_id.to_string())
# Then find out if the appservice is interested in any of those rooms
for room_id in room_ids:
if await self.matches_user_in_member_list(room_id, store):
return True
return False
def is_interested_in_user(self, user_id: str) -> bool:
return ( return (
self._matches_regex(user_id, ApplicationService.NS_USERS) bool(self._matches_regex(user_id, ApplicationService.NS_USERS))
or user_id == self.sender or user_id == self.sender
) )
def is_interested_in_alias(self, alias): def is_interested_in_alias(self, alias: str) -> bool:
return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES)) return bool(self._matches_regex(alias, ApplicationService.NS_ALIASES))
def is_interested_in_room(self, room_id): def is_interested_in_room(self, room_id: str) -> bool:
return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS)) return bool(self._matches_regex(room_id, ApplicationService.NS_ROOMS))
def is_exclusive_user(self, user_id): def is_exclusive_user(self, user_id: str) -> bool:
return ( return (
self._is_exclusive(ApplicationService.NS_USERS, user_id) self._is_exclusive(ApplicationService.NS_USERS, user_id)
or user_id == self.sender or user_id == self.sender
) )
def is_interested_in_protocol(self, protocol): def is_interested_in_protocol(self, protocol: str) -> bool:
return protocol in self.protocols return protocol in self.protocols
def is_exclusive_alias(self, alias): def is_exclusive_alias(self, alias: str) -> bool:
return self._is_exclusive(ApplicationService.NS_ALIASES, alias) return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
def is_exclusive_room(self, room_id): def is_exclusive_room(self, room_id: str) -> bool:
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def get_exclusive_user_regexes(self): def get_exclusive_user_regexes(self):
@ -276,14 +290,14 @@ class ApplicationService:
if regex_obj["exclusive"] if regex_obj["exclusive"]
] ]
def get_groups_for_user(self, user_id): def get_groups_for_user(self, user_id: str) -> Iterable[str]:
"""Get the groups that this user is associated with by this AS """Get the groups that this user is associated with by this AS
Args: Args:
user_id (str): The ID of the user. user_id: The ID of the user.
Returns: Returns:
iterable[str]: an iterable that yields group_id strings. An iterable that yields group_id strings.
""" """
return ( return (
regex_obj["group_id"] regex_obj["group_id"]
@ -291,7 +305,7 @@ class ApplicationService:
if "group_id" in regex_obj and regex_obj["regex"].match(user_id) if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
) )
def is_rate_limited(self): def is_rate_limited(self) -> bool:
return self.rate_limited return self.rate_limited
def __str__(self): def __str__(self):
@ -300,3 +314,45 @@ class ApplicationService:
dict_copy["token"] = "<redacted>" dict_copy["token"] = "<redacted>"
dict_copy["hs_token"] = "<redacted>" dict_copy["hs_token"] = "<redacted>"
return "ApplicationService: %s" % (dict_copy,) return "ApplicationService: %s" % (dict_copy,)
class AppServiceTransaction:
"""Represents an application service transaction."""
def __init__(
self,
service: ApplicationService,
id: int,
events: List[EventBase],
ephemeral: List[JsonDict],
):
self.service = service
self.id = id
self.events = events
self.ephemeral = ephemeral
async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface.
Args:
as_api: The API to use to send.
Returns:
True if the transaction was sent.
"""
return await as_api.push_bulk(
service=self.service,
events=self.events,
ephemeral=self.ephemeral,
txn_id=self.id,
)
async def complete(self, store: "DataStore") -> None:
"""Completes this transaction as successful.
Marks this transaction ID on the application service and removes the
transaction contents from the database.
Args:
store: The database store to operate on.
"""
await store.complete_appservice_txn(service=self.service, txn_id=self.id)

View File

@ -14,12 +14,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
import urllib import urllib
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from prometheus_client import Counter from prometheus_client import Counter
from synapse.api.constants import EventTypes, ThirdPartyEntityKind from synapse.api.constants import EventTypes, ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.events import EventBase
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.types import JsonDict, ThirdPartyInstanceID
@ -201,7 +202,13 @@ class ApplicationServiceApi(SimpleHttpClient):
key = (service.id, protocol) key = (service.id, protocol)
return await self.protocol_meta_cache.wrap(key, _get) return await self.protocol_meta_cache.wrap(key, _get)
async def push_bulk(self, service, events, txn_id=None): async def push_bulk(
self,
service: "ApplicationService",
events: List[EventBase],
ephemeral: List[JsonDict],
txn_id: Optional[int] = None,
):
if service.url is None: if service.url is None:
return True return True
@ -211,15 +218,19 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning( logger.warning(
"push_bulk: Missing txn ID sending events to %s", service.url "push_bulk: Missing txn ID sending events to %s", service.url
) )
txn_id = str(0) txn_id = 0
txn_id = str(txn_id)
uri = service.url + ("/transactions/%s" % urllib.parse.quote(str(txn_id)))
# Never send ephemeral events to appservices that do not support it
if service.supports_ephemeral:
body = {"events": events, "de.sorunome.msc2409.ephemeral": ephemeral}
else:
body = {"events": events}
uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
try: try:
await self.put_json( await self.put_json(
uri=uri, uri=uri, json_body=body, args={"access_token": service.hs_token},
json_body={"events": events},
args={"access_token": service.hs_token},
) )
sent_transactions_counter.labels(service.id).inc() sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events)) sent_events_counter.labels(service.id).inc(len(events))

View File

@ -49,10 +49,13 @@ This is all tied together by the AppServiceScheduler which DIs the required
components. components.
""" """
import logging import logging
from typing import List
from synapse.appservice import ApplicationServiceState from synapse.appservice import ApplicationService, ApplicationServiceState
from synapse.events import EventBase
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -82,8 +85,13 @@ class ApplicationServiceScheduler:
for service in services: for service in services:
self.txn_ctrl.start_recoverer(service) self.txn_ctrl.start_recoverer(service)
def submit_event_for_as(self, service, event): def submit_event_for_as(self, service: ApplicationService, event: EventBase):
self.queuer.enqueue(service, event) self.queuer.enqueue_event(service, event)
def submit_ephemeral_events_for_as(
self, service: ApplicationService, events: List[JsonDict]
):
self.queuer.enqueue_ephemeral(service, events)
class _ServiceQueuer: class _ServiceQueuer:
@ -96,17 +104,15 @@ class _ServiceQueuer:
def __init__(self, txn_ctrl, clock): def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]} self.queued_events = {} # dict of {service_id: [events]}
self.queued_ephemeral = {} # dict of {service_id: [events]}
# the appservices which currently have a transaction in flight # the appservices which currently have a transaction in flight
self.requests_in_flight = set() self.requests_in_flight = set()
self.txn_ctrl = txn_ctrl self.txn_ctrl = txn_ctrl
self.clock = clock self.clock = clock
def enqueue(self, service, event): def _start_background_request(self, service):
self.queued_events.setdefault(service.id, []).append(event)
# start a sender for this appservice if we don't already have one # start a sender for this appservice if we don't already have one
if service.id in self.requests_in_flight: if service.id in self.requests_in_flight:
return return
@ -114,7 +120,15 @@ class _ServiceQueuer:
"as-sender-%s" % (service.id,), self._send_request, service "as-sender-%s" % (service.id,), self._send_request, service
) )
async def _send_request(self, service): def enqueue_event(self, service: ApplicationService, event: EventBase):
self.queued_events.setdefault(service.id, []).append(event)
self._start_background_request(service)
def enqueue_ephemeral(self, service: ApplicationService, events: List[JsonDict]):
self.queued_ephemeral.setdefault(service.id, []).extend(events)
self._start_background_request(service)
async def _send_request(self, service: ApplicationService):
# sanity-check: we shouldn't get here if this service already has a sender # sanity-check: we shouldn't get here if this service already has a sender
# running. # running.
assert service.id not in self.requests_in_flight assert service.id not in self.requests_in_flight
@ -123,10 +137,11 @@ class _ServiceQueuer:
try: try:
while True: while True:
events = self.queued_events.pop(service.id, []) events = self.queued_events.pop(service.id, [])
if not events: ephemeral = self.queued_ephemeral.pop(service.id, [])
if not events and not ephemeral:
return return
try: try:
await self.txn_ctrl.send(service, events) await self.txn_ctrl.send(service, events, ephemeral)
except Exception: except Exception:
logger.exception("AS request failed") logger.exception("AS request failed")
finally: finally:
@ -158,9 +173,16 @@ class _TransactionController:
# for UTs # for UTs
self.RECOVERER_CLASS = _Recoverer self.RECOVERER_CLASS = _Recoverer
async def send(self, service, events): async def send(
self,
service: ApplicationService,
events: List[EventBase],
ephemeral: List[JsonDict] = [],
):
try: try:
txn = await self.store.create_appservice_txn(service=service, events=events) txn = await self.store.create_appservice_txn(
service=service, events=events, ephemeral=ephemeral
)
service_is_up = await self._is_service_up(service) service_is_up = await self._is_service_up(service)
if service_is_up: if service_is_up:
sent = await txn.send(self.as_api) sent = await txn.send(self.as_api)
@ -204,7 +226,7 @@ class _TransactionController:
recoverer.recover() recoverer.recover()
logger.info("Now %i active recoverers", len(self.recoverers)) logger.info("Now %i active recoverers", len(self.recoverers))
async def _is_service_up(self, service): async def _is_service_up(self, service: ApplicationService) -> bool:
state = await self.store.get_appservice_state(service) state = await self.store.get_appservice_state(service)
return state == ApplicationServiceState.UP or state is None return state == ApplicationServiceState.UP or state is None

View File

@ -160,6 +160,8 @@ def _load_appservice(hostname, as_info, config_filename):
if as_info.get("ip_range_whitelist"): if as_info.get("ip_range_whitelist"):
ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist")) ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist"))
supports_ephemeral = as_info.get("de.sorunome.msc2409.push_ephemeral", False)
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
hostname=hostname, hostname=hostname,
@ -168,6 +170,7 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"], hs_token=as_info["hs_token"],
sender=user_id, sender=user_id,
id=as_info["id"], id=as_info["id"],
supports_ephemeral=supports_ephemeral,
protocols=protocols, protocols=protocols,
rate_limited=rate_limited, rate_limited=rate_limited,
ip_range_whitelist=ip_range_whitelist, ip_range_whitelist=ip_range_whitelist,

View File

@ -312,6 +312,12 @@ class EventBase(metaclass=abc.ABCMeta):
""" """
return [e for e, _ in self.auth_events] return [e for e, _ in self.auth_events]
def freeze(self):
"""'Freeze' the event dict, so it cannot be modified by accident"""
# this will be a no-op if the event dict is already frozen.
self._dict = freeze(self._dict)
class FrozenEvent(EventBase): class FrozenEvent(EventBase):
format_version = EventFormatVersions.V1 # All events of this type are V1 format_version = EventFormatVersions.V1 # All events of this type are V1

View File

@ -97,32 +97,37 @@ class EventBuilder:
def is_state(self): def is_state(self):
return self._state_key is not None return self._state_key is not None
async def build(self, prev_event_ids: List[str]) -> EventBase: async def build(
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]]
) -> EventBase:
"""Transform into a fully signed and hashed event """Transform into a fully signed and hashed event
Args: Args:
prev_event_ids: The event IDs to use as the prev events prev_event_ids: The event IDs to use as the prev events
auth_event_ids: The event IDs to use as the auth events.
Should normally be set to None, which will cause them to be calculated
based on the room state at the prev_events.
Returns: Returns:
The signed and hashed event. The signed and hashed event.
""" """
if auth_event_ids is None:
state_ids = await self._state.get_current_state_ids( state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids self.room_id, prev_event_ids
) )
auth_ids = self._auth.compute_auth_events(self, state_ids) auth_event_ids = self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1: if format_version == EventFormatVersions.V1:
# The types of auth/prev events changes between event versions. # The types of auth/prev events changes between event versions.
auth_events = await self._store.add_event_hashes( auth_events = await self._store.add_event_hashes(
auth_ids auth_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events = await self._store.add_event_hashes( prev_events = await self._store.add_event_hashes(
prev_event_ids prev_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
else: else:
auth_events = auth_ids auth_events = auth_event_ids
prev_events = prev_event_ids prev_events = prev_event_ids
old_depth = await self._store.get_max_depth_of(prev_event_ids) old_depth = await self._store.get_max_depth_of(prev_event_ids)

View File

@ -12,7 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Callable
from typing import Callable, Union
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
@ -44,15 +45,20 @@ class ThirdPartyEventRules:
async def check_event_allowed( async def check_event_allowed(
self, event: EventBase, context: EventContext self, event: EventBase, context: EventContext
) -> bool: ) -> Union[bool, dict]:
"""Check if a provided event should be allowed in the given context. """Check if a provided event should be allowed in the given context.
The module can return:
* True: the event is allowed.
* False: the event is not allowed, and should be rejected with M_FORBIDDEN.
* a dict: replacement event data.
Args: Args:
event: The event to be checked. event: The event to be checked.
context: The context of the event. context: The context of the event.
Returns: Returns:
True if the event should be allowed, False if not. The result from the ThirdPartyRules module, as above
""" """
if self.third_party_rules is None: if self.third_party_rules is None:
return True return True
@ -63,9 +69,10 @@ class ThirdPartyEventRules:
events = await self.store.get_events(prev_state_ids.values()) events = await self.store.get_events(prev_state_ids.values())
state_events = {(ev.type, ev.state_key): ev for ev in events.values()} state_events = {(ev.type, ev.state_key): ev for ev in events.values()}
# The module can modify the event slightly if it wants, but caution should be # Ensure that the event is frozen, to make sure that the module is not tempted
# exercised, and it's likely to go very wrong if applied to events received over # to try to modify it. Any attempt to modify it at this point will invalidate
# federation. # the hashes and signatures.
event.freeze()
return await self.third_party_rules.check_event_allowed(event, state_events) return await self.third_party_rules.check_event_allowed(event, state_events)

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, List, Optional
from prometheus_client import Counter from prometheus_client import Counter
@ -21,13 +22,16 @@ from twisted.internet import defer
import synapse import synapse
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import ( from synapse.metrics import (
event_processing_loop_counter, event_processing_loop_counter,
event_processing_loop_room_count, event_processing_loop_room_count,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import RoomStreamToken from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,6 +48,7 @@ class ApplicationServicesHandler:
self.started_scheduler = False self.started_scheduler = False
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices self.notify_appservices = hs.config.notify_appservices
self.event_sources = hs.get_event_sources()
self.current_max = 0 self.current_max = 0
self.is_processing = False self.is_processing = False
@ -82,7 +87,7 @@ class ApplicationServicesHandler:
if not events: if not events:
break break
events_by_room = {} events_by_room = {} # type: Dict[str, List[EventBase]]
for event in events: for event in events:
events_by_room.setdefault(event.room_id, []).append(event) events_by_room.setdefault(event.room_id, []).append(event)
@ -161,6 +166,104 @@ class ApplicationServicesHandler:
finally: finally:
self.is_processing = False self.is_processing = False
async def notify_interested_services_ephemeral(
self, stream_key: str, new_token: Optional[int], users: Collection[UserID] = [],
):
"""This is called by the notifier in the background
when a ephemeral event handled by the homeserver.
This will determine which appservices
are interested in the event, and submit them.
Events will only be pushed to appservices
that have opted into ephemeral events
Args:
stream_key: The stream the event came from.
new_token: The latest stream token
users: The user(s) involved with the event.
"""
services = [
service
for service in self.store.get_app_services()
if service.supports_ephemeral
]
if not services or not self.notify_appservices:
return
logger.info("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
# Only handle typing if we have the latest token
if stream_key == "typing_key" and new_token is not None:
events = await self._handle_typing(service, new_token)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
# We don't persist the token for typing_key for performance reasons
elif stream_key == "receipt_key":
events = await self._handle_receipts(service)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
await self.store.set_type_stream_id_for_appservice(
service, "read_receipt", new_token
)
elif stream_key == "presence_key":
events = await self._handle_presence(service, users)
if events:
self.scheduler.submit_ephemeral_events_for_as(service, events)
await self.store.set_type_stream_id_for_appservice(
service, "presence", new_token
)
async def _handle_typing(self, service: ApplicationService, new_token: int):
typing_source = self.event_sources.sources["typing"]
# Get the typing events from just before current
typing, _ = await typing_source.get_new_events_as(
service=service,
# For performance reasons, we don't persist the previous
# token in the DB and instead fetch the latest typing information
# for appservices.
from_key=new_token - 1,
)
return typing
async def _handle_receipts(self, service: ApplicationService):
from_key = await self.store.get_type_stream_id_for_appservice(
service, "read_receipt"
)
receipts_source = self.event_sources.sources["receipt"]
receipts, _ = await receipts_source.get_new_events_as(
service=service, from_key=from_key
)
return receipts
async def _handle_presence(
self, service: ApplicationService, users: Collection[UserID]
):
events = [] # type: List[JsonDict]
presence_source = self.event_sources.sources["presence"]
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
)
for user in users:
interested = await service.is_interested_in_presence(user, self.store)
if not interested:
continue
presence_events, _ = await presence_source.get_new_events(
user=user, service=service, from_key=from_key,
)
time_now = self.clock.time_msec()
presence_events = [
{
"type": "m.presence",
"sender": event.user_id,
"content": format_user_presence_state(
event, time_now, include_user_id=False
),
}
for event in presence_events
]
events = events + presence_events
async def query_user_exists(self, user_id): async def query_user_exists(self, user_id):
"""Check if any application service knows this user_id exists. """Check if any application service knows this user_id exists.
@ -223,7 +326,7 @@ class ApplicationServicesHandler:
async def get_3pe_protocols(self, only_protocol=None): async def get_3pe_protocols(self, only_protocol=None):
services = self.store.get_app_services() services = self.store.get_app_services()
protocols = {} protocols = {} # type: Dict[str, List[JsonDict]]
# Collect up all the individual protocol responses out of the ASes # Collect up all the individual protocol responses out of the ASes
for s in services: for s in services:

View File

@ -1507,18 +1507,9 @@ class FederationHandler(BaseHandler):
event, context = await self.event_creation_handler.create_new_client_event( event, context = await self.event_creation_handler.create_new_client_event(
builder=builder builder=builder
) )
except AuthError as e: except SynapseError as e:
logger.warning("Failed to create join to %s because %s", room_id, e) logger.warning("Failed to create join to %s because %s", room_id, e)
raise e raise
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.info("Creation of join %s forbidden by third-party rules", event)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
# The remote hasn't signed it yet, obviously. We'll do the full checks # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request` # when we get the event back in `on_send_join_request`
@ -1567,15 +1558,6 @@ class FederationHandler(BaseHandler):
context = await self._handle_new_event(origin, event) context = await self._handle_new_event(origin, event)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.info("Sending of join %s forbidden by third-party rules", event)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
logger.debug( logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s", "on_send_join_request: After _handle_new_event: %s, sigs: %s",
event.event_id, event.event_id,
@ -1748,15 +1730,6 @@ class FederationHandler(BaseHandler):
builder=builder builder=builder
) )
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.warning("Creation of leave %s forbidden by third-party rules", event)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
try: try:
# The remote hasn't signed it yet, obviously. We'll do the full checks # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request` # when we get the event back in `on_send_leave_request`
@ -1789,16 +1762,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False event.internal_metadata.outlier = False
context = await self._handle_new_event(origin, event) await self._handle_new_event(origin, event)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.info("Sending of leave %s forbidden by third-party rules", event)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
logger.debug( logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s", "on_send_leave_request: After _handle_new_event: %s, sigs: %s",
@ -2694,18 +2658,6 @@ class FederationHandler(BaseHandler):
builder=builder builder=builder
) )
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.info(
"Creation of threepid invite %s forbidden by third-party rules",
event,
)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
event, context = await self.add_display_name_to_third_party_invite( event, context = await self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context room_version, event_dict, event, context
) )
@ -2756,18 +2708,6 @@ class FederationHandler(BaseHandler):
event, context = await self.event_creation_handler.create_new_client_event( event, context = await self.event_creation_handler.create_new_client_event(
builder=builder builder=builder
) )
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
logger.warning(
"Exchange of threepid invite %s forbidden by third-party rules", event
)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
event, context = await self.add_display_name_to_third_party_invite( event, context = await self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context room_version, event_dict, event, context
) )

View File

@ -437,9 +437,9 @@ class EventCreationHandler:
self, self,
requester: Requester, requester: Requester,
event_dict: dict, event_dict: dict,
token_id: Optional[str] = None,
txn_id: Optional[str] = None, txn_id: Optional[str] = None,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
require_consent: bool = True, require_consent: bool = True,
) -> Tuple[EventBase, EventContext]: ) -> Tuple[EventBase, EventContext]:
""" """
@ -453,13 +453,18 @@ class EventCreationHandler:
Args: Args:
requester requester
event_dict: An entire event event_dict: An entire event
token_id
txn_id txn_id
prev_event_ids: prev_event_ids:
the forward extremities to use as the prev_events for the the forward extremities to use as the prev_events for the
new event. new event.
If None, they will be requested from the database. If None, they will be requested from the database.
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
require_consent: Whether to check if the requester has require_consent: Whether to check if the requester has
consented to the privacy policy. consented to the privacy policy.
Raises: Raises:
@ -511,14 +516,17 @@ class EventCreationHandler:
if require_consent and not is_exempt: if require_consent and not is_exempt:
await self.assert_accepted_privacy_policy(requester) await self.assert_accepted_privacy_policy(requester)
if token_id is not None: if requester.access_token_id is not None:
builder.internal_metadata.token_id = token_id builder.internal_metadata.token_id = requester.access_token_id
if txn_id is not None: if txn_id is not None:
builder.internal_metadata.txn_id = txn_id builder.internal_metadata.txn_id = txn_id
event, context = await self.create_new_client_event( event, context = await self.create_new_client_event(
builder=builder, requester=requester, prev_event_ids=prev_event_ids, builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
) )
# In an ideal world we wouldn't need the second part of this condition. However, # In an ideal world we wouldn't need the second part of this condition. However,
@ -726,7 +734,7 @@ class EventCreationHandler:
return event, event.internal_metadata.stream_ordering return event, event.internal_metadata.stream_ordering
event, context = await self.create_event( event, context = await self.create_event(
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id requester, event_dict, txn_id=txn_id
) )
assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
@ -757,6 +765,7 @@ class EventCreationHandler:
builder: EventBuilder, builder: EventBuilder,
requester: Optional[Requester] = None, requester: Optional[Requester] = None,
prev_event_ids: Optional[List[str]] = None, prev_event_ids: Optional[List[str]] = None,
auth_event_ids: Optional[List[str]] = None,
) -> Tuple[EventBase, EventContext]: ) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client """Create a new event for a local client
@ -769,6 +778,11 @@ class EventCreationHandler:
If None, they will be requested from the database. If None, they will be requested from the database.
auth_event_ids:
The event ids to use as the auth_events for the new event.
Should normally be left as None, which will cause them to be calculated
based on the room state at the prev_events.
Returns: Returns:
Tuple of created event, context Tuple of created event, context
""" """
@ -790,11 +804,30 @@ class EventCreationHandler:
builder.type == EventTypes.Create or len(prev_event_ids) > 0 builder.type == EventTypes.Create or len(prev_event_ids) > 0
), "Attempting to create an event with no prev_events" ), "Attempting to create an event with no prev_events"
event = await builder.build(prev_event_ids=prev_event_ids) event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
)
context = await self.state.compute_event_context(event) context = await self.state.compute_event_context(event)
if requester: if requester:
context.app_service = requester.app_service context.app_service = requester.app_service
third_party_result = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not third_party_result:
logger.info(
"Event %s forbidden by third-party rules", event,
)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
elif isinstance(third_party_result, dict):
# the third-party rules want to replace the event. We'll need to build a new
# event.
event, context = await self._rebuild_event_after_third_party_rules(
third_party_result, event
)
self.validator.validate_new(event, self.config) self.validator.validate_new(event, self.config)
# If this event is an annotation then we check that that the sender # If this event is an annotation then we check that that the sender
@ -881,14 +914,6 @@ class EventCreationHandler:
else: else:
room_version = await self.store.get_room_version_id(event.room_id) room_version = await self.store.get_room_version_id(event.room_id)
event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
if event.internal_metadata.is_out_of_band_membership(): if event.internal_metadata.is_out_of_band_membership():
# the only sort of out-of-band-membership events we expect to see here # the only sort of out-of-band-membership events we expect to see here
# are invite rejections we have generated ourselves. # are invite rejections we have generated ourselves.
@ -1291,3 +1316,57 @@ class EventCreationHandler:
room_id, room_id,
) )
del self._rooms_to_exclude_from_dummy_event_insertion[room_id] del self._rooms_to_exclude_from_dummy_event_insertion[room_id]
async def _rebuild_event_after_third_party_rules(
self, third_party_result: dict, original_event: EventBase
) -> Tuple[EventBase, EventContext]:
# the third_party_event_rules want to replace the event.
# we do some basic checks, and then return the replacement event and context.
# Construct a new EventBuilder and validate it, which helps with the
# rest of these checks.
try:
builder = self.event_builder_factory.for_room_version(
original_event.room_version, third_party_result
)
self.validator.validate_builder(builder)
except SynapseError as e:
raise Exception(
"Third party rules module created an invalid event: " + e.msg,
)
immutable_fields = [
# changing the room is going to break things: we've already checked that the
# room exists, and are holding a concurrency limiter token for that room.
# Also, we might need to use a different room version.
"room_id",
# changing the type or state key might work, but we'd need to check that the
# calling functions aren't making assumptions about them.
"type",
"state_key",
]
for k in immutable_fields:
if getattr(builder, k, None) != original_event.get(k):
raise Exception(
"Third party rules module created an invalid event: "
"cannot change field " + k
)
# check that the new sender belongs to this HS
if not self.hs.is_mine_id(builder.sender):
raise Exception(
"Third party rules module created an invalid event: "
"invalid sender " + builder.sender
)
# copy over the original internal metadata
for k, v in original_event.internal_metadata.get_dict().items():
setattr(builder.internal_metadata, k, v)
event = await builder.build(prev_event_ids=original_event.prev_event_ids())
# we rebuild the event context, to be on the safe side. If nothing else,
# delta_ids might need an update.
context = await self.state.compute_event_context(event)
return event, context

View File

@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Tuple
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -140,5 +142,36 @@ class ReceiptEventSource:
return (events, to_key) return (events, to_key)
async def get_new_events_as(
self, from_key: int, service: ApplicationService
) -> Tuple[List[JsonDict], int]:
"""Returns a set of new receipt events that an appservice
may be interested in.
Args:
from_key: the stream position at which events should be fetched from
service: The appservice which may be interested
"""
from_key = int(from_key)
to_key = self.get_current_key()
if from_key == to_key:
return [], to_key
# We first need to fetch all new receipts
rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
from_key=from_key, to_key=to_key
)
# Then filter down to rooms that the AS can read
events = []
for room_id, event in rooms_to_events.items():
if not await service.matches_user_in_member_list(room_id, self.store):
continue
events.append(event)
return (events, to_key)
def get_current_key(self, direction="f"): def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id() return self.store.get_max_receipt_stream_id()

View File

@ -214,7 +214,6 @@ class RoomCreationHandler(BaseHandler):
"replacement_room": new_room_id, "replacement_room": new_room_id,
}, },
}, },
token_id=requester.access_token_id,
) )
old_room_version = await self.store.get_room_version_id(old_room_id) old_room_version = await self.store.get_room_version_id(old_room_id)
await self.auth.check_from_context( await self.auth.check_from_context(

View File

@ -17,12 +17,10 @@ import abc
import logging import logging
import random import random
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from unpaddedbase64 import encode_base64
from synapse import types from synapse import types
from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -31,12 +29,8 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import EventFormatVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -194,7 +188,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# For backwards compatibility: # For backwards compatibility:
"membership": membership, "membership": membership,
}, },
token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
require_consent=require_consent, require_consent=require_consent,
@ -1153,31 +1146,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
room_id = invite_event.room_id room_id = invite_event.room_id
target_user = invite_event.state_key target_user = invite_event.state_key
room_version = await self.store.get_room_version(room_id)
content["membership"] = Membership.LEAVE content["membership"] = Membership.LEAVE
# the auth events for the new event are the same as that of the invite, plus
# the invite itself.
#
# the prev_events are just the invite.
invite_hash = invite_event.event_id # type: Union[str, Tuple]
if room_version.event_format == EventFormatVersions.V1:
alg, h = compute_event_reference_hash(invite_event)
invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
auth_events = tuple(invite_event.auth_events) + (invite_hash,)
prev_events = (invite_hash,)
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(invite_event.depth + 1, MAX_DEPTH)
event_dict = { event_dict = {
"depth": depth,
"auth_events": auth_events,
"prev_events": prev_events,
"type": EventTypes.Member, "type": EventTypes.Member,
"room_id": room_id, "room_id": room_id,
"sender": target_user, "sender": target_user,
@ -1185,24 +1157,23 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user, "state_key": target_user,
} }
event = create_local_event_from_event_dict( # the auth events for the new event are the same as that of the invite, plus
clock=self.clock, # the invite itself.
hostname=self.hs.hostname, #
signing_key=self.hs.signing_key, # the prev_events are just the invite.
room_version=room_version, prev_event_ids = [invite_event.event_id]
event_dict=event_dict, auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event(
requester,
event_dict,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
) )
event.internal_metadata.outlier = True event.internal_metadata.outlier = True
event.internal_metadata.out_of_band_membership = True event.internal_metadata.out_of_band_membership = True
if txn_id is not None:
event.internal_metadata.txn_id = txn_id
if requester.access_token_id is not None:
event.internal_metadata.token_id = requester.access_token_id
EventValidator().validate_new(event, self.config)
context = await self.state_handler.compute_event_context(event)
context.app_service = requester.app_service
result_event = await self.event_creation_handler.handle_new_client_event( result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[UserID.from_string(target_user)], requester, event, context, extra_users=[UserID.from_string(target_user)],
) )

View File

@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools import itertools
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple

View File

@ -12,16 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import random import random
from collections import namedtuple from collections import namedtuple
from typing import TYPE_CHECKING, List, Set, Tuple from typing import TYPE_CHECKING, List, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams import TypingStream
from synapse.types import UserID, get_domain_from_id from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer from synapse.util.wheel_timer import WheelTimer
@ -430,6 +430,33 @@ class TypingNotificationEventSource:
"content": {"user_ids": list(typing)}, "content": {"user_ids": list(typing)},
} }
async def get_new_events_as(
self, from_key: int, service: ApplicationService
) -> Tuple[List[JsonDict], int]:
"""Returns a set of new typing events that an appservice
may be interested in.
Args:
from_key: the stream position at which events should be fetched from
service: The appservice which may be interested
"""
with Measure(self.clock, "typing.get_new_events_as"):
from_key = int(from_key)
handler = self.get_typing_handler()
events = []
for room_id in handler._room_serials.keys():
if handler._room_serials[room_id] <= from_key:
continue
if not await service.matches_user_in_member_list(
room_id, handler.store
):
continue
events.append(self._make_event_for(room_id))
return (events, handler._latest_room_serial)
async def get_new_events(self, from_key, room_ids, **kwargs): async def get_new_events(self, from_key, room_ids, **kwargs):
with Measure(self.clock, "typing.get_new_events"): with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key) from_key = int(from_key)

View File

@ -329,6 +329,22 @@ class Notifier:
except Exception: except Exception:
logger.exception("Error notifying application services of event") logger.exception("Error notifying application services of event")
async def _notify_app_services_ephemeral(
self,
stream_key: str,
new_token: Union[int, RoomStreamToken],
users: Collection[UserID] = [],
):
try:
stream_token = None
if isinstance(new_token, int):
stream_token = new_token
await self.appservice_handler.notify_interested_services_ephemeral(
stream_key, stream_token, users
)
except Exception:
logger.exception("Error notifying application services of event")
async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken): async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
try: try:
await self._pusher_pool.on_new_notifications(max_room_stream_token) await self._pusher_pool.on_new_notifications(max_room_stream_token)
@ -367,6 +383,15 @@ class Notifier:
self.notify_replication() self.notify_replication()
# Notify appservices
run_as_background_process(
"_notify_app_services_ephemeral",
self._notify_app_services_ephemeral,
stream_key,
new_token,
users,
)
def on_new_replication_data(self) -> None: def on_new_replication_data(self) -> None:
"""Used to inform replication listeners that something has happend """Used to inform replication listeners that something has happend
without waking up any of the normal user event streams""" without waking up any of the normal user event streams"""

View File

@ -15,7 +15,7 @@
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.descriptors import Cache from synapse.util.caches.deferred_cache import DeferredCache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache( self.client_ip_last_seen = DeferredCache(
name="client_ip_last_seen", keylen=4, max_entries=50000 name="client_ip_last_seen", keylen=4, max_entries=50000
) ) # type: DeferredCache[tuple, int]
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())

View File

@ -205,7 +205,13 @@ class HomeServer(metaclass=abc.ABCMeta):
# instantiated during setup() for future return by get_datastore() # instantiated during setup() for future return by get_datastore()
DATASTORE_CLASS = abc.abstractproperty() DATASTORE_CLASS = abc.abstractproperty()
def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs): def __init__(
self,
hostname: str,
config: HomeServerConfig,
reactor=None,
version_string="Synapse",
):
""" """
Args: Args:
hostname : The hostname for the server. hostname : The hostname for the server.
@ -236,11 +242,9 @@ class HomeServer(metaclass=abc.ABCMeta):
burst_count=config.rc_registration.burst_count, burst_count=config.rc_registration.burst_count,
) )
self.datastores = None # type: Optional[Databases] self.version_string = version_string
# Other kwargs are explicit dependencies self.datastores = None # type: Optional[Databases]
for depname in kwargs:
setattr(self, depname, kwargs[depname])
def get_instance_id(self) -> str: def get_instance_id(self) -> str:
"""A unique ID for this synapse process instance. """A unique ID for this synapse process instance.

View File

@ -893,6 +893,12 @@ class DatabasePool:
attempts = 0 attempts = 0
while True: while True:
try: try:
# We can autocommit if we are going to use native upserts
autocommit = (
self.engine.can_native_upsert
and table not in self._unsafe_to_upsert_tables
)
return await self.runInteraction( return await self.runInteraction(
desc, desc,
self.simple_upsert_txn, self.simple_upsert_txn,
@ -901,6 +907,7 @@ class DatabasePool:
values, values,
insertion_values, insertion_values,
lock=lock, lock=lock,
db_autocommit=autocommit,
) )
except self.engine.module.IntegrityError as e: except self.engine.module.IntegrityError as e:
attempts += 1 attempts += 1
@ -1063,6 +1070,43 @@ class DatabasePool:
) )
txn.execute(sql, list(allvalues.values())) txn.execute(sql, list(allvalues.values()))
async def simple_upsert_many(
self,
table: str,
key_names: Collection[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
desc: str,
) -> None:
"""
Upsert, many times.
Args:
table: The table to upsert into
key_names: The key column names.
key_values: A list of each row's key column values.
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
"""
# We can autocommit if we are going to use native upserts
autocommit = (
self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
)
return await self.runInteraction(
desc,
self.simple_upsert_many_txn,
table,
key_names,
key_values,
value_names,
value_values,
db_autocommit=autocommit,
)
def simple_upsert_many_txn( def simple_upsert_many_txn(
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
@ -1214,7 +1258,13 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics
""" """
return await self.runInteraction( return await self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none desc,
self.simple_select_one_txn,
table,
keyvalues,
retcols,
allow_none,
db_autocommit=True,
) )
@overload @overload
@ -1265,6 +1315,7 @@ class DatabasePool:
keyvalues, keyvalues,
retcol, retcol,
allow_none=allow_none, allow_none=allow_none,
db_autocommit=True,
) )
@overload @overload
@ -1346,7 +1397,12 @@ class DatabasePool:
Results in a list Results in a list
""" """
return await self.runInteraction( return await self.runInteraction(
desc, self.simple_select_onecol_txn, table, keyvalues, retcol desc,
self.simple_select_onecol_txn,
table,
keyvalues,
retcol,
db_autocommit=True,
) )
async def simple_select_list( async def simple_select_list(
@ -1371,7 +1427,12 @@ class DatabasePool:
A list of dictionaries. A list of dictionaries.
""" """
return await self.runInteraction( return await self.runInteraction(
desc, self.simple_select_list_txn, table, keyvalues, retcols desc,
self.simple_select_list_txn,
table,
keyvalues,
retcols,
db_autocommit=True,
) )
@classmethod @classmethod
@ -1450,6 +1511,7 @@ class DatabasePool:
chunk, chunk,
keyvalues, keyvalues,
retcols, retcols,
db_autocommit=True,
) )
results.extend(rows) results.extend(rows)
@ -1548,7 +1610,12 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics
""" """
await self.runInteraction( await self.runInteraction(
desc, self.simple_update_one_txn, table, keyvalues, updatevalues desc,
self.simple_update_one_txn,
table,
keyvalues,
updatevalues,
db_autocommit=True,
) )
@classmethod @classmethod
@ -1607,7 +1674,9 @@ class DatabasePool:
keyvalues: dict of column names and values to select the row with keyvalues: dict of column names and values to select the row with
desc: description of the transaction, for logging and metrics desc: description of the transaction, for logging and metrics
""" """
await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues) await self.runInteraction(
desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
)
@staticmethod @staticmethod
def simple_delete_one_txn( def simple_delete_one_txn(
@ -1646,7 +1715,9 @@ class DatabasePool:
Returns: Returns:
The number of deleted rows. The number of deleted rows.
""" """
return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues) return await self.runInteraction(
desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True
)
@staticmethod @staticmethod
def simple_delete_txn( def simple_delete_txn(
@ -1694,7 +1765,13 @@ class DatabasePool:
Number rows deleted Number rows deleted
""" """
return await self.runInteraction( return await self.runInteraction(
desc, self.simple_delete_many_txn, table, column, iterable, keyvalues desc,
self.simple_delete_many_txn,
table,
column,
iterable,
keyvalues,
db_autocommit=True,
) )
@staticmethod @staticmethod
@ -1860,7 +1937,13 @@ class DatabasePool:
""" """
return await self.runInteraction( return await self.runInteraction(
desc, self.simple_search_list_txn, table, term, col, retcols desc,
self.simple_search_list_txn,
table,
term,
col,
retcols,
db_autocommit=True,
) )
@classmethod @classmethod

View File

@ -15,12 +15,15 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from typing import List
from synapse.appservice import AppServiceTransaction from synapse.appservice import ApplicationService, AppServiceTransaction
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -172,15 +175,23 @@ class ApplicationServiceTransactionWorkerStore(
"application_services_state", {"as_id": service.id}, {"state": state} "application_services_state", {"as_id": service.id}, {"state": state}
) )
async def create_appservice_txn(self, service, events): async def create_appservice_txn(
self,
service: ApplicationService,
events: List[EventBase],
ephemeral: List[JsonDict],
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service """Atomically creates a new transaction for this application service
with the given list of events. with the given list of events. Ephemeral events are NOT persisted to the
database and are not resent if a transaction is retried.
Args: Args:
service(ApplicationService): The service who the transaction is for. service: The service who the transaction is for.
events(list<Event>): A list of events to put in the transaction. events: A list of persistent events to put in the transaction.
ephemeral: A list of ephemeral events to put in the transaction.
Returns: Returns:
AppServiceTransaction: A new transaction. A new transaction.
""" """
def _create_appservice_txn(txn): def _create_appservice_txn(txn):
@ -207,7 +218,9 @@ class ApplicationServiceTransactionWorkerStore(
"VALUES(?,?,?)", "VALUES(?,?,?)",
(service.id, new_txn_id, event_ids), (service.id, new_txn_id, event_ids),
) )
return AppServiceTransaction(service=service, id=new_txn_id, events=events) return AppServiceTransaction(
service=service, id=new_txn_id, events=events, ephemeral=ephemeral
)
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"create_appservice_txn", _create_appservice_txn "create_appservice_txn", _create_appservice_txn
@ -296,7 +309,9 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids) events = await self.get_events_as_list(event_ids)
return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) return AppServiceTransaction(
service=service, id=entry["txn_id"], events=events, ephemeral=[]
)
def _get_last_txn(self, txn, service_id): def _get_last_txn(self, txn, service_id):
txn.execute( txn.execute(
@ -320,7 +335,7 @@ class ApplicationServiceTransactionWorkerStore(
) )
async def get_new_events_for_appservice(self, current_id, limit): async def get_new_events_for_appservice(self, current_id, limit):
"""Get all new evnets""" """Get all new events for an appservice"""
def get_new_events_for_appservice_txn(txn): def get_new_events_for_appservice_txn(txn):
sql = ( sql = (
@ -351,6 +366,39 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, events return upper_bound, events
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
def get_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type
txn.execute(
"SELECT ? FROM application_services_state WHERE as_id=?",
(stream_id_type, service.id,),
)
last_txn_id = txn.fetchone()
if last_txn_id is None or last_txn_id[0] is None: # no row exists
return 0
else:
return int(last_txn_id[0])
return await self.db_pool.runInteraction(
"get_type_stream_id_for_appservice", get_type_stream_id_for_appservice_txn
)
async def set_type_stream_id_for_appservice(
self, service: ApplicationService, type: str, pos: int
) -> None:
def set_type_stream_id_for_appservice_txn(txn):
stream_id_type = "%s_stream_id" % type
txn.execute(
"UPDATE ? SET device_list_stream_id = ? WHERE as_id=?",
(stream_id_type, pos, service.id),
)
await self.db_pool.runInteraction(
"set_type_stream_id_for_appservice", set_type_stream_id_for_appservice_txn
)
class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
# This is currently empty due to there not being any AS storage functions # This is currently empty due to there not being any AS storage functions

View File

@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.util.caches.descriptors import Cache from synapse.util.caches.deferred_cache import DeferredCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -410,7 +410,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
class ClientIpStore(ClientIpWorkerStore): class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = Cache( self.client_ip_last_seen = DeferredCache(
name="client_ip_last_seen", keylen=4, max_entries=50000 name="client_ip_last_seen", keylen=4, max_entries=50000
) )

View File

@ -34,7 +34,8 @@ from synapse.storage.database import (
) )
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import Cache, cached, cachedList from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr from synapse.util.stringutils import shortstr
@ -1004,7 +1005,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies # Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists. # the device exists.
self.device_id_exists_cache = Cache( self.device_id_exists_cache = DeferredCache(
name="device_id_exists", keylen=2, max_entries=10000 name="device_id_exists", keylen=2, max_entries=10000
) )

View File

@ -42,7 +42,8 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -145,7 +146,7 @@ class EventsWorkerStore(SQLBaseStore):
self._cleanup_old_transaction_ids, self._cleanup_old_transaction_ids,
) )
self._get_event_cache = Cache( self._get_event_cache = DeferredCache(
"*getEvent*", "*getEvent*",
keylen=3, keylen=3,
max_entries=hs.config.caches.event_cache_size, max_entries=hs.config.caches.event_cache_size,

View File

@ -122,9 +122,7 @@ class KeyStore(SQLBaseStore):
# param, which is itself the 2-tuple (server_name, key_id). # param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id)) invalidations.append((server_name, key_id))
await self.db_pool.runInteraction( await self.db_pool.simple_upsert_many(
"store_server_verify_keys",
self.db_pool.simple_upsert_many_txn,
table="server_signature_keys", table="server_signature_keys",
key_names=("server_name", "key_id"), key_names=("server_name", "key_id"),
key_values=key_values, key_values=key_values,
@ -135,6 +133,7 @@ class KeyStore(SQLBaseStore):
"verify_key", "verify_key",
), ),
value_values=value_values, value_values=value_values,
desc="store_server_verify_keys",
) )
invalidate = self._get_server_verify_key.invalidate invalidate = self._get_server_verify_key.invalidate

View File

@ -281,9 +281,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
a_day_in_milliseconds = 24 * 60 * 60 * 1000 a_day_in_milliseconds = 24 * 60 * 60 * 1000
now = self._clock.time_msec() now = self._clock.time_msec()
# A note on user_agent. Technically a given device can have multiple
# user agents, so we need to decide which one to pick. We could have handled this
# in number of ways, but given that we don't _that_ much have gone for MAX()
# For more details of the other options considered see
# https://github.com/matrix-org/synapse/pull/8503#discussion_r502306111
sql = """ sql = """
INSERT INTO user_daily_visits (user_id, device_id, timestamp) INSERT INTO user_daily_visits (user_id, device_id, timestamp, user_agent)
SELECT u.user_id, u.device_id, ? SELECT u.user_id, u.device_id, ?, MAX(u.user_agent)
FROM user_ips AS u FROM user_ips AS u
LEFT JOIN ( LEFT JOIN (
SELECT user_id, device_id, timestamp FROM user_daily_visits SELECT user_id, device_id, timestamp FROM user_daily_visits
@ -294,7 +299,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
WHERE last_seen > ? AND last_seen <= ? WHERE last_seen > ? AND last_seen <= ?
AND udv.timestamp IS NULL AND users.is_guest=0 AND udv.timestamp IS NULL AND users.is_guest=0
AND users.appservice_id IS NULL AND users.appservice_id IS NULL
GROUP BY u.user_id, u.device_id GROUP BY u.user_id, u.device_id, u.user_agent
""" """
# This means that the day has rolled over but there could still # This means that the day has rolled over but there could still

View File

@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -274,6 +275,60 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
} }
return results return results
@cached(num_args=2,)
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> Dict[str, JsonDict]:
"""Get receipts for all rooms between two stream_ids.
Args:
to_key: Max stream id to fetch receipts upto.
from_key: Min stream id to fetch receipts from. None fetches
from the start.
Returns:
A dictionary of roomids to a list of receipts.
"""
def f(txn):
if from_key:
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
"""
txn.execute(sql, [from_key, to_key])
else:
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id <= ?
"""
txn.execute(sql, [to_key])
return self.db_pool.cursor_to_dict(txn)
txn_results = await self.db_pool.runInteraction(
"get_linearized_receipts_for_all_rooms", f
)
results = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(
row["room_id"],
{"type": "m.receipt", "room_id": row["room_id"], "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = db_to_json(row["data"])
return results
async def get_users_sent_receipts_between( async def get_users_sent_receipts_between(
self, last_id: int, current_id: int self, last_id: int, current_id: int
) -> List[str]: ) -> List[str]:

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add new column to user_daily_visits to track user agent
ALTER TABLE user_daily_visits
ADD COLUMN user_agent TEXT;

View File

@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE application_services_state
ADD COLUMN read_receipt_stream_id INT,
ADD COLUMN presence_stream_id INT;

View File

@ -208,42 +208,56 @@ class TransactionStore(TransactionWorkerStore):
""" """
self._destination_retry_cache.pop(destination, None) self._destination_retry_cache.pop(destination, None)
return await self.db_pool.runInteraction( if self.database_engine.can_native_upsert:
"set_destination_retry_timings", return await self.db_pool.runInteraction(
self._set_destination_retry_timings, "set_destination_retry_timings",
destination, self._set_destination_retry_timings_native,
failure_ts, destination,
retry_last_ts, failure_ts,
retry_interval, retry_last_ts,
) retry_interval,
db_autocommit=True, # Safe as its a single upsert
)
else:
return await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings_emulated,
destination,
failure_ts,
retry_last_ts,
retry_interval,
)
def _set_destination_retry_timings( def _set_destination_retry_timings_native(
self, txn, destination, failure_ts, retry_last_ts, retry_interval self, txn, destination, failure_ts, retry_last_ts, retry_interval
): ):
assert self.database_engine.can_native_upsert
if self.database_engine.can_native_upsert: # Upsert retry time interval if retry_interval is zero (i.e. we're
# Upsert retry time interval if retry_interval is zero (i.e. we're # resetting it) or greater than the existing retry interval.
# resetting it) or greater than the existing retry interval. #
# WARNING: This is executed in autocommit, so we shouldn't add any more
# SQL calls in here (without being very careful).
sql = """
INSERT INTO destinations (
destination, failure_ts, retry_last_ts, retry_interval
)
VALUES (?, ?, ?, ?)
ON CONFLICT (destination) DO UPDATE SET
failure_ts = EXCLUDED.failure_ts,
retry_last_ts = EXCLUDED.retry_last_ts,
retry_interval = EXCLUDED.retry_interval
WHERE
EXCLUDED.retry_interval = 0
OR destinations.retry_interval IS NULL
OR destinations.retry_interval < EXCLUDED.retry_interval
"""
sql = """ txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
INSERT INTO destinations (
destination, failure_ts, retry_last_ts, retry_interval
)
VALUES (?, ?, ?, ?)
ON CONFLICT (destination) DO UPDATE SET
failure_ts = EXCLUDED.failure_ts,
retry_last_ts = EXCLUDED.retry_last_ts,
retry_interval = EXCLUDED.retry_interval
WHERE
EXCLUDED.retry_interval = 0
OR destinations.retry_interval IS NULL
OR destinations.retry_interval < EXCLUDED.retry_interval
"""
txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
return
def _set_destination_retry_timings_emulated(
self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
self.database_engine.lock_table(txn, "destinations") self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us # We need to be careful here as the data may have changed from under us

View File

@ -480,21 +480,16 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
user_id_tuples: iterable of 2-tuple of user IDs. user_id_tuples: iterable of 2-tuple of user IDs.
""" """
def _add_users_who_share_room_txn(txn): await self.db_pool.simple_upsert_many(
self.db_pool.simple_upsert_many_txn( table="users_who_share_private_rooms",
txn, key_names=["user_id", "other_user_id", "room_id"],
table="users_who_share_private_rooms", key_values=[
key_names=["user_id", "other_user_id", "room_id"], (user_id, other_user_id, room_id)
key_values=[ for user_id, other_user_id in user_id_tuples
(user_id, other_user_id, room_id) ],
for user_id, other_user_id in user_id_tuples value_names=(),
], value_values=None,
value_names=(), desc="add_users_who_share_room",
value_values=None,
)
await self.db_pool.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
) )
async def add_users_in_public_rooms( async def add_users_in_public_rooms(
@ -508,19 +503,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
user_ids user_ids
""" """
def _add_users_in_public_rooms_txn(txn): await self.db_pool.simple_upsert_many(
table="users_in_public_rooms",
self.db_pool.simple_upsert_many_txn( key_names=["user_id", "room_id"],
txn, key_values=[(user_id, room_id) for user_id in user_ids],
table="users_in_public_rooms", value_names=(),
key_names=["user_id", "room_id"], value_values=None,
key_values=[(user_id, room_id) for user_id in user_ids], desc="add_users_in_public_rooms",
value_names=(),
value_values=None,
)
await self.db_pool.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
) )
async def delete_all_from_user_dir(self) -> None: async def delete_all_from_user_dir(self) -> None:

View File

@ -618,14 +618,7 @@ class _MultiWriterCtxManager:
db_autocommit=True, db_autocommit=True,
) )
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
with self.id_gen._lock: with self.id_gen._lock:
assert max(self.id_gen._current_positions.values(), default=0) < min(
self.stream_ids
)
self.id_gen._unfinished_ids.update(self.stream_ids) self.id_gen._unfinished_ids.update(self.stream_ids)
if self.multiple_ids is None: if self.multiple_ids is None:

View File

@ -0,0 +1,292 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import threading
from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast
from prometheus_client import Gauge
from twisted.internet import defer
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
["name"],
)
KT = TypeVar("KT")
VT = TypeVar("VT")
class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
class DeferredCache(Generic[KT, VT]):
"""Wraps an LruCache, adding support for Deferred results.
It expects that each entry added with set() will be a Deferred; likewise get()
may return an ObservableDeferred.
"""
__slots__ = (
"cache",
"name",
"keylen",
"thread",
"metrics",
"_pending_deferred_cache",
)
def __init__(
self,
name: str,
max_entries: int = 1000,
keylen: int = 1,
tree: bool = False,
iterable: bool = False,
apply_cache_factor_from_config: bool = True,
):
"""
Args:
name: The name of the cache
max_entries: Maximum amount of entries that the cache will hold
keylen: The length of the tuple used as the cache key. Ignored unless
`tree` is True.
tree: Use a TreeCache instead of a dict as the underlying cache type
iterable: If True, count each item in the cached object as an entry,
rather than each cached object
apply_cache_factor_from_config: Whether cache factors specified in the
config file affect `max_entries`
"""
cache_type = TreeCache if tree else dict
# _pending_deferred_cache maps from the key value to a `CacheEntry` object.
self._pending_deferred_cache = (
cache_type()
) # type: MutableMapping[KT, CacheEntry]
# cache is used for completed results and maps to the result itself, rather than
# a Deferred.
self.cache = LruCache(
max_size=max_entries,
keylen=keylen,
cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
apply_cache_factor_from_config=apply_cache_factor_from_config,
)
self.name = name
self.keylen = keylen
self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
self.cache,
collect_callback=self._metrics_collection_callback,
)
@property
def max_entries(self):
return self.cache.max_size
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
def _metrics_collection_callback(self):
cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(
self,
key: KT,
default=_Sentinel.sentinel,
callback: Optional[Callable[[], None]] = None,
update_metrics: bool = True,
):
"""Looks the key up in the caches.
Args:
key(tuple)
default: What is returned if key is not in the caches. If not
specified then function throws KeyError instead
callback(fn): Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either an ObservableDeferred or the result itself
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks)
if update_metrics:
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
if val is not _Sentinel.sentinel:
self.metrics.inc_hits()
return val
if update_metrics:
self.metrics.inc_misses()
if default is _Sentinel.sentinel:
raise KeyError()
else:
return default
def set(
self,
key: KT,
value: defer.Deferred,
callback: Optional[Callable[[], None]] = None,
) -> ObservableDeferred:
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
return False
def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
def eb(_fail):
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
self.check_thread()
self.cache.pop(key, None)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
# for future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
entry.invalidate()
def invalidate_many(self, key: KT):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
def invalidate_all(self):
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
entry.invalidate()
self._pending_deferred_cache.clear()
class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]
def __init__(
self, deferred: ObservableDeferred, callbacks: Iterable[Callable[[], None]]
):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()

View File

@ -13,25 +13,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools import functools
import inspect import inspect
import logging import logging
import threading
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from prometheus_client import Gauge
from twisted.internet import defer from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from . import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,239 +48,6 @@ class _CachedFunction(Generic[F]):
__call__ = None # type: F __call__ = None # type: F
cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache",
["name"],
)
_CacheSentinel = object()
class CacheEntry:
__slots__ = ["deferred", "callbacks", "invalidated"]
def __init__(self, deferred, callbacks):
self.deferred = deferred
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
class Cache:
__slots__ = (
"cache",
"name",
"keylen",
"thread",
"metrics",
"_pending_deferred_cache",
)
def __init__(
self,
name: str,
max_entries: int = 1000,
keylen: int = 1,
tree: bool = False,
iterable: bool = False,
apply_cache_factor_from_config: bool = True,
):
"""
Args:
name: The name of the cache
max_entries: Maximum amount of entries that the cache will hold
keylen: The length of the tuple used as the cache key
tree: Use a TreeCache instead of a dict as the underlying cache type
iterable: If True, count each item in the cached object as an entry,
rather than each cached object
apply_cache_factor_from_config: Whether cache factors specified in the
config file affect `max_entries`
Returns:
Cache
"""
cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
self.cache = LruCache(
max_size=max_entries,
keylen=keylen,
cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
apply_cache_factor_from_config=apply_cache_factor_from_config,
)
self.name = name
self.keylen = keylen
self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
self.cache,
collect_callback=self._metrics_collection_callback,
)
@property
def max_entries(self):
return self.cache.max_size
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
def _metrics_collection_callback(self):
cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, key, default=_CacheSentinel, callback=None, update_metrics=True):
"""Looks the key up in the caches.
Args:
key(tuple)
default: What is returned if key is not in the caches. If not
specified then function throws KeyError instead
callback(fn): Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
val.callbacks.update(callbacks)
if update_metrics:
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
if update_metrics:
self.metrics.inc_misses()
if default is _CacheSentinel:
raise KeyError()
else:
return default
def set(self, key, value, callback=None):
if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred")
callbacks = [callback] if callback else []
self.check_thread()
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def compare_and_pop():
"""Check if our entry is still the one in _pending_deferred_cache, and
if so, pop it.
Returns true if the entries matched.
"""
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
return True
# oops, the _pending_deferred_cache has been updated since
# we started our query, so we are out of date.
#
# Better put back whatever we took out. (We do it this way
# round, rather than peeking into the _pending_deferred_cache
# and then removing on a match, to make the common case faster)
if existing_entry is not None:
self._pending_deferred_cache[key] = existing_entry
return False
def cb(result):
if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
# updated, but it's possible that `set` was called without
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
def eb(_fail):
compare_and_pop()
entry.invalidate()
# once the deferred completes, we can move the entry from the
# _pending_deferred_cache to the real cache.
#
observer.addCallbacks(cb, eb)
return observable
def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
self.check_thread()
self.cache.pop(key, None)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, which will (a) stop it being returned
# for future queries and (b) stop it being persisted as a proper entry
# in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
# run the invalidation callbacks now, rather than waiting for the
# deferred to resolve.
if entry:
entry.invalidate()
def invalidate_many(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
def invalidate_all(self):
self.check_thread()
self.cache.clear()
for entry in self._pending_deferred_cache.values():
entry.invalidate()
self._pending_deferred_cache.clear()
class _CacheDescriptorBase: class _CacheDescriptorBase:
def __init__(self, orig: _CachedFunction, num_args, cache_context=False): def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
self.orig = orig self.orig = orig
@ -390,13 +150,13 @@ class CacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable self.iterable = iterable
def __get__(self, obj, owner): def __get__(self, obj, owner):
cache = Cache( cache = DeferredCache(
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
keylen=self.num_args, keylen=self.num_args,
tree=self.tree, tree=self.tree,
iterable=self.iterable, iterable=self.iterable,
) ) # type: DeferredCache[Tuple, Any]
def get_cache_key_gen(args, kwargs): def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into """Given some args/kwargs return a generator that resolves into
@ -640,9 +400,9 @@ class _CacheContext:
_cache_context_objects = ( _cache_context_objects = (
WeakValueDictionary() WeakValueDictionary()
) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext] ) # type: WeakValueDictionary[Tuple[DeferredCache, CacheKey], _CacheContext]
def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None def __init__(self, cache, cache_key): # type: (DeferredCache, CacheKey) -> None
self._cache = cache self._cache = cache
self._cache_key = cache_key self._cache_key = cache_key
@ -651,7 +411,9 @@ class _CacheContext:
self._cache.invalidate(self._cache_key) self._cache.invalidate(self._cache_key)
@classmethod @classmethod
def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext def get_instance(
cls, cache, cache_key
): # type: (DeferredCache, CacheKey) -> _CacheContext
"""Returns an instance constructed with the given arguments. """Returns an instance constructed with the given arguments.
A new instance is only created if none already exists. A new instance is only created if none already exists.

View File

@ -64,7 +64,8 @@ class LruCache:
Args: Args:
max_size: The maximum amount of entries the cache can hold max_size: The maximum amount of entries the cache can hold
keylen: The length of the tuple used as the cache key keylen: The length of the tuple used as the cache key. Ignored unless
cache_type is `TreeCache`.
cache_type (type): cache_type (type):
type of underlying cache to be used. Typically one of dict type of underlying cache to be used. Typically one of dict

View File

@ -34,7 +34,7 @@ class TTLCache:
self._data = {} self._data = {}
# the _CacheEntries, sorted by expiry time # the _CacheEntries, sorted by expiry time
self._expiry_list = SortedList() self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
self._timer = timer self._timer = timer

View File

@ -22,7 +22,7 @@ class FrontendProxyTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
http_client=None, homeserverToUse=GenericWorkerServer http_client=None, homeserver_to_use=GenericWorkerServer
) )
return hs return hs

View File

@ -26,7 +26,7 @@ from tests.unittest import HomeserverTestCase
class FederationReaderOpenIDListenerTests(HomeserverTestCase): class FederationReaderOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
http_client=None, homeserverToUse=GenericWorkerServer http_client=None, homeserver_to_use=GenericWorkerServer
) )
return hs return hs
@ -84,7 +84,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
http_client=None, homeserverToUse=SynapseHomeServer http_client=None, homeserver_to_use=SynapseHomeServer
) )
return hs return hs

View File

@ -60,7 +60,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved service=service, events=events, ephemeral=[] # txn made and saved
) )
self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made self.assertEquals(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed txn.complete.assert_called_once_with(self.store) # txn completed
@ -81,7 +81,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events # txn made and saved service=service, events=events, ephemeral=[] # txn made and saved
) )
self.assertEquals(0, txn.send.call_count) # txn not sent though self.assertEquals(0, txn.send.call_count) # txn not sent though
self.assertEquals(0, txn.complete.call_count) # or completed self.assertEquals(0, txn.complete.call_count) # or completed
@ -106,7 +106,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
self.store.create_appservice_txn.assert_called_once_with( self.store.create_appservice_txn.assert_called_once_with(
service=service, events=events service=service, events=events, ephemeral=[]
) )
self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made self.assertEquals(1, self.recoverer_fn.call_count) # recoverer made
self.assertEquals(1, self.recoverer.recover.call_count) # and invoked self.assertEquals(1, self.recoverer.recover.call_count) # and invoked
@ -202,26 +202,28 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
# Expect the event to be sent immediately. # Expect the event to be sent immediately.
service = Mock(id=4) service = Mock(id=4)
event = Mock() event = Mock()
self.queuer.enqueue(service, event) self.queuer.enqueue_event(service, event)
self.txn_ctrl.send.assert_called_once_with(service, [event]) self.txn_ctrl.send.assert_called_once_with(service, [event], [])
def test_send_single_event_with_queue(self): def test_send_single_event_with_queue(self):
d = defer.Deferred() d = defer.Deferred()
self.txn_ctrl.send = Mock(side_effect=lambda x, y: make_deferred_yieldable(d)) self.txn_ctrl.send = Mock(
side_effect=lambda x, y, z: make_deferred_yieldable(d)
)
service = Mock(id=4) service = Mock(id=4)
event = Mock(event_id="first") event = Mock(event_id="first")
event2 = Mock(event_id="second") event2 = Mock(event_id="second")
event3 = Mock(event_id="third") event3 = Mock(event_id="third")
# Send an event and don't resolve it just yet. # Send an event and don't resolve it just yet.
self.queuer.enqueue(service, event) self.queuer.enqueue_event(service, event)
# Send more events: expect send() to NOT be called multiple times. # Send more events: expect send() to NOT be called multiple times.
self.queuer.enqueue(service, event2) self.queuer.enqueue_event(service, event2)
self.queuer.enqueue(service, event3) self.queuer.enqueue_event(service, event3)
self.txn_ctrl.send.assert_called_with(service, [event]) self.txn_ctrl.send.assert_called_with(service, [event], [])
self.assertEquals(1, self.txn_ctrl.send.call_count) self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent # Resolve the send event: expect the queued events to be sent
d.callback(service) d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [event2, event3]) self.txn_ctrl.send.assert_called_with(service, [event2, event3], [])
self.assertEquals(2, self.txn_ctrl.send.call_count) self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_multiple_service_queues(self): def test_multiple_service_queues(self):
@ -239,21 +241,58 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
send_return_list = [srv_1_defer, srv_2_defer] send_return_list = [srv_1_defer, srv_2_defer]
def do_send(x, y): def do_send(x, y, z):
return make_deferred_yieldable(send_return_list.pop(0)) return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send) self.txn_ctrl.send = Mock(side_effect=do_send)
# send events for different ASes and make sure they are sent # send events for different ASes and make sure they are sent
self.queuer.enqueue(srv1, srv_1_event) self.queuer.enqueue_event(srv1, srv_1_event)
self.queuer.enqueue(srv1, srv_1_event2) self.queuer.enqueue_event(srv1, srv_1_event2)
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event]) self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [])
self.queuer.enqueue(srv2, srv_2_event) self.queuer.enqueue_event(srv2, srv_2_event)
self.queuer.enqueue(srv2, srv_2_event2) self.queuer.enqueue_event(srv2, srv_2_event2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event]) self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [])
# make sure callbacks for a service only send queued events for THAT # make sure callbacks for a service only send queued events for THAT
# service # service
srv_2_defer.callback(srv2) srv_2_defer.callback(srv2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2]) self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
self.assertEquals(3, self.txn_ctrl.send.call_count) self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_single_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
event_list = [Mock(name="event")]
self.queuer.enqueue_ephemeral(service, event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
def test_send_multiple_ephemeral_no_queue(self):
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
self.queuer.enqueue_ephemeral(service, event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], event_list)
def test_send_single_ephemeral_with_queue(self):
d = defer.Deferred()
self.txn_ctrl.send = Mock(
side_effect=lambda x, y, z: make_deferred_yieldable(d)
)
service = Mock(id=4)
event_list_1 = [Mock(event_id="event1"), Mock(event_id="event2")]
event_list_2 = [Mock(event_id="event3"), Mock(event_id="event4")]
event_list_3 = [Mock(event_id="event5"), Mock(event_id="event6")]
# Send an event and don't resolve it just yet.
self.queuer.enqueue_ephemeral(service, event_list_1)
# Send more events: expect send() to NOT be called multiple times.
self.queuer.enqueue_ephemeral(service, event_list_2)
self.queuer.enqueue_ephemeral(service, event_list_3)
self.txn_ctrl.send.assert_called_with(service, [], event_list_1)
self.assertEquals(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send
d.callback(service)
# Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
self.assertEquals(2, self.txn_ctrl.send.call_count)

View File

@ -66,7 +66,6 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
"sender": self.requester.user.to_string(), "sender": self.requester.user.to_string(),
"content": {"msgtype": "m.text", "body": random_string(5)}, "content": {"msgtype": "m.text", "body": random_string(5)},
}, },
token_id=self.token_id,
txn_id=txn_id, txn_id=txn_id,
) )
) )

View File

@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.store.get_latest_event_ids_in_room(room_id) self.store.get_latest_event_ids_in_room(room_id)
) )
event = self.get_success(builder.build(prev_event_ids)) event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(self.federation_handler.on_receive_pdu(hostname, event)) self.get_success(self.federation_handler.on_receive_pdu(hostname, event))

View File

@ -59,7 +59,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver( self.worker_hs = self.setup_test_homeserver(
http_client=None, http_client=None,
homeserverToUse=GenericWorkerServer, homeserver_to_use=GenericWorkerServer,
config=self._get_worker_hs_config(), config=self._get_worker_hs_config(),
reactor=self.reactor, reactor=self.reactor,
) )
@ -266,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
config.update(extra_config) config.update(extra_config)
worker_hs = self.setup_test_homeserver( worker_hs = self.setup_test_homeserver(
homeserverToUse=GenericWorkerServer, homeserver_to_use=GenericWorkerServer,
config=config, config=config,
reactor=self.reactor, reactor=self.reactor,
**kwargs **kwargs

View File

@ -31,7 +31,7 @@ class FederationAckTestCase(HomeserverTestCase):
return config return config
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer) hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer)
return hs return hs

View File

@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
} }
builder = factory.for_room_version(room_version, event_dict) builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success(builder.build(prev_event_ids)) join_event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(federation.on_send_join_request(remote_server, join_event)) self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate() self.replicate()

View File

@ -14,8 +14,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
from mock import patch
from synapse.api.room_versions import RoomVersion
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS from tests.utils import USE_POSTGRES_FOR_TESTS
@ -36,6 +40,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
admin.register_servlets_for_client_rest_resource, admin.register_servlets_for_client_rest_resource,
room.register_servlets, room.register_servlets,
login.register_servlets, login.register_servlets,
sync.register_servlets,
] ]
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
@ -43,6 +48,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.other_user_id = self.register_user("otheruser", "pass") self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass") self.other_access_token = self.login("otheruser", "pass")
self.room_creator = self.hs.get_room_creation_handler()
self.store = hs.get_datastore()
def default_config(self): def default_config(self):
conf = super().default_config() conf = super().default_config()
conf["redis"] = {"enabled": "true"} conf["redis"] = {"enabled": "true"}
@ -53,6 +61,29 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
} }
return conf return conf
def _create_room(self, room_id: str, user_id: str, tok: str):
"""Create a room with given room_id
"""
# We control the room ID generation by patching out the
# `_generate_room_id` method
async def generate_room(
creator_id: str, is_public: bool, room_version: RoomVersion
):
await self.store.store_room(
room_id=room_id,
room_creator_user_id=creator_id,
is_public=is_public,
room_version=room_version,
)
return room_id
with patch(
"synapse.handlers.room.RoomCreationHandler._generate_room_id"
) as mock:
mock.side_effect = generate_room
self.helper.create_room_as(user_id, tok=tok)
def test_basic(self): def test_basic(self):
"""Simple test to ensure that multiple rooms can be created and joined, """Simple test to ensure that multiple rooms can be created and joined,
and that different rooms get handled by different instances. and that different rooms get handled by different instances.
@ -100,3 +131,189 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertTrue(persisted_on_1) self.assertTrue(persisted_on_1)
self.assertTrue(persisted_on_2) self.assertTrue(persisted_on_2)
def test_vector_clock_token(self):
"""Tests that using a stream token with a vector clock component works
correctly with basic /sync and /messages usage.
"""
self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "worker1"},
)
worker_hs2 = self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "worker2"},
)
sync_hs = self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "sync"},
)
# Specially selected room IDs that get persisted on different workers.
room_id1 = "!foo:test"
room_id2 = "!baz:test"
self.assertEqual(
self.hs.config.worker.events_shard_config.get_instance(room_id1), "worker1"
)
self.assertEqual(
self.hs.config.worker.events_shard_config.get_instance(room_id2), "worker2"
)
user_id = self.register_user("user", "pass")
access_token = self.login("user", "pass")
store = self.hs.get_datastore()
# Create two room on the different workers.
self._create_room(room_id1, user_id, access_token)
self._create_room(room_id2, user_id, access_token)
# The other user joins
self.helper.join(
room=room_id1, user=self.other_user_id, tok=self.other_access_token
)
self.helper.join(
room=room_id2, user=self.other_user_id, tok=self.other_access_token
)
# Do an initial sync so that we're up to date.
request, channel = self.make_request("GET", "/sync", access_token=access_token)
self.render_on_worker(sync_hs, request)
next_batch = channel.json_body["next_batch"]
# We now gut wrench into the events stream MultiWriterIdGenerator on
# worker2 to mimic it getting stuck persisting an event. This ensures
# that when we send an event on worker1 we end up in a state where
# worker2 events stream position lags that on worker1, resulting in a
# RoomStreamToken with a non-empty instance map component.
#
# Worker2's event stream position will not advance until we call
# __aexit__ again.
actx = worker_hs2.get_datastore()._stream_id_gen.get_next()
self.get_success(actx.__aenter__())
response = self.helper.send(room_id1, body="Hi!", tok=self.other_access_token)
first_event_in_room1 = response["event_id"]
# Assert that the current stream token has an instance map component, as
# we are trying to test vector clock tokens.
room_stream_token = store.get_room_max_token()
self.assertNotEqual(len(room_stream_token.instance_map), 0)
# Check that syncing still gets the new event, despite the gap in the
# stream IDs.
request, channel = self.make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token
)
self.render_on_worker(sync_hs, request)
# We should only see the new event and nothing else
self.assertIn(room_id1, channel.json_body["rooms"]["join"])
self.assertNotIn(room_id2, channel.json_body["rooms"]["join"])
events = channel.json_body["rooms"]["join"][room_id1]["timeline"]["events"]
self.assertListEqual(
[first_event_in_room1], [event["event_id"] for event in events]
)
# Get the next batch and makes sure its a vector clock style token.
vector_clock_token = channel.json_body["next_batch"]
self.assertTrue(vector_clock_token.startswith("m"))
# Now that we've got a vector clock token we finish the fake persisting
# an event we started above.
self.get_success(actx.__aexit__(None, None, None))
# Now try and send an event to the other rooom so that we can test that
# the vector clock style token works as a `since` token.
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"]
request, channel = self.make_request(
"GET",
"/sync?since={}".format(vector_clock_token),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)
self.assertNotIn(room_id1, channel.json_body["rooms"]["join"])
self.assertIn(room_id2, channel.json_body["rooms"]["join"])
events = channel.json_body["rooms"]["join"][room_id2]["timeline"]["events"]
self.assertListEqual(
[first_event_in_room2], [event["event_id"] for event in events]
)
next_batch = channel.json_body["next_batch"]
# We also want to test that the vector clock style token works with
# pagination. We do this by sending a couple of new events into the room
# and syncing again to get a prev_batch token for each room, then
# paginating from there back to the vector clock token.
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
request, channel = self.make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token
)
self.render_on_worker(sync_hs, request)
prev_batch1 = channel.json_body["rooms"]["join"][room_id1]["timeline"][
"prev_batch"
]
prev_batch2 = channel.json_body["rooms"]["join"][room_id2]["timeline"][
"prev_batch"
]
# Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion.
request, channel = self.make_request(
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id1, prev_batch1, vector_clock_token
),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"])
# Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken.
request, channel = self.make_request(
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id2, prev_batch2, vector_clock_token
),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)
self.assertEqual(len(channel.json_body["chunk"]), 1)
self.assertEqual(
channel.json_body["chunk"][0]["event_id"], first_event_in_room2
)
# Paginating forwards should give the same results
request, channel = self.make_request(
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id1, vector_clock_token, prev_batch1
),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"])
request, channel = self.make_request(
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id2, vector_clock_token, prev_batch2,
),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)
self.assertEqual(len(channel.json_body["chunk"]), 1)
self.assertEqual(
channel.json_body["chunk"][0]["event_id"], first_event_in_room2
)

View File

@ -114,16 +114,36 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
self.render(request) self.render(request)
self.assertEquals(channel.result["code"], b"403", channel.result) self.assertEquals(channel.result["code"], b"403", channel.result)
def test_modify_event(self): def test_cannot_modify_event(self):
"""Tests that the module can successfully tweak an event before it is persisted. """cannot accidentally modify an event before it is persisted"""
"""
# first patch the event checker so that it will modify the event # first patch the event checker so that it will try to modify the event
async def check(ev: EventBase, state): async def check(ev: EventBase, state):
ev.content = {"x": "y"} ev.content = {"x": "y"}
return True return True
current_rules_module().check_event_allowed = check current_rules_module().check_event_allowed = check
# now send the event
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/modifyme/1" % self.room_id,
{"x": "x"},
access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.result["code"], b"500", channel.result)
def test_modify_event(self):
"""The module can return a modified version of the event"""
# first patch the event checker so that it will modify the event
async def check(ev: EventBase, state):
d = ev.get_dict()
d["content"] = {"x": "y"}
return d
current_rules_module().check_event_allowed = check
# now send the event # now send the event
request, channel = self.make_request( request, channel = self.make_request(
"PUT", "PUT",

View File

@ -20,82 +20,11 @@ from mock import Mock
from twisted.internet import defer from twisted.internet import defer
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import Cache, cached from synapse.util.caches.descriptors import cached
from tests import unittest from tests import unittest
class CacheTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
self.cache = Cache("test")
def test_empty(self):
failed = False
try:
self.cache.get("foo")
except KeyError:
failed = True
self.assertTrue(failed)
def test_hit(self):
self.cache.prefill("foo", 123)
self.assertEquals(self.cache.get("foo"), 123)
def test_invalidate(self):
self.cache.prefill(("foo",), 123)
self.cache.invalidate(("foo",))
failed = False
try:
self.cache.get(("foo",))
except KeyError:
failed = True
self.assertTrue(failed)
def test_eviction(self):
cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
cache.prefill(3, "three") # 1 will be evicted
failed = False
try:
cache.get(1)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(2)
cache.get(3)
def test_eviction_lru(self):
cache = Cache("test", max_entries=2)
cache.prefill(1, "one")
cache.prefill(2, "two")
# Now access 1 again, thus causing 2 to be least-recently used
cache.get(1)
cache.prefill(3, "three")
failed = False
try:
cache.get(2)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(1)
cache.get(3)
class CacheDecoratorTestCase(unittest.HomeserverTestCase): class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_passthrough(self): def test_passthrough(self):

View File

@ -244,7 +244,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[0]["id"]) service = Mock(id=self.as_list[0]["id"])
events = [Mock(event_id="e1"), Mock(event_id="e2")] events = [Mock(event_id="e1"), Mock(event_id="e2")]
txn = yield defer.ensureDeferred( txn = yield defer.ensureDeferred(
self.store.create_appservice_txn(service, events) self.store.create_appservice_txn(service, events, [])
) )
self.assertEquals(txn.id, 1) self.assertEquals(txn.id, 1)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
@ -258,7 +258,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(service.id, 9644, events) yield self._insert_txn(service.id, 9644, events)
yield self._insert_txn(service.id, 9645, events) yield self._insert_txn(service.id, 9645, events)
txn = yield defer.ensureDeferred( txn = yield defer.ensureDeferred(
self.store.create_appservice_txn(service, events) self.store.create_appservice_txn(service, events, [])
) )
self.assertEquals(txn.id, 9646) self.assertEquals(txn.id, 9646)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
@ -270,7 +270,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
events = [Mock(event_id="e1"), Mock(event_id="e2")] events = [Mock(event_id="e1"), Mock(event_id="e2")]
yield self._set_last_txn(service.id, 9643) yield self._set_last_txn(service.id, 9643)
txn = yield defer.ensureDeferred( txn = yield defer.ensureDeferred(
self.store.create_appservice_txn(service, events) self.store.create_appservice_txn(service, events, [])
) )
self.assertEquals(txn.id, 9644) self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)
@ -293,7 +293,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._insert_txn(self.as_list[3]["id"], 9643, events) yield self._insert_txn(self.as_list[3]["id"], 9643, events)
txn = yield defer.ensureDeferred( txn = yield defer.ensureDeferred(
self.store.create_appservice_txn(service, events) self.store.create_appservice_txn(service, events, [])
) )
self.assertEquals(txn.id, 9644) self.assertEquals(txn.id, 9644)
self.assertEquals(txn.events, events) self.assertEquals(txn.events, events)

View File

@ -236,9 +236,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self._event_id = event_id self._event_id = event_id
@defer.inlineCallbacks @defer.inlineCallbacks
def build(self, prev_event_ids): def build(self, prev_event_ids, auth_event_ids):
built_event = yield defer.ensureDeferred( built_event = yield defer.ensureDeferred(
self._base_builder.build(prev_event_ids) self._base_builder.build(prev_event_ids, auth_event_ids)
) )
built_event._event_id = self._event_id built_event._event_id = self._event_id

View File

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
from synapse.util.caches.descriptors import Cache from synapse.util.caches.deferred_cache import DeferredCache
from tests import unittest from tests import unittest
@ -138,7 +138,7 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
Caches produce metrics reflecting their state when scraped. Caches produce metrics reflecting their state when scraped.
""" """
CACHE_NAME = "cache_metrics_test_fgjkbdfg" CACHE_NAME = "cache_metrics_test_fgjkbdfg"
cache = Cache(CACHE_NAME, max_entries=777) cache = DeferredCache(CACHE_NAME, max_entries=777)
items = { items = {
x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii")

View File

@ -20,7 +20,7 @@ import hmac
import inspect import inspect
import logging import logging
import time import time
from typing import Optional, Tuple, Type, TypeVar, Union from typing import Optional, Tuple, Type, TypeVar, Union, overload
from mock import Mock, patch from mock import Mock, patch
@ -364,6 +364,36 @@ class HomeserverTestCase(TestCase):
Function to optionally be overridden in subclasses. Function to optionally be overridden in subclasses.
""" """
# Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest
# when the `request` arg isn't given, so we define an explicit override to
# cover that case.
@overload
def make_request(
self,
method: Union[bytes, str],
path: Union[bytes, str],
content: Union[bytes, dict] = b"",
access_token: Optional[str] = None,
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
) -> Tuple[SynapseRequest, FakeChannel]:
...
@overload
def make_request(
self,
method: Union[bytes, str],
path: Union[bytes, str],
content: Union[bytes, dict] = b"",
access_token: Optional[str] = None,
request: Type[T] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
) -> Tuple[T, FakeChannel]:
...
def make_request( def make_request(
self, self,
method: Union[bytes, str], method: Union[bytes, str],

View File

@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from twisted.internet import defer
from synapse.util.caches.deferred_cache import DeferredCache
class DeferredCacheTestCase(unittest.TestCase):
def test_empty(self):
cache = DeferredCache("test")
failed = False
try:
cache.get("foo")
except KeyError:
failed = True
self.assertTrue(failed)
def test_hit(self):
cache = DeferredCache("test")
cache.prefill("foo", 123)
self.assertEquals(cache.get("foo"), 123)
def test_invalidate(self):
cache = DeferredCache("test")
cache.prefill(("foo",), 123)
cache.invalidate(("foo",))
failed = False
try:
cache.get(("foo",))
except KeyError:
failed = True
self.assertTrue(failed)
def test_invalidate_all(self):
cache = DeferredCache("testcache")
callback_record = [False, False]
def record_callback(idx):
callback_record[idx] = True
# add a couple of pending entries
d1 = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete
d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
cache.invalidate_all()
# lookup should return none
self.assertIsNone(cache.get("key1", None))
self.assertIsNone(cache.get("key2", None))
# both callbacks should have been callbacked
self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
# letting the other lookup complete should do nothing
d1.callback("result1")
self.assertIsNone(cache.get("key1", None))
def test_eviction(self):
cache = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
cache.prefill(1, "one")
cache.prefill(2, "two")
cache.prefill(3, "three") # 1 will be evicted
failed = False
try:
cache.get(1)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(2)
cache.get(3)
def test_eviction_lru(self):
cache = DeferredCache(
"test", max_entries=2, apply_cache_factor_from_config=False
)
cache.prefill(1, "one")
cache.prefill(2, "two")
# Now access 1 again, thus causing 2 to be least-recently used
cache.get(1)
cache.prefill(3, "three")
failed = False
try:
cache.get(2)
except KeyError:
failed = True
self.assertTrue(failed)
cache.get(1)
cache.get(3)

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from functools import partial
import mock import mock
@ -42,49 +41,6 @@ def run_on_reactor():
return make_deferred_yieldable(d) return make_deferred_yieldable(d)
class CacheTestCase(unittest.TestCase):
def test_invalidate_all(self):
cache = descriptors.Cache("testcache")
callback_record = [False, False]
def record_callback(idx):
callback_record[idx] = True
# add a couple of pending entries
d1 = defer.Deferred()
cache.set("key1", d1, partial(record_callback, 0))
d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1))
# lookup should return observable deferreds
self.assertFalse(cache.get("key1").has_called())
self.assertFalse(cache.get("key2").has_called())
# let one of the lookups complete
d2.callback("result2")
# for now at least, the cache will return real results rather than an
# observabledeferred
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation
cache.invalidate_all()
# lookup should return none
self.assertIsNone(cache.get("key1", None))
self.assertIsNone(cache.get("key2", None))
# both callbacks should have been callbacked
self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
# letting the other lookup complete should do nothing
d1.callback("result1")
self.assertIsNone(cache.get("key1", None))
class DescriptorTestCase(unittest.TestCase): class DescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cache(self): def test_cache(self):

View File

@ -21,6 +21,7 @@ import time
import uuid import uuid
import warnings import warnings
from inspect import getcallargs from inspect import getcallargs
from typing import Type
from urllib import parse as urlparse from urllib import parse as urlparse
from mock import Mock, patch from mock import Mock, patch
@ -194,8 +195,8 @@ def setup_test_homeserver(
name="test", name="test",
config=None, config=None,
reactor=None, reactor=None,
homeserverToUse=TestHomeServer, homeserver_to_use: Type[HomeServer] = TestHomeServer,
**kargs **kwargs
): ):
""" """
Setup a homeserver suitable for running tests against. Keyword arguments Setup a homeserver suitable for running tests against. Keyword arguments
@ -218,8 +219,8 @@ def setup_test_homeserver(
config.ldap_enabled = False config.ldap_enabled = False
if "clock" not in kargs: if "clock" not in kwargs:
kargs["clock"] = MockClock() kwargs["clock"] = MockClock()
if USE_POSTGRES_FOR_TESTS: if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex test_db = "synapse_test_%s" % uuid.uuid4().hex
@ -264,18 +265,20 @@ def setup_test_homeserver(
cur.close() cur.close()
db_conn.close() db_conn.close()
hs = homeserverToUse( hs = homeserver_to_use(
name, name, config=config, version_string="Synapse/tests", reactor=reactor,
config=config,
version_string="Synapse/tests",
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,
**kargs
) )
# Install @cache_in_self attributes
for key, val in kwargs.items():
setattr(hs, key, val)
# Mock TLS
hs.tls_server_context_factory = Mock()
hs.tls_client_options_factory = Mock()
hs.setup() hs.setup()
if homeserverToUse.__name__ == "TestHomeServer": if homeserver_to_use == TestHomeServer:
hs.setup_background_tasks() hs.setup_background_tasks()
if isinstance(db_engine, PostgresEngine): if isinstance(db_engine, PostgresEngine):
@ -339,7 +342,7 @@ def setup_test_homeserver(
hs.get_auth_handler().validate_hash = validate_hash hs.get_auth_handler().validate_hash = validate_hash
fed = kargs.get("resource_for_federation", None) fed = kwargs.get("resource_for_federation", None)
if fed: if fed:
register_federation_servlets(hs, fed) register_federation_servlets(hs, fed)