MatrixSynapse/synapse/storage/databases/main/session.py

146 lines
4.9 KiB
Python
Raw Normal View History

# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
import synapse.util.stringutils as stringutils
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
class SessionStore(SQLBaseStore):
"""
A store for generic session data.
Each type of session should provide a unique type (to separate sessions).
Sessions are automatically removed when they expire.
"""
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Create a background job for culling expired sessions.
if hs.config.worker.run_background_tasks:
self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
async def create_session(
self, session_type: str, value: JsonDict, expiry_ms: int
) -> str:
"""
Creates a new pagination session for the room hierarchy endpoint.
Args:
session_type: The type for this session.
value: The value to store.
expiry_ms: How long before an item is evicted from the cache
in milliseconds. Default is 0, indicating items never get
evicted based on time.
Returns:
The newly created session ID.
Raises:
StoreError if a unique session ID cannot be generated.
"""
# autogen a session ID and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
while attempts < 5:
session_id = stringutils.random_string(24)
try:
await self.db_pool.simple_insert(
table="sessions",
values={
"session_id": session_id,
"session_type": session_type,
"value": json_encoder.encode(value),
"expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
},
desc="create_session",
)
return session_id
except self.db_pool.engine.module.IntegrityError:
attempts += 1
raise StoreError(500, "Couldn't generate a session ID.")
async def get_session(self, session_type: str, session_id: str) -> JsonDict:
"""
Retrieve data stored with create_session
Args:
session_type: The type for this session.
session_id: The session ID returned from create_session.
Raises:
StoreError if the session cannot be found.
"""
def _get_session(
txn: LoggingTransaction, session_type: str, session_id: str, ts: int
) -> JsonDict:
# This includes the expiry time since items are only periodically
# deleted, not upon expiry.
select_sql = """
SELECT value FROM sessions WHERE
session_type = ? AND session_id = ? AND expiry_time_ms > ?
"""
txn.execute(select_sql, [session_type, session_id, ts])
row = txn.fetchone()
if not row:
raise StoreError(404, "No session")
return db_to_json(row[0])
return await self.db_pool.runInteraction(
"get_session",
_get_session,
session_type,
session_id,
self._clock.time_msec(),
)
@wrap_as_background_process("delete_expired_sessions")
async def _delete_expired_sessions(self) -> None:
"""Remove sessions with expiry dates that have passed."""
def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
txn.execute(sql, (ts,))
await self.db_pool.runInteraction(
"delete_expired_sessions",
_delete_expired_sessions_txn,
self._clock.time_msec(),
)