Move `tests.utils.setup_test_homeserver` to `tests.server`
It had no users. We have just taken the identity of a previous function but don't provide the same behaviour, so we need to fix this in the next commit...pull/11505/head
parent
f7ec6e7d9e
commit
b3fd99b74a
185
tests/server.py
185
tests/server.py
|
@ -11,9 +11,12 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from io import SEEK_END, BytesIO
|
from io import SEEK_END, BytesIO
|
||||||
from typing import (
|
from typing import (
|
||||||
|
@ -27,6 +30,7 @@ from typing import (
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import Deque
|
from typing_extensions import Deque
|
||||||
|
@ -53,10 +57,24 @@ from twisted.web.http_headers import Headers
|
||||||
from twisted.web.resource import IResource
|
from twisted.web.resource import IResource
|
||||||
from twisted.web.server import Request, Site
|
from twisted.web.server import Request, Site
|
||||||
|
|
||||||
|
from synapse.config.database import DatabaseConnectionConfig
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage import DataStore
|
||||||
|
from synapse.storage.engines import PostgresEngine, create_engine
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
from tests.utils import (
|
||||||
|
LEAVE_DB,
|
||||||
|
POSTGRES_BASE_DB,
|
||||||
|
POSTGRES_HOST,
|
||||||
|
POSTGRES_PASSWORD,
|
||||||
|
POSTGRES_USER,
|
||||||
|
USE_POSTGRES_FOR_TESTS,
|
||||||
|
MockClock,
|
||||||
|
default_config,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -668,3 +686,168 @@ def connect_client(
|
||||||
client.makeConnection(FakeTransport(server, reactor))
|
client.makeConnection(FakeTransport(server, reactor))
|
||||||
|
|
||||||
return client, server
|
return client, server
|
||||||
|
|
||||||
|
|
||||||
|
class TestHomeServer(HomeServer):
|
||||||
|
DATASTORE_CLASS = DataStore
|
||||||
|
|
||||||
|
|
||||||
|
def setup_test_homeserver(
|
||||||
|
cleanup_func,
|
||||||
|
name="test",
|
||||||
|
config=None,
|
||||||
|
reactor=None,
|
||||||
|
homeserver_to_use: Type[HomeServer] = TestHomeServer,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Setup a homeserver suitable for running tests against. Keyword arguments
|
||||||
|
are passed to the Homeserver constructor.
|
||||||
|
|
||||||
|
If no datastore is supplied, one is created and given to the homeserver.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cleanup_func : The function used to register a cleanup routine for
|
||||||
|
after the test.
|
||||||
|
|
||||||
|
Calling this method directly is deprecated: you should instead derive from
|
||||||
|
HomeserverTestCase.
|
||||||
|
"""
|
||||||
|
if reactor is None:
|
||||||
|
from twisted.internet import reactor
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = default_config(name, parse=True)
|
||||||
|
|
||||||
|
config.ldap_enabled = False
|
||||||
|
|
||||||
|
if "clock" not in kwargs:
|
||||||
|
kwargs["clock"] = MockClock()
|
||||||
|
|
||||||
|
if USE_POSTGRES_FOR_TESTS:
|
||||||
|
test_db = "synapse_test_%s" % uuid.uuid4().hex
|
||||||
|
|
||||||
|
database_config = {
|
||||||
|
"name": "psycopg2",
|
||||||
|
"args": {
|
||||||
|
"database": test_db,
|
||||||
|
"host": POSTGRES_HOST,
|
||||||
|
"password": POSTGRES_PASSWORD,
|
||||||
|
"user": POSTGRES_USER,
|
||||||
|
"cp_min": 1,
|
||||||
|
"cp_max": 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
database_config = {
|
||||||
|
"name": "sqlite3",
|
||||||
|
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
if "db_txn_limit" in kwargs:
|
||||||
|
database_config["txn_limit"] = kwargs["db_txn_limit"]
|
||||||
|
|
||||||
|
database = DatabaseConnectionConfig("master", database_config)
|
||||||
|
config.database.databases = [database]
|
||||||
|
|
||||||
|
db_engine = create_engine(database.config)
|
||||||
|
|
||||||
|
# Create the database before we actually try and connect to it, based off
|
||||||
|
# the template database we generate in setupdb()
|
||||||
|
if isinstance(db_engine, PostgresEngine):
|
||||||
|
db_conn = db_engine.module.connect(
|
||||||
|
database=POSTGRES_BASE_DB,
|
||||||
|
user=POSTGRES_USER,
|
||||||
|
host=POSTGRES_HOST,
|
||||||
|
password=POSTGRES_PASSWORD,
|
||||||
|
)
|
||||||
|
db_conn.autocommit = True
|
||||||
|
cur = db_conn.cursor()
|
||||||
|
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
|
||||||
|
cur.execute(
|
||||||
|
"CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
|
||||||
|
)
|
||||||
|
cur.close()
|
||||||
|
db_conn.close()
|
||||||
|
|
||||||
|
hs = homeserver_to_use(
|
||||||
|
name,
|
||||||
|
config=config,
|
||||||
|
version_string="Synapse/tests",
|
||||||
|
reactor=reactor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Install @cache_in_self attributes
|
||||||
|
for key, val in kwargs.items():
|
||||||
|
setattr(hs, "_" + key, val)
|
||||||
|
|
||||||
|
# Mock TLS
|
||||||
|
hs.tls_server_context_factory = Mock()
|
||||||
|
hs.tls_client_options_factory = Mock()
|
||||||
|
|
||||||
|
hs.setup()
|
||||||
|
if homeserver_to_use == TestHomeServer:
|
||||||
|
hs.setup_background_tasks()
|
||||||
|
|
||||||
|
if isinstance(db_engine, PostgresEngine):
|
||||||
|
database = hs.get_datastores().databases[0]
|
||||||
|
|
||||||
|
# We need to do cleanup on PostgreSQL
|
||||||
|
def cleanup():
|
||||||
|
import psycopg2
|
||||||
|
|
||||||
|
# Close all the db pools
|
||||||
|
database._db_pool.close()
|
||||||
|
|
||||||
|
dropped = False
|
||||||
|
|
||||||
|
# Drop the test database
|
||||||
|
db_conn = db_engine.module.connect(
|
||||||
|
database=POSTGRES_BASE_DB,
|
||||||
|
user=POSTGRES_USER,
|
||||||
|
host=POSTGRES_HOST,
|
||||||
|
password=POSTGRES_PASSWORD,
|
||||||
|
)
|
||||||
|
db_conn.autocommit = True
|
||||||
|
cur = db_conn.cursor()
|
||||||
|
|
||||||
|
# Try a few times to drop the DB. Some things may hold on to the
|
||||||
|
# database for a few more seconds due to flakiness, preventing
|
||||||
|
# us from dropping it when the test is over. If we can't drop
|
||||||
|
# it, warn and move on.
|
||||||
|
for _ in range(5):
|
||||||
|
try:
|
||||||
|
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
|
||||||
|
db_conn.commit()
|
||||||
|
dropped = True
|
||||||
|
except psycopg2.OperationalError as e:
|
||||||
|
warnings.warn(
|
||||||
|
"Couldn't drop old db: " + str(e), category=UserWarning
|
||||||
|
)
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
cur.close()
|
||||||
|
db_conn.close()
|
||||||
|
|
||||||
|
if not dropped:
|
||||||
|
warnings.warn("Failed to drop old DB.", category=UserWarning)
|
||||||
|
|
||||||
|
if not LEAVE_DB:
|
||||||
|
# Register the cleanup hook
|
||||||
|
cleanup_func(cleanup)
|
||||||
|
|
||||||
|
# bcrypt is far too slow to be doing in unit tests
|
||||||
|
# Need to let the HS build an auth handler and then mess with it
|
||||||
|
# because AuthHandler's constructor requires the HS, so we can't make one
|
||||||
|
# beforehand and pass it in to the HS's constructor (chicken / egg)
|
||||||
|
async def hash(p):
|
||||||
|
return hashlib.md5(p.encode("utf8")).hexdigest()
|
||||||
|
|
||||||
|
hs.get_auth_handler().hash = hash
|
||||||
|
|
||||||
|
async def validate_hash(p, h):
|
||||||
|
return hashlib.md5(p.encode("utf8")).hexdigest() == h
|
||||||
|
|
||||||
|
hs.get_auth_handler().validate_hash = validate_hash
|
||||||
|
|
||||||
|
return hs
|
||||||
|
|
|
@ -23,7 +23,8 @@ from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.utils import TestHomeServer, default_config
|
from tests.server import TestHomeServer
|
||||||
|
from tests.utils import default_config
|
||||||
|
|
||||||
|
|
||||||
class SQLBaseStoreTestCase(unittest.TestCase):
|
class SQLBaseStoreTestCase(unittest.TestCase):
|
||||||
|
|
|
@ -19,8 +19,8 @@ from synapse.rest.client import login, room
|
||||||
from synapse.types import UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.server import TestHomeServer
|
||||||
from tests.test_utils import event_injection
|
from tests.test_utils import event_injection
|
||||||
from tests.utils import TestHomeServer
|
|
||||||
|
|
||||||
|
|
||||||
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
||||||
|
|
175
tests/utils.py
175
tests/utils.py
|
@ -14,12 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
import hashlib
|
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
import warnings
|
|
||||||
from typing import Type
|
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
from urllib import parse as urlparse
|
from urllib import parse as urlparse
|
||||||
|
|
||||||
|
@ -28,14 +23,11 @@ from twisted.internet import defer
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import CodeMessageException, cs_error
|
from synapse.api.errors import CodeMessageException, cs_error
|
||||||
from synapse.api.room_versions import RoomVersions
|
from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.config.database import DatabaseConnectionConfig
|
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||||
from synapse.logging.context import current_context, set_current_context
|
from synapse.logging.context import current_context, set_current_context
|
||||||
from synapse.server import HomeServer
|
|
||||||
from synapse.storage import DataStore
|
|
||||||
from synapse.storage.database import LoggingDatabaseConnection
|
from synapse.storage.database import LoggingDatabaseConnection
|
||||||
from synapse.storage.engines import PostgresEngine, create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.prepare_database import prepare_database
|
from synapse.storage.prepare_database import prepare_database
|
||||||
|
|
||||||
# set this to True to run the tests against postgres instead of sqlite.
|
# set this to True to run the tests against postgres instead of sqlite.
|
||||||
|
@ -182,171 +174,6 @@ def default_config(name, parse=False):
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
class TestHomeServer(HomeServer):
|
|
||||||
DATASTORE_CLASS = DataStore
|
|
||||||
|
|
||||||
|
|
||||||
def setup_test_homeserver(
|
|
||||||
cleanup_func,
|
|
||||||
name="test",
|
|
||||||
config=None,
|
|
||||||
reactor=None,
|
|
||||||
homeserver_to_use: Type[HomeServer] = TestHomeServer,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Setup a homeserver suitable for running tests against. Keyword arguments
|
|
||||||
are passed to the Homeserver constructor.
|
|
||||||
|
|
||||||
If no datastore is supplied, one is created and given to the homeserver.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cleanup_func : The function used to register a cleanup routine for
|
|
||||||
after the test.
|
|
||||||
|
|
||||||
Calling this method directly is deprecated: you should instead derive from
|
|
||||||
HomeserverTestCase.
|
|
||||||
"""
|
|
||||||
if reactor is None:
|
|
||||||
from twisted.internet import reactor
|
|
||||||
|
|
||||||
if config is None:
|
|
||||||
config = default_config(name, parse=True)
|
|
||||||
|
|
||||||
config.ldap_enabled = False
|
|
||||||
|
|
||||||
if "clock" not in kwargs:
|
|
||||||
kwargs["clock"] = MockClock()
|
|
||||||
|
|
||||||
if USE_POSTGRES_FOR_TESTS:
|
|
||||||
test_db = "synapse_test_%s" % uuid.uuid4().hex
|
|
||||||
|
|
||||||
database_config = {
|
|
||||||
"name": "psycopg2",
|
|
||||||
"args": {
|
|
||||||
"database": test_db,
|
|
||||||
"host": POSTGRES_HOST,
|
|
||||||
"password": POSTGRES_PASSWORD,
|
|
||||||
"user": POSTGRES_USER,
|
|
||||||
"cp_min": 1,
|
|
||||||
"cp_max": 5,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
database_config = {
|
|
||||||
"name": "sqlite3",
|
|
||||||
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
|
|
||||||
}
|
|
||||||
|
|
||||||
if "db_txn_limit" in kwargs:
|
|
||||||
database_config["txn_limit"] = kwargs["db_txn_limit"]
|
|
||||||
|
|
||||||
database = DatabaseConnectionConfig("master", database_config)
|
|
||||||
config.database.databases = [database]
|
|
||||||
|
|
||||||
db_engine = create_engine(database.config)
|
|
||||||
|
|
||||||
# Create the database before we actually try and connect to it, based off
|
|
||||||
# the template database we generate in setupdb()
|
|
||||||
if isinstance(db_engine, PostgresEngine):
|
|
||||||
db_conn = db_engine.module.connect(
|
|
||||||
database=POSTGRES_BASE_DB,
|
|
||||||
user=POSTGRES_USER,
|
|
||||||
host=POSTGRES_HOST,
|
|
||||||
password=POSTGRES_PASSWORD,
|
|
||||||
)
|
|
||||||
db_conn.autocommit = True
|
|
||||||
cur = db_conn.cursor()
|
|
||||||
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
|
|
||||||
cur.execute(
|
|
||||||
"CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
|
|
||||||
)
|
|
||||||
cur.close()
|
|
||||||
db_conn.close()
|
|
||||||
|
|
||||||
hs = homeserver_to_use(
|
|
||||||
name,
|
|
||||||
config=config,
|
|
||||||
version_string="Synapse/tests",
|
|
||||||
reactor=reactor,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Install @cache_in_self attributes
|
|
||||||
for key, val in kwargs.items():
|
|
||||||
setattr(hs, "_" + key, val)
|
|
||||||
|
|
||||||
# Mock TLS
|
|
||||||
hs.tls_server_context_factory = Mock()
|
|
||||||
hs.tls_client_options_factory = Mock()
|
|
||||||
|
|
||||||
hs.setup()
|
|
||||||
if homeserver_to_use == TestHomeServer:
|
|
||||||
hs.setup_background_tasks()
|
|
||||||
|
|
||||||
if isinstance(db_engine, PostgresEngine):
|
|
||||||
database = hs.get_datastores().databases[0]
|
|
||||||
|
|
||||||
# We need to do cleanup on PostgreSQL
|
|
||||||
def cleanup():
|
|
||||||
import psycopg2
|
|
||||||
|
|
||||||
# Close all the db pools
|
|
||||||
database._db_pool.close()
|
|
||||||
|
|
||||||
dropped = False
|
|
||||||
|
|
||||||
# Drop the test database
|
|
||||||
db_conn = db_engine.module.connect(
|
|
||||||
database=POSTGRES_BASE_DB,
|
|
||||||
user=POSTGRES_USER,
|
|
||||||
host=POSTGRES_HOST,
|
|
||||||
password=POSTGRES_PASSWORD,
|
|
||||||
)
|
|
||||||
db_conn.autocommit = True
|
|
||||||
cur = db_conn.cursor()
|
|
||||||
|
|
||||||
# Try a few times to drop the DB. Some things may hold on to the
|
|
||||||
# database for a few more seconds due to flakiness, preventing
|
|
||||||
# us from dropping it when the test is over. If we can't drop
|
|
||||||
# it, warn and move on.
|
|
||||||
for _ in range(5):
|
|
||||||
try:
|
|
||||||
cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
|
|
||||||
db_conn.commit()
|
|
||||||
dropped = True
|
|
||||||
except psycopg2.OperationalError as e:
|
|
||||||
warnings.warn(
|
|
||||||
"Couldn't drop old db: " + str(e), category=UserWarning
|
|
||||||
)
|
|
||||||
time.sleep(0.5)
|
|
||||||
|
|
||||||
cur.close()
|
|
||||||
db_conn.close()
|
|
||||||
|
|
||||||
if not dropped:
|
|
||||||
warnings.warn("Failed to drop old DB.", category=UserWarning)
|
|
||||||
|
|
||||||
if not LEAVE_DB:
|
|
||||||
# Register the cleanup hook
|
|
||||||
cleanup_func(cleanup)
|
|
||||||
|
|
||||||
# bcrypt is far too slow to be doing in unit tests
|
|
||||||
# Need to let the HS build an auth handler and then mess with it
|
|
||||||
# because AuthHandler's constructor requires the HS, so we can't make one
|
|
||||||
# beforehand and pass it in to the HS's constructor (chicken / egg)
|
|
||||||
async def hash(p):
|
|
||||||
return hashlib.md5(p.encode("utf8")).hexdigest()
|
|
||||||
|
|
||||||
hs.get_auth_handler().hash = hash
|
|
||||||
|
|
||||||
async def validate_hash(p, h):
|
|
||||||
return hashlib.md5(p.encode("utf8")).hexdigest() == h
|
|
||||||
|
|
||||||
hs.get_auth_handler().validate_hash = validate_hash
|
|
||||||
|
|
||||||
return hs
|
|
||||||
|
|
||||||
|
|
||||||
def mock_getRawHeaders(headers=None):
|
def mock_getRawHeaders(headers=None):
|
||||||
headers = headers if headers is not None else {}
|
headers = headers if headers is not None else {}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue