415 lines
12 KiB
Python
415 lines
12 KiB
Python
|
import collections
|
||
|
import logging
|
||
|
|
||
|
from . import config
|
||
|
from . import engines
|
||
|
from . import util
|
||
|
from .. import exc
|
||
|
from .. import inspect
|
||
|
from ..engine import url as sa_url
|
||
|
from ..sql import ddl
|
||
|
from ..sql import schema
|
||
|
from ..util import compat
|
||
|
|
||
|
|
||
|
log = logging.getLogger(__name__)
|
||
|
|
||
|
FOLLOWER_IDENT = None
|
||
|
|
||
|
|
||
|
class register(object):
|
||
|
def __init__(self):
|
||
|
self.fns = {}
|
||
|
|
||
|
@classmethod
|
||
|
def init(cls, fn):
|
||
|
return register().for_db("*")(fn)
|
||
|
|
||
|
def for_db(self, *dbnames):
|
||
|
def decorate(fn):
|
||
|
for dbname in dbnames:
|
||
|
self.fns[dbname] = fn
|
||
|
return self
|
||
|
|
||
|
return decorate
|
||
|
|
||
|
def __call__(self, cfg, *arg):
|
||
|
if isinstance(cfg, compat.string_types):
|
||
|
url = sa_url.make_url(cfg)
|
||
|
elif isinstance(cfg, sa_url.URL):
|
||
|
url = cfg
|
||
|
else:
|
||
|
url = cfg.db.url
|
||
|
backend = url.get_backend_name()
|
||
|
if backend in self.fns:
|
||
|
return self.fns[backend](cfg, *arg)
|
||
|
else:
|
||
|
return self.fns["*"](cfg, *arg)
|
||
|
|
||
|
|
||
|
def create_follower_db(follower_ident):
|
||
|
for cfg in _configs_for_db_operation():
|
||
|
log.info("CREATE database %s, URI %r", follower_ident, cfg.db.url)
|
||
|
create_db(cfg, cfg.db, follower_ident)
|
||
|
|
||
|
|
||
|
def setup_config(db_url, options, file_config, follower_ident):
|
||
|
# load the dialect, which should also have it set up its provision
|
||
|
# hooks
|
||
|
|
||
|
dialect = sa_url.make_url(db_url).get_dialect()
|
||
|
dialect.load_provisioning()
|
||
|
|
||
|
if follower_ident:
|
||
|
db_url = follower_url_from_main(db_url, follower_ident)
|
||
|
db_opts = {}
|
||
|
update_db_opts(db_url, db_opts)
|
||
|
db_opts["scope"] = "global"
|
||
|
eng = engines.testing_engine(db_url, db_opts)
|
||
|
post_configure_engine(db_url, eng, follower_ident)
|
||
|
eng.connect().close()
|
||
|
|
||
|
cfg = config.Config.register(eng, db_opts, options, file_config)
|
||
|
|
||
|
# a symbolic name that tests can use if they need to disambiguate
|
||
|
# names across databases
|
||
|
if follower_ident:
|
||
|
config.ident = follower_ident
|
||
|
|
||
|
if follower_ident:
|
||
|
configure_follower(cfg, follower_ident)
|
||
|
return cfg
|
||
|
|
||
|
|
||
|
def drop_follower_db(follower_ident):
|
||
|
for cfg in _configs_for_db_operation():
|
||
|
log.info("DROP database %s, URI %r", follower_ident, cfg.db.url)
|
||
|
drop_db(cfg, cfg.db, follower_ident)
|
||
|
|
||
|
|
||
|
def generate_db_urls(db_urls, extra_drivers):
|
||
|
"""Generate a set of URLs to test given configured URLs plus additional
|
||
|
driver names.
|
||
|
|
||
|
Given::
|
||
|
|
||
|
--dburi postgresql://db1 \
|
||
|
--dburi postgresql://db2 \
|
||
|
--dburi postgresql://db2 \
|
||
|
--dbdriver=psycopg2 --dbdriver=asyncpg?async_fallback=true
|
||
|
|
||
|
Noting that the default postgresql driver is psycopg2, the output
|
||
|
would be::
|
||
|
|
||
|
postgresql+psycopg2://db1
|
||
|
postgresql+asyncpg://db1
|
||
|
postgresql+psycopg2://db2
|
||
|
postgresql+psycopg2://db3
|
||
|
|
||
|
That is, for the driver in a --dburi, we want to keep that and use that
|
||
|
driver for each URL it's part of . For a driver that is only
|
||
|
in --dbdrivers, we want to use it just once for one of the URLs.
|
||
|
for a driver that is both coming from --dburi as well as --dbdrivers,
|
||
|
we want to keep it in that dburi.
|
||
|
|
||
|
Driver specific query options can be specified by added them to the
|
||
|
driver name. For example, to enable the async fallback option for
|
||
|
asyncpg::
|
||
|
|
||
|
--dburi postgresql://db1 \
|
||
|
--dbdriver=asyncpg?async_fallback=true
|
||
|
|
||
|
"""
|
||
|
urls = set()
|
||
|
|
||
|
backend_to_driver_we_already_have = collections.defaultdict(set)
|
||
|
|
||
|
urls_plus_dialects = [
|
||
|
(url_obj, url_obj.get_dialect())
|
||
|
for url_obj in [sa_url.make_url(db_url) for db_url in db_urls]
|
||
|
]
|
||
|
|
||
|
for url_obj, dialect in urls_plus_dialects:
|
||
|
backend_to_driver_we_already_have[dialect.name].add(dialect.driver)
|
||
|
|
||
|
backend_to_driver_we_need = {}
|
||
|
|
||
|
for url_obj, dialect in urls_plus_dialects:
|
||
|
backend = dialect.name
|
||
|
dialect.load_provisioning()
|
||
|
|
||
|
if backend not in backend_to_driver_we_need:
|
||
|
backend_to_driver_we_need[backend] = extra_per_backend = set(
|
||
|
extra_drivers
|
||
|
).difference(backend_to_driver_we_already_have[backend])
|
||
|
else:
|
||
|
extra_per_backend = backend_to_driver_we_need[backend]
|
||
|
|
||
|
for driver_url in _generate_driver_urls(url_obj, extra_per_backend):
|
||
|
if driver_url in urls:
|
||
|
continue
|
||
|
urls.add(driver_url)
|
||
|
yield driver_url
|
||
|
|
||
|
|
||
|
def _generate_driver_urls(url, extra_drivers):
|
||
|
main_driver = url.get_driver_name()
|
||
|
extra_drivers.discard(main_driver)
|
||
|
|
||
|
url = generate_driver_url(url, main_driver, "")
|
||
|
yield str(url)
|
||
|
|
||
|
for drv in list(extra_drivers):
|
||
|
|
||
|
if "?" in drv:
|
||
|
|
||
|
driver_only, query_str = drv.split("?", 1)
|
||
|
|
||
|
else:
|
||
|
driver_only = drv
|
||
|
query_str = None
|
||
|
|
||
|
new_url = generate_driver_url(url, driver_only, query_str)
|
||
|
if new_url:
|
||
|
extra_drivers.remove(drv)
|
||
|
|
||
|
yield str(new_url)
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def generate_driver_url(url, driver, query_str):
|
||
|
backend = url.get_backend_name()
|
||
|
|
||
|
new_url = url.set(
|
||
|
drivername="%s+%s" % (backend, driver),
|
||
|
)
|
||
|
if query_str:
|
||
|
new_url = new_url.update_query_string(query_str)
|
||
|
|
||
|
try:
|
||
|
new_url.get_dialect()
|
||
|
except exc.NoSuchModuleError:
|
||
|
return None
|
||
|
else:
|
||
|
return new_url
|
||
|
|
||
|
|
||
|
def _configs_for_db_operation():
|
||
|
hosts = set()
|
||
|
|
||
|
for cfg in config.Config.all_configs():
|
||
|
cfg.db.dispose()
|
||
|
|
||
|
for cfg in config.Config.all_configs():
|
||
|
url = cfg.db.url
|
||
|
backend = url.get_backend_name()
|
||
|
host_conf = (backend, url.username, url.host, url.database)
|
||
|
|
||
|
if host_conf not in hosts:
|
||
|
yield cfg
|
||
|
hosts.add(host_conf)
|
||
|
|
||
|
for cfg in config.Config.all_configs():
|
||
|
cfg.db.dispose()
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def drop_all_schema_objects_pre_tables(cfg, eng):
|
||
|
pass
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def drop_all_schema_objects_post_tables(cfg, eng):
|
||
|
pass
|
||
|
|
||
|
|
||
|
def drop_all_schema_objects(cfg, eng):
|
||
|
|
||
|
drop_all_schema_objects_pre_tables(cfg, eng)
|
||
|
|
||
|
inspector = inspect(eng)
|
||
|
try:
|
||
|
view_names = inspector.get_view_names()
|
||
|
except NotImplementedError:
|
||
|
pass
|
||
|
else:
|
||
|
with eng.begin() as conn:
|
||
|
for vname in view_names:
|
||
|
conn.execute(
|
||
|
ddl._DropView(schema.Table(vname, schema.MetaData()))
|
||
|
)
|
||
|
|
||
|
if config.requirements.schemas.enabled_for_config(cfg):
|
||
|
try:
|
||
|
view_names = inspector.get_view_names(schema="test_schema")
|
||
|
except NotImplementedError:
|
||
|
pass
|
||
|
else:
|
||
|
with eng.begin() as conn:
|
||
|
for vname in view_names:
|
||
|
conn.execute(
|
||
|
ddl._DropView(
|
||
|
schema.Table(
|
||
|
vname,
|
||
|
schema.MetaData(),
|
||
|
schema="test_schema",
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
|
||
|
util.drop_all_tables(eng, inspector)
|
||
|
if config.requirements.schemas.enabled_for_config(cfg):
|
||
|
util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
|
||
|
util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
|
||
|
|
||
|
drop_all_schema_objects_post_tables(cfg, eng)
|
||
|
|
||
|
if config.requirements.sequences.enabled_for_config(cfg):
|
||
|
with eng.begin() as conn:
|
||
|
for seq in inspector.get_sequence_names():
|
||
|
conn.execute(ddl.DropSequence(schema.Sequence(seq)))
|
||
|
if config.requirements.schemas.enabled_for_config(cfg):
|
||
|
for schema_name in [cfg.test_schema, cfg.test_schema_2]:
|
||
|
for seq in inspector.get_sequence_names(
|
||
|
schema=schema_name
|
||
|
):
|
||
|
conn.execute(
|
||
|
ddl.DropSequence(
|
||
|
schema.Sequence(seq, schema=schema_name)
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def create_db(cfg, eng, ident):
|
||
|
"""Dynamically create a database for testing.
|
||
|
|
||
|
Used when a test run will employ multiple processes, e.g., when run
|
||
|
via `tox` or `pytest -n4`.
|
||
|
"""
|
||
|
raise NotImplementedError("no DB creation routine for cfg: %s" % eng.url)
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def drop_db(cfg, eng, ident):
|
||
|
"""Drop a database that we dynamically created for testing."""
|
||
|
raise NotImplementedError("no DB drop routine for cfg: %s" % eng.url)
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def update_db_opts(db_url, db_opts):
|
||
|
"""Set database options (db_opts) for a test database that we created."""
|
||
|
pass
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def post_configure_engine(url, engine, follower_ident):
|
||
|
"""Perform extra steps after configuring an engine for testing.
|
||
|
|
||
|
(For the internal dialects, currently only used by sqlite, oracle)
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def follower_url_from_main(url, ident):
|
||
|
"""Create a connection URL for a dynamically-created test database.
|
||
|
|
||
|
:param url: the connection URL specified when the test run was invoked
|
||
|
:param ident: the pytest-xdist "worker identifier" to be used as the
|
||
|
database name
|
||
|
"""
|
||
|
url = sa_url.make_url(url)
|
||
|
return url.set(database=ident)
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def configure_follower(cfg, ident):
|
||
|
"""Create dialect-specific config settings for a follower database."""
|
||
|
pass
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def run_reap_dbs(url, ident):
|
||
|
"""Remove databases that were created during the test process, after the
|
||
|
process has ended.
|
||
|
|
||
|
This is an optional step that is invoked for certain backends that do not
|
||
|
reliably release locks on the database as long as a process is still in
|
||
|
use. For the internal dialects, this is currently only necessary for
|
||
|
mssql and oracle.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
|
||
|
def reap_dbs(idents_file):
|
||
|
log.info("Reaping databases...")
|
||
|
|
||
|
urls = collections.defaultdict(set)
|
||
|
idents = collections.defaultdict(set)
|
||
|
dialects = {}
|
||
|
|
||
|
with open(idents_file) as file_:
|
||
|
for line in file_:
|
||
|
line = line.strip()
|
||
|
db_name, db_url = line.split(" ")
|
||
|
url_obj = sa_url.make_url(db_url)
|
||
|
if db_name not in dialects:
|
||
|
dialects[db_name] = url_obj.get_dialect()
|
||
|
dialects[db_name].load_provisioning()
|
||
|
url_key = (url_obj.get_backend_name(), url_obj.host)
|
||
|
urls[url_key].add(db_url)
|
||
|
idents[url_key].add(db_name)
|
||
|
|
||
|
for url_key in urls:
|
||
|
url = list(urls[url_key])[0]
|
||
|
ident = idents[url_key]
|
||
|
run_reap_dbs(url, ident)
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def temp_table_keyword_args(cfg, eng):
|
||
|
"""Specify keyword arguments for creating a temporary Table.
|
||
|
|
||
|
Dialect-specific implementations of this method will return the
|
||
|
kwargs that are passed to the Table method when creating a temporary
|
||
|
table for testing, e.g., in the define_temp_tables method of the
|
||
|
ComponentReflectionTest class in suite/test_reflection.py
|
||
|
"""
|
||
|
raise NotImplementedError(
|
||
|
"no temp table keyword args routine for cfg: %s" % eng.url
|
||
|
)
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def prepare_for_drop_tables(config, connection):
|
||
|
pass
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def stop_test_class_outside_fixtures(config, db, testcls):
|
||
|
pass
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def get_temp_table_name(cfg, eng, base_name):
|
||
|
"""Specify table name for creating a temporary Table.
|
||
|
|
||
|
Dialect-specific implementations of this method will return the
|
||
|
name to use when creating a temporary table for testing,
|
||
|
e.g., in the define_temp_tables method of the
|
||
|
ComponentReflectionTest class in suite/test_reflection.py
|
||
|
|
||
|
Default to just the base name since that's what most dialects will
|
||
|
use. The mssql dialect's implementation will need a "#" prepended.
|
||
|
"""
|
||
|
return base_name
|
||
|
|
||
|
|
||
|
@register.init
|
||
|
def set_default_schema_on_connection(cfg, dbapi_connection, schema_name):
|
||
|
raise NotImplementedError(
|
||
|
"backend does not implement a schema name set function: %s"
|
||
|
% (cfg.db.url,)
|
||
|
)
|