146 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			146 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
| # -*- coding: utf-8 -*-
 | |
| #  Copyright 2021 The Matrix.org Foundation C.I.C.
 | |
| #
 | |
| #  Licensed under the Apache License, Version 2.0 (the "License");
 | |
| #  you may not use this file except in compliance with the License.
 | |
| #  You may obtain a copy of the License at
 | |
| #
 | |
| #      http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| #  Unless required by applicable law or agreed to in writing, software
 | |
| #  distributed under the License is distributed on an "AS IS" BASIS,
 | |
| #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| #  See the License for the specific language governing permissions and
 | |
| #  limitations under the License.
 | |
| from typing import TYPE_CHECKING
 | |
| 
 | |
| import synapse.util.stringutils as stringutils
 | |
| from synapse.api.errors import StoreError
 | |
| from synapse.metrics.background_process_metrics import wrap_as_background_process
 | |
| from synapse.storage._base import SQLBaseStore, db_to_json
 | |
| from synapse.storage.database import (
 | |
|     DatabasePool,
 | |
|     LoggingDatabaseConnection,
 | |
|     LoggingTransaction,
 | |
| )
 | |
| from synapse.types import JsonDict
 | |
| from synapse.util import json_encoder
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from synapse.server import HomeServer
 | |
| 
 | |
| 
 | |
| class SessionStore(SQLBaseStore):
 | |
|     """
 | |
|     A store for generic session data.
 | |
| 
 | |
|     Each type of session should provide a unique type (to separate sessions).
 | |
| 
 | |
|     Sessions are automatically removed when they expire.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         database: DatabasePool,
 | |
|         db_conn: LoggingDatabaseConnection,
 | |
|         hs: "HomeServer",
 | |
|     ):
 | |
|         super().__init__(database, db_conn, hs)
 | |
| 
 | |
|         # Create a background job for culling expired sessions.
 | |
|         if hs.config.worker.run_background_tasks:
 | |
|             self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000)
 | |
| 
 | |
|     async def create_session(
 | |
|         self, session_type: str, value: JsonDict, expiry_ms: int
 | |
|     ) -> str:
 | |
|         """
 | |
|         Creates a new pagination session for the room hierarchy endpoint.
 | |
| 
 | |
|         Args:
 | |
|             session_type: The type for this session.
 | |
|             value: The value to store.
 | |
|             expiry_ms: How long before an item is evicted from the cache
 | |
|                 in milliseconds. Default is 0, indicating items never get
 | |
|                 evicted based on time.
 | |
| 
 | |
|         Returns:
 | |
|             The newly created session ID.
 | |
| 
 | |
|         Raises:
 | |
|             StoreError if a unique session ID cannot be generated.
 | |
|         """
 | |
|         # autogen a session ID and try to create it. We may clash, so just
 | |
|         # try a few times till one goes through, giving up eventually.
 | |
|         attempts = 0
 | |
|         while attempts < 5:
 | |
|             session_id = stringutils.random_string(24)
 | |
| 
 | |
|             try:
 | |
|                 await self.db_pool.simple_insert(
 | |
|                     table="sessions",
 | |
|                     values={
 | |
|                         "session_id": session_id,
 | |
|                         "session_type": session_type,
 | |
|                         "value": json_encoder.encode(value),
 | |
|                         "expiry_time_ms": self.hs.get_clock().time_msec() + expiry_ms,
 | |
|                     },
 | |
|                     desc="create_session",
 | |
|                 )
 | |
| 
 | |
|                 return session_id
 | |
|             except self.db_pool.engine.module.IntegrityError:
 | |
|                 attempts += 1
 | |
|         raise StoreError(500, "Couldn't generate a session ID.")
 | |
| 
 | |
|     async def get_session(self, session_type: str, session_id: str) -> JsonDict:
 | |
|         """
 | |
|         Retrieve data stored with create_session
 | |
| 
 | |
|         Args:
 | |
|             session_type: The type for this session.
 | |
|             session_id: The session ID returned from create_session.
 | |
| 
 | |
|         Raises:
 | |
|             StoreError if the session cannot be found.
 | |
|         """
 | |
| 
 | |
|         def _get_session(
 | |
|             txn: LoggingTransaction, session_type: str, session_id: str, ts: int
 | |
|         ) -> JsonDict:
 | |
|             # This includes the expiry time since items are only periodically
 | |
|             # deleted, not upon expiry.
 | |
|             select_sql = """
 | |
|             SELECT value FROM sessions WHERE
 | |
|             session_type = ? AND session_id = ? AND expiry_time_ms > ?
 | |
|             """
 | |
|             txn.execute(select_sql, [session_type, session_id, ts])
 | |
|             row = txn.fetchone()
 | |
| 
 | |
|             if not row:
 | |
|                 raise StoreError(404, "No session")
 | |
| 
 | |
|             return db_to_json(row[0])
 | |
| 
 | |
|         return await self.db_pool.runInteraction(
 | |
|             "get_session",
 | |
|             _get_session,
 | |
|             session_type,
 | |
|             session_id,
 | |
|             self._clock.time_msec(),
 | |
|         )
 | |
| 
 | |
|     @wrap_as_background_process("delete_expired_sessions")
 | |
|     async def _delete_expired_sessions(self) -> None:
 | |
|         """Remove sessions with expiry dates that have passed."""
 | |
| 
 | |
|         def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None:
 | |
|             sql = "DELETE FROM sessions WHERE expiry_time_ms <= ?"
 | |
|             txn.execute(sql, (ts,))
 | |
| 
 | |
|         await self.db_pool.runInteraction(
 | |
|             "delete_expired_sessions",
 | |
|             _delete_expired_sessions_txn,
 | |
|             self._clock.time_msec(),
 | |
|         )
 |