377 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			377 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
# Copyright 2014-2016 OpenMarket Ltd
 | 
						|
# Copyright 2018 New Vector Ltd
 | 
						|
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
 | 
						|
import logging
 | 
						|
from typing import TYPE_CHECKING, List, Optional, Tuple
 | 
						|
 | 
						|
from synapse.config.homeserver import HomeServerConfig
 | 
						|
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 | 
						|
from synapse.storage.databases.main.stats import UserSortOrder
 | 
						|
from synapse.storage.engines import PostgresEngine
 | 
						|
from synapse.storage.util.id_generators import (
 | 
						|
    IdGenerator,
 | 
						|
    MultiWriterIdGenerator,
 | 
						|
    StreamIdGenerator,
 | 
						|
)
 | 
						|
from synapse.types import JsonDict, get_domain_from_id
 | 
						|
from synapse.util.caches.stream_change_cache import StreamChangeCache
 | 
						|
 | 
						|
from .account_data import AccountDataStore
 | 
						|
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
 | 
						|
from .cache import CacheInvalidationWorkerStore
 | 
						|
from .censor_events import CensorEventsStore
 | 
						|
from .client_ips import ClientIpStore
 | 
						|
from .deviceinbox import DeviceInboxStore
 | 
						|
from .devices import DeviceStore
 | 
						|
from .directory import DirectoryStore
 | 
						|
from .e2e_room_keys import EndToEndRoomKeyStore
 | 
						|
from .end_to_end_keys import EndToEndKeyStore
 | 
						|
from .event_federation import EventFederationStore
 | 
						|
from .event_push_actions import EventPushActionsStore
 | 
						|
from .events_bg_updates import EventsBackgroundUpdatesStore
 | 
						|
from .events_forward_extremities import EventForwardExtremitiesStore
 | 
						|
from .filtering import FilteringStore
 | 
						|
from .group_server import GroupServerStore
 | 
						|
from .keys import KeyStore
 | 
						|
from .lock import LockStore
 | 
						|
from .media_repository import MediaRepositoryStore
 | 
						|
from .metrics import ServerMetricsStore
 | 
						|
from .monthly_active_users import MonthlyActiveUsersStore
 | 
						|
from .openid import OpenIdStore
 | 
						|
from .presence import PresenceStore
 | 
						|
from .profile import ProfileStore
 | 
						|
from .purge_events import PurgeEventsStore
 | 
						|
from .push_rule import PushRuleStore
 | 
						|
from .pusher import PusherStore
 | 
						|
from .receipts import ReceiptsStore
 | 
						|
from .registration import RegistrationStore
 | 
						|
from .rejections import RejectionsStore
 | 
						|
from .relations import RelationsStore
 | 
						|
from .room import RoomStore
 | 
						|
from .room_batch import RoomBatchStore
 | 
						|
from .roommember import RoomMemberStore
 | 
						|
from .search import SearchStore
 | 
						|
from .session import SessionStore
 | 
						|
from .signatures import SignatureStore
 | 
						|
from .state import StateStore
 | 
						|
from .stats import StatsStore
 | 
						|
from .stream import StreamWorkerStore
 | 
						|
from .tags import TagsStore
 | 
						|
from .transactions import TransactionWorkerStore
 | 
						|
from .ui_auth import UIAuthStore
 | 
						|
from .user_directory import UserDirectoryStore
 | 
						|
from .user_erasure_store import UserErasureStore
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    from synapse.server import HomeServer
 | 
						|
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
 | 
						|
class DataStore(
 | 
						|
    EventsBackgroundUpdatesStore,
 | 
						|
    RoomMemberStore,
 | 
						|
    RoomStore,
 | 
						|
    RoomBatchStore,
 | 
						|
    RegistrationStore,
 | 
						|
    StreamWorkerStore,
 | 
						|
    ProfileStore,
 | 
						|
    PresenceStore,
 | 
						|
    TransactionWorkerStore,
 | 
						|
    DirectoryStore,
 | 
						|
    KeyStore,
 | 
						|
    StateStore,
 | 
						|
    SignatureStore,
 | 
						|
    ApplicationServiceStore,
 | 
						|
    PurgeEventsStore,
 | 
						|
    EventFederationStore,
 | 
						|
    MediaRepositoryStore,
 | 
						|
    RejectionsStore,
 | 
						|
    FilteringStore,
 | 
						|
    PusherStore,
 | 
						|
    PushRuleStore,
 | 
						|
    ApplicationServiceTransactionStore,
 | 
						|
    ReceiptsStore,
 | 
						|
    EndToEndKeyStore,
 | 
						|
    EndToEndRoomKeyStore,
 | 
						|
    SearchStore,
 | 
						|
    TagsStore,
 | 
						|
    AccountDataStore,
 | 
						|
    EventPushActionsStore,
 | 
						|
    OpenIdStore,
 | 
						|
    ClientIpStore,
 | 
						|
    DeviceStore,
 | 
						|
    DeviceInboxStore,
 | 
						|
    UserDirectoryStore,
 | 
						|
    GroupServerStore,
 | 
						|
    UserErasureStore,
 | 
						|
    MonthlyActiveUsersStore,
 | 
						|
    StatsStore,
 | 
						|
    RelationsStore,
 | 
						|
    CensorEventsStore,
 | 
						|
    UIAuthStore,
 | 
						|
    EventForwardExtremitiesStore,
 | 
						|
    CacheInvalidationWorkerStore,
 | 
						|
    ServerMetricsStore,
 | 
						|
    LockStore,
 | 
						|
    SessionStore,
 | 
						|
):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        database: DatabasePool,
 | 
						|
        db_conn: LoggingDatabaseConnection,
 | 
						|
        hs: "HomeServer",
 | 
						|
    ):
 | 
						|
        self.hs = hs
 | 
						|
        self._clock = hs.get_clock()
 | 
						|
        self.database_engine = database.engine
 | 
						|
 | 
						|
        self._device_list_id_gen = StreamIdGenerator(
 | 
						|
            db_conn,
 | 
						|
            "device_lists_stream",
 | 
						|
            "stream_id",
 | 
						|
            extra_tables=[
 | 
						|
                ("user_signature_stream", "stream_id"),
 | 
						|
                ("device_lists_outbound_pokes", "stream_id"),
 | 
						|
            ],
 | 
						|
        )
 | 
						|
 | 
						|
        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
 | 
						|
        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
 | 
						|
        self._group_updates_id_gen = StreamIdGenerator(
 | 
						|
            db_conn, "local_group_updates", "stream_id"
 | 
						|
        )
 | 
						|
 | 
						|
        self._cache_id_gen: Optional[MultiWriterIdGenerator]
 | 
						|
        if isinstance(self.database_engine, PostgresEngine):
 | 
						|
            # We set the `writers` to an empty list here as we don't care about
 | 
						|
            # missing updates over restarts, as we'll not have anything in our
 | 
						|
            # caches to invalidate. (This reduces the amount of writes to the DB
 | 
						|
            # that happen).
 | 
						|
            self._cache_id_gen = MultiWriterIdGenerator(
 | 
						|
                db_conn,
 | 
						|
                database,
 | 
						|
                stream_name="caches",
 | 
						|
                instance_name=hs.get_instance_name(),
 | 
						|
                tables=[
 | 
						|
                    (
 | 
						|
                        "cache_invalidation_stream_by_instance",
 | 
						|
                        "instance_name",
 | 
						|
                        "stream_id",
 | 
						|
                    )
 | 
						|
                ],
 | 
						|
                sequence_name="cache_invalidation_stream_seq",
 | 
						|
                writers=[],
 | 
						|
            )
 | 
						|
 | 
						|
        else:
 | 
						|
            self._cache_id_gen = None
 | 
						|
 | 
						|
        super().__init__(database, db_conn, hs)
 | 
						|
 | 
						|
        device_list_max = self._device_list_id_gen.get_current_token()
 | 
						|
        self._device_list_stream_cache = StreamChangeCache(
 | 
						|
            "DeviceListStreamChangeCache", device_list_max
 | 
						|
        )
 | 
						|
        self._user_signature_stream_cache = StreamChangeCache(
 | 
						|
            "UserSignatureStreamChangeCache", device_list_max
 | 
						|
        )
 | 
						|
        self._device_list_federation_stream_cache = StreamChangeCache(
 | 
						|
            "DeviceListFederationStreamChangeCache", device_list_max
 | 
						|
        )
 | 
						|
 | 
						|
        events_max = self._stream_id_gen.get_current_token()
 | 
						|
        curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
 | 
						|
            db_conn,
 | 
						|
            "current_state_delta_stream",
 | 
						|
            entity_column="room_id",
 | 
						|
            stream_column="stream_id",
 | 
						|
            max_value=events_max,  # As we share the stream id with events token
 | 
						|
            limit=1000,
 | 
						|
        )
 | 
						|
        self._curr_state_delta_stream_cache = StreamChangeCache(
 | 
						|
            "_curr_state_delta_stream_cache",
 | 
						|
            min_curr_state_delta_id,
 | 
						|
            prefilled_cache=curr_state_delta_prefill,
 | 
						|
        )
 | 
						|
 | 
						|
        _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
 | 
						|
            db_conn,
 | 
						|
            "local_group_updates",
 | 
						|
            entity_column="user_id",
 | 
						|
            stream_column="stream_id",
 | 
						|
            max_value=self._group_updates_id_gen.get_current_token(),
 | 
						|
            limit=1000,
 | 
						|
        )
 | 
						|
        self._group_updates_stream_cache = StreamChangeCache(
 | 
						|
            "_group_updates_stream_cache",
 | 
						|
            min_group_updates_id,
 | 
						|
            prefilled_cache=_group_updates_prefill,
 | 
						|
        )
 | 
						|
 | 
						|
        self._stream_order_on_start = self.get_room_max_stream_ordering()
 | 
						|
        self._min_stream_order_on_start = self.get_room_min_stream_ordering()
 | 
						|
 | 
						|
    def get_device_stream_token(self) -> int:
 | 
						|
        return self._device_list_id_gen.get_current_token()
 | 
						|
 | 
						|
    async def get_users(self) -> List[JsonDict]:
 | 
						|
        """Function to retrieve a list of users in users table.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            A list of dictionaries representing users.
 | 
						|
        """
 | 
						|
        return await self.db_pool.simple_select_list(
 | 
						|
            table="users",
 | 
						|
            keyvalues={},
 | 
						|
            retcols=[
 | 
						|
                "name",
 | 
						|
                "password_hash",
 | 
						|
                "is_guest",
 | 
						|
                "admin",
 | 
						|
                "user_type",
 | 
						|
                "deactivated",
 | 
						|
            ],
 | 
						|
            desc="get_users",
 | 
						|
        )
 | 
						|
 | 
						|
    async def get_users_paginate(
 | 
						|
        self,
 | 
						|
        start: int,
 | 
						|
        limit: int,
 | 
						|
        user_id: Optional[str] = None,
 | 
						|
        name: Optional[str] = None,
 | 
						|
        guests: bool = True,
 | 
						|
        deactivated: bool = False,
 | 
						|
        order_by: str = UserSortOrder.USER_ID.value,
 | 
						|
        direction: str = "f",
 | 
						|
    ) -> Tuple[List[JsonDict], int]:
 | 
						|
        """Function to retrieve a paginated list of users from
 | 
						|
        users list. This will return a json list of users and the
 | 
						|
        total number of users matching the filter criteria.
 | 
						|
 | 
						|
        Args:
 | 
						|
            start: start number to begin the query from
 | 
						|
            limit: number of rows to retrieve
 | 
						|
            user_id: search for user_id. ignored if name is not None
 | 
						|
            name: search for local part of user_id or display name
 | 
						|
            guests: whether to in include guest users
 | 
						|
            deactivated: whether to include deactivated users
 | 
						|
            order_by: the sort order of the returned list
 | 
						|
            direction: sort ascending or descending
 | 
						|
        Returns:
 | 
						|
            A tuple of a list of mappings from user to information and a count of total users.
 | 
						|
        """
 | 
						|
 | 
						|
        def get_users_paginate_txn(txn):
 | 
						|
            filters = []
 | 
						|
            args = [self.hs.config.server.server_name]
 | 
						|
 | 
						|
            # Set ordering
 | 
						|
            order_by_column = UserSortOrder(order_by).value
 | 
						|
 | 
						|
            if direction == "b":
 | 
						|
                order = "DESC"
 | 
						|
            else:
 | 
						|
                order = "ASC"
 | 
						|
 | 
						|
            # `name` is in database already in lower case
 | 
						|
            if name:
 | 
						|
                filters.append("(name LIKE ? OR LOWER(displayname) LIKE ?)")
 | 
						|
                args.extend(["@%" + name.lower() + "%:%", "%" + name.lower() + "%"])
 | 
						|
            elif user_id:
 | 
						|
                filters.append("name LIKE ?")
 | 
						|
                args.extend(["%" + user_id.lower() + "%"])
 | 
						|
 | 
						|
            if not guests:
 | 
						|
                filters.append("is_guest = 0")
 | 
						|
 | 
						|
            if not deactivated:
 | 
						|
                filters.append("deactivated = 0")
 | 
						|
 | 
						|
            where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
 | 
						|
 | 
						|
            sql_base = f"""
 | 
						|
                FROM users as u
 | 
						|
                LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
 | 
						|
                {where_clause}
 | 
						|
                """
 | 
						|
            sql = "SELECT COUNT(*) as total_users " + sql_base
 | 
						|
            txn.execute(sql, args)
 | 
						|
            count = txn.fetchone()[0]
 | 
						|
 | 
						|
            sql = f"""
 | 
						|
                SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
 | 
						|
                displayname, avatar_url, creation_ts * 1000 as creation_ts
 | 
						|
                {sql_base}
 | 
						|
                ORDER BY {order_by_column} {order}, u.name ASC
 | 
						|
                LIMIT ? OFFSET ?
 | 
						|
            """
 | 
						|
            args += [limit, start]
 | 
						|
            txn.execute(sql, args)
 | 
						|
            users = self.db_pool.cursor_to_dict(txn)
 | 
						|
            return users, count
 | 
						|
 | 
						|
        return await self.db_pool.runInteraction(
 | 
						|
            "get_users_paginate_txn", get_users_paginate_txn
 | 
						|
        )
 | 
						|
 | 
						|
    async def search_users(self, term: str) -> Optional[List[JsonDict]]:
 | 
						|
        """Function to search users list for one or more users with
 | 
						|
        the matched term.
 | 
						|
 | 
						|
        Args:
 | 
						|
            term: search term
 | 
						|
 | 
						|
        Returns:
 | 
						|
            A list of dictionaries or None.
 | 
						|
        """
 | 
						|
        return await self.db_pool.simple_search_list(
 | 
						|
            table="users",
 | 
						|
            term=term,
 | 
						|
            col="name",
 | 
						|
            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
 | 
						|
            desc="search_users",
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
 | 
						|
    """Called before upgrading an existing database to check that it is broadly sane
 | 
						|
    compared with the configuration.
 | 
						|
    """
 | 
						|
    logger.info("Checking database for consistency with configuration...")
 | 
						|
 | 
						|
    # if there are any users in the database, check that the username matches our
 | 
						|
    # configured server name.
 | 
						|
 | 
						|
    cur.execute("SELECT name FROM users LIMIT 1")
 | 
						|
    rows = cur.fetchall()
 | 
						|
    if not rows:
 | 
						|
        return
 | 
						|
 | 
						|
    user_domain = get_domain_from_id(rows[0][0])
 | 
						|
    if user_domain == config.server.server_name:
 | 
						|
        return
 | 
						|
 | 
						|
    raise Exception(
 | 
						|
        "Found users in database not native to %s!\n"
 | 
						|
        "You cannot change a synapse server_name after it's been configured"
 | 
						|
        % (config.server.server_name,)
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
__all__ = ["DataStore", "check_database_before_upgrade"]
 |