You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

790 lines
21KB

  1. # plugin/plugin_base.py
  2. # Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. """Testing extensions.
  8. this module is designed to work as a testing-framework-agnostic library,
  9. created so that multiple test frameworks can be supported at once
  10. (mostly so that we can migrate to new ones). The current target
  11. is pytest.
  12. """
  13. from __future__ import absolute_import
  14. import abc
  15. import logging
  16. import re
  17. import sys
  18. # flag which indicates we are in the SQLAlchemy testing suite,
  19. # and not that of Alembic or a third party dialect.
  20. bootstrapped_as_sqlalchemy = False
  21. log = logging.getLogger("sqlalchemy.testing.plugin_base")
  22. py3k = sys.version_info >= (3, 0)
  23. if py3k:
  24. import configparser
  25. ABC = abc.ABC
  26. else:
  27. import ConfigParser as configparser
  28. import collections as collections_abc # noqa
  29. class ABC(object):
  30. __metaclass__ = abc.ABCMeta
  31. # late imports
  32. fixtures = None
  33. engines = None
  34. exclusions = None
  35. warnings = None
  36. profiling = None
  37. provision = None
  38. assertions = None
  39. requirements = None
  40. config = None
  41. testing = None
  42. util = None
  43. file_config = None
  44. logging = None
  45. include_tags = set()
  46. exclude_tags = set()
  47. options = None
  48. def setup_options(make_option):
  49. make_option(
  50. "--log-info",
  51. action="callback",
  52. type=str,
  53. callback=_log,
  54. help="turn on info logging for <LOG> (multiple OK)",
  55. )
  56. make_option(
  57. "--log-debug",
  58. action="callback",
  59. type=str,
  60. callback=_log,
  61. help="turn on debug logging for <LOG> (multiple OK)",
  62. )
  63. make_option(
  64. "--db",
  65. action="append",
  66. type=str,
  67. dest="db",
  68. help="Use prefab database uri. Multiple OK, "
  69. "first one is run by default.",
  70. )
  71. make_option(
  72. "--dbs",
  73. action="callback",
  74. zeroarg_callback=_list_dbs,
  75. help="List available prefab dbs",
  76. )
  77. make_option(
  78. "--dburi",
  79. action="append",
  80. type=str,
  81. dest="dburi",
  82. help="Database uri. Multiple OK, " "first one is run by default.",
  83. )
  84. make_option(
  85. "--dbdriver",
  86. action="append",
  87. type="string",
  88. dest="dbdriver",
  89. help="Additional database drivers to include in tests. "
  90. "These are linked to the existing database URLs by the "
  91. "provisioning system.",
  92. )
  93. make_option(
  94. "--dropfirst",
  95. action="store_true",
  96. dest="dropfirst",
  97. help="Drop all tables in the target database first",
  98. )
  99. make_option(
  100. "--disable-asyncio",
  101. action="store_true",
  102. help="disable test / fixtures / provisoning running in asyncio",
  103. )
  104. make_option(
  105. "--backend-only",
  106. action="store_true",
  107. dest="backend_only",
  108. help="Run only tests marked with __backend__ or __sparse_backend__",
  109. )
  110. make_option(
  111. "--nomemory",
  112. action="store_true",
  113. dest="nomemory",
  114. help="Don't run memory profiling tests",
  115. )
  116. make_option(
  117. "--notimingintensive",
  118. action="store_true",
  119. dest="notimingintensive",
  120. help="Don't run timing intensive tests",
  121. )
  122. make_option(
  123. "--profile-sort",
  124. type=str,
  125. default="cumulative",
  126. dest="profilesort",
  127. help="Type of sort for profiling standard output",
  128. )
  129. make_option(
  130. "--profile-dump",
  131. type=str,
  132. dest="profiledump",
  133. help="Filename where a single profile run will be dumped",
  134. )
  135. make_option(
  136. "--postgresql-templatedb",
  137. type=str,
  138. help="name of template database to use for PostgreSQL "
  139. "CREATE DATABASE (defaults to current database)",
  140. )
  141. make_option(
  142. "--low-connections",
  143. action="store_true",
  144. dest="low_connections",
  145. help="Use a low number of distinct connections - "
  146. "i.e. for Oracle TNS",
  147. )
  148. make_option(
  149. "--write-idents",
  150. type=str,
  151. dest="write_idents",
  152. help="write out generated follower idents to <file>, "
  153. "when -n<num> is used",
  154. )
  155. make_option(
  156. "--reversetop",
  157. action="store_true",
  158. dest="reversetop",
  159. default=False,
  160. help="Use a random-ordering set implementation in the ORM "
  161. "(helps reveal dependency issues)",
  162. )
  163. make_option(
  164. "--requirements",
  165. action="callback",
  166. type=str,
  167. callback=_requirements_opt,
  168. help="requirements class for testing, overrides setup.cfg",
  169. )
  170. make_option(
  171. "--with-cdecimal",
  172. action="store_true",
  173. dest="cdecimal",
  174. default=False,
  175. help="Monkeypatch the cdecimal library into Python 'decimal' "
  176. "for all tests",
  177. )
  178. make_option(
  179. "--include-tag",
  180. action="callback",
  181. callback=_include_tag,
  182. type=str,
  183. help="Include tests with tag <tag>",
  184. )
  185. make_option(
  186. "--exclude-tag",
  187. action="callback",
  188. callback=_exclude_tag,
  189. type=str,
  190. help="Exclude tests with tag <tag>",
  191. )
  192. make_option(
  193. "--write-profiles",
  194. action="store_true",
  195. dest="write_profiles",
  196. default=False,
  197. help="Write/update failing profiling data.",
  198. )
  199. make_option(
  200. "--force-write-profiles",
  201. action="store_true",
  202. dest="force_write_profiles",
  203. default=False,
  204. help="Unconditionally write/update profiling data.",
  205. )
  206. make_option(
  207. "--dump-pyannotate",
  208. type=str,
  209. dest="dump_pyannotate",
  210. help="Run pyannotate and dump json info to given file",
  211. )
  212. make_option(
  213. "--mypy-extra-test-path",
  214. type=str,
  215. action="append",
  216. default=[],
  217. dest="mypy_extra_test_paths",
  218. help="Additional test directories to add to the mypy tests. "
  219. "This is used only when running mypy tests. Multiple OK",
  220. )
  221. def configure_follower(follower_ident):
  222. """Configure required state for a follower.
  223. This invokes in the parent process and typically includes
  224. database creation.
  225. """
  226. from sqlalchemy.testing import provision
  227. provision.FOLLOWER_IDENT = follower_ident
  228. def memoize_important_follower_config(dict_):
  229. """Store important configuration we will need to send to a follower.
  230. This invokes in the parent process after normal config is set up.
  231. This is necessary as pytest seems to not be using forking, so we
  232. start with nothing in memory, *but* it isn't running our argparse
  233. callables, so we have to just copy all of that over.
  234. """
  235. dict_["memoized_config"] = {
  236. "include_tags": include_tags,
  237. "exclude_tags": exclude_tags,
  238. }
  239. def restore_important_follower_config(dict_):
  240. """Restore important configuration needed by a follower.
  241. This invokes in the follower process.
  242. """
  243. global include_tags, exclude_tags
  244. include_tags.update(dict_["memoized_config"]["include_tags"])
  245. exclude_tags.update(dict_["memoized_config"]["exclude_tags"])
  246. def read_config():
  247. global file_config
  248. file_config = configparser.ConfigParser()
  249. file_config.read(["setup.cfg", "test.cfg"])
  250. def pre_begin(opt):
  251. """things to set up early, before coverage might be setup."""
  252. global options
  253. options = opt
  254. for fn in pre_configure:
  255. fn(options, file_config)
  256. def set_coverage_flag(value):
  257. options.has_coverage = value
  258. def post_begin():
  259. """things to set up later, once we know coverage is running."""
  260. # Lazy setup of other options (post coverage)
  261. for fn in post_configure:
  262. fn(options, file_config)
  263. # late imports, has to happen after config.
  264. global util, fixtures, engines, exclusions, assertions, provision
  265. global warnings, profiling, config, testing
  266. from sqlalchemy import testing # noqa
  267. from sqlalchemy.testing import fixtures, engines, exclusions # noqa
  268. from sqlalchemy.testing import assertions, warnings, profiling # noqa
  269. from sqlalchemy.testing import config, provision # noqa
  270. from sqlalchemy import util # noqa
  271. warnings.setup_filters()
  272. def _log(opt_str, value, parser):
  273. global logging
  274. if not logging:
  275. import logging
  276. logging.basicConfig()
  277. if opt_str.endswith("-info"):
  278. logging.getLogger(value).setLevel(logging.INFO)
  279. elif opt_str.endswith("-debug"):
  280. logging.getLogger(value).setLevel(logging.DEBUG)
  281. def _list_dbs(*args):
  282. print("Available --db options (use --dburi to override)")
  283. for macro in sorted(file_config.options("db")):
  284. print("%20s\t%s" % (macro, file_config.get("db", macro)))
  285. sys.exit(0)
  286. def _requirements_opt(opt_str, value, parser):
  287. _setup_requirements(value)
  288. def _exclude_tag(opt_str, value, parser):
  289. exclude_tags.add(value.replace("-", "_"))
  290. def _include_tag(opt_str, value, parser):
  291. include_tags.add(value.replace("-", "_"))
  292. pre_configure = []
  293. post_configure = []
  294. def pre(fn):
  295. pre_configure.append(fn)
  296. return fn
  297. def post(fn):
  298. post_configure.append(fn)
  299. return fn
  300. @pre
  301. def _setup_options(opt, file_config):
  302. global options
  303. options = opt
  304. @pre
  305. def _set_nomemory(opt, file_config):
  306. if opt.nomemory:
  307. exclude_tags.add("memory_intensive")
  308. @pre
  309. def _set_notimingintensive(opt, file_config):
  310. if opt.notimingintensive:
  311. exclude_tags.add("timing_intensive")
  312. @pre
  313. def _monkeypatch_cdecimal(options, file_config):
  314. if options.cdecimal:
  315. import cdecimal
  316. sys.modules["decimal"] = cdecimal
  317. @post
  318. def _init_symbols(options, file_config):
  319. from sqlalchemy.testing import config
  320. config._fixture_functions = _fixture_fn_class()
  321. @post
  322. def _set_disable_asyncio(opt, file_config):
  323. if opt.disable_asyncio or not py3k:
  324. from sqlalchemy.testing import asyncio
  325. asyncio.ENABLE_ASYNCIO = False
  326. @post
  327. def _engine_uri(options, file_config):
  328. from sqlalchemy import testing
  329. from sqlalchemy.testing import config
  330. from sqlalchemy.testing import provision
  331. if options.dburi:
  332. db_urls = list(options.dburi)
  333. else:
  334. db_urls = []
  335. extra_drivers = options.dbdriver or []
  336. if options.db:
  337. for db_token in options.db:
  338. for db in re.split(r"[,\s]+", db_token):
  339. if db not in file_config.options("db"):
  340. raise RuntimeError(
  341. "Unknown URI specifier '%s'. "
  342. "Specify --dbs for known uris." % db
  343. )
  344. else:
  345. db_urls.append(file_config.get("db", db))
  346. if not db_urls:
  347. db_urls.append(file_config.get("db", "default"))
  348. config._current = None
  349. expanded_urls = list(provision.generate_db_urls(db_urls, extra_drivers))
  350. for db_url in expanded_urls:
  351. log.info("Adding database URL: %s", db_url)
  352. if options.write_idents and provision.FOLLOWER_IDENT: # != 'master':
  353. with open(options.write_idents, "a") as file_:
  354. file_.write(provision.FOLLOWER_IDENT + " " + db_url + "\n")
  355. cfg = provision.setup_config(
  356. db_url, options, file_config, provision.FOLLOWER_IDENT
  357. )
  358. if not config._current:
  359. cfg.set_as_current(cfg, testing)
  360. @post
  361. def _requirements(options, file_config):
  362. requirement_cls = file_config.get("sqla_testing", "requirement_cls")
  363. _setup_requirements(requirement_cls)
  364. def _setup_requirements(argument):
  365. from sqlalchemy.testing import config
  366. from sqlalchemy import testing
  367. if config.requirements is not None:
  368. return
  369. modname, clsname = argument.split(":")
  370. # importlib.import_module() only introduced in 2.7, a little
  371. # late
  372. mod = __import__(modname)
  373. for component in modname.split(".")[1:]:
  374. mod = getattr(mod, component)
  375. req_cls = getattr(mod, clsname)
  376. config.requirements = testing.requires = req_cls()
  377. config.bootstrapped_as_sqlalchemy = bootstrapped_as_sqlalchemy
  378. @post
  379. def _prep_testing_database(options, file_config):
  380. from sqlalchemy.testing import config
  381. if options.dropfirst:
  382. from sqlalchemy.testing import provision
  383. for cfg in config.Config.all_configs():
  384. provision.drop_all_schema_objects(cfg, cfg.db)
  385. @post
  386. def _reverse_topological(options, file_config):
  387. if options.reversetop:
  388. from sqlalchemy.orm.util import randomize_unitofwork
  389. randomize_unitofwork()
  390. @post
  391. def _post_setup_options(opt, file_config):
  392. from sqlalchemy.testing import config
  393. config.options = options
  394. config.file_config = file_config
  395. @post
  396. def _setup_profiling(options, file_config):
  397. from sqlalchemy.testing import profiling
  398. profiling._profile_stats = profiling.ProfileStatsFile(
  399. file_config.get("sqla_testing", "profile_file"),
  400. sort=options.profilesort,
  401. dump=options.profiledump,
  402. )
  403. def want_class(name, cls):
  404. if not issubclass(cls, fixtures.TestBase):
  405. return False
  406. elif name.startswith("_"):
  407. return False
  408. elif (
  409. config.options.backend_only
  410. and not getattr(cls, "__backend__", False)
  411. and not getattr(cls, "__sparse_backend__", False)
  412. and not getattr(cls, "__only_on__", False)
  413. ):
  414. return False
  415. else:
  416. return True
  417. def want_method(cls, fn):
  418. if not fn.__name__.startswith("test_"):
  419. return False
  420. elif fn.__module__ is None:
  421. return False
  422. elif include_tags:
  423. return (
  424. hasattr(cls, "__tags__")
  425. and exclusions.tags(cls.__tags__).include_test(
  426. include_tags, exclude_tags
  427. )
  428. ) or (
  429. hasattr(fn, "_sa_exclusion_extend")
  430. and fn._sa_exclusion_extend.include_test(
  431. include_tags, exclude_tags
  432. )
  433. )
  434. elif exclude_tags and hasattr(cls, "__tags__"):
  435. return exclusions.tags(cls.__tags__).include_test(
  436. include_tags, exclude_tags
  437. )
  438. elif exclude_tags and hasattr(fn, "_sa_exclusion_extend"):
  439. return fn._sa_exclusion_extend.include_test(include_tags, exclude_tags)
  440. else:
  441. return True
  442. def generate_sub_tests(cls, module):
  443. if getattr(cls, "__backend__", False) or getattr(
  444. cls, "__sparse_backend__", False
  445. ):
  446. sparse = getattr(cls, "__sparse_backend__", False)
  447. for cfg in _possible_configs_for_cls(cls, sparse=sparse):
  448. orig_name = cls.__name__
  449. # we can have special chars in these names except for the
  450. # pytest junit plugin, which is tripped up by the brackets
  451. # and periods, so sanitize
  452. alpha_name = re.sub(r"[_\[\]\.]+", "_", cfg.name)
  453. alpha_name = re.sub(r"_+$", "", alpha_name)
  454. name = "%s_%s" % (cls.__name__, alpha_name)
  455. subcls = type(
  456. name,
  457. (cls,),
  458. {"_sa_orig_cls_name": orig_name, "__only_on_config__": cfg},
  459. )
  460. setattr(module, name, subcls)
  461. yield subcls
  462. else:
  463. yield cls
  464. def start_test_class_outside_fixtures(cls):
  465. _do_skips(cls)
  466. _setup_engine(cls)
  467. def stop_test_class(cls):
  468. # close sessions, immediate connections, etc.
  469. fixtures.stop_test_class_inside_fixtures(cls)
  470. # close outstanding connection pool connections, dispose of
  471. # additional engines
  472. engines.testing_reaper.stop_test_class_inside_fixtures()
  473. def stop_test_class_outside_fixtures(cls):
  474. engines.testing_reaper.stop_test_class_outside_fixtures()
  475. provision.stop_test_class_outside_fixtures(config, config.db, cls)
  476. try:
  477. if not options.low_connections:
  478. assertions.global_cleanup_assertions()
  479. finally:
  480. _restore_engine()
  481. def _restore_engine():
  482. if config._current:
  483. config._current.reset(testing)
  484. def final_process_cleanup():
  485. engines.testing_reaper.final_cleanup()
  486. assertions.global_cleanup_assertions()
  487. _restore_engine()
  488. def _setup_engine(cls):
  489. if getattr(cls, "__engine_options__", None):
  490. opts = dict(cls.__engine_options__)
  491. opts["scope"] = "class"
  492. eng = engines.testing_engine(options=opts)
  493. config._current.push_engine(eng, testing)
  494. def before_test(test, test_module_name, test_class, test_name):
  495. # format looks like:
  496. # "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
  497. name = getattr(test_class, "_sa_orig_cls_name", test_class.__name__)
  498. id_ = "%s.%s.%s" % (test_module_name, name, test_name)
  499. profiling._start_current_test(id_)
  500. def after_test(test):
  501. fixtures.after_test()
  502. engines.testing_reaper.after_test()
  503. def after_test_fixtures(test):
  504. engines.testing_reaper.after_test_outside_fixtures(test)
  505. def _possible_configs_for_cls(cls, reasons=None, sparse=False):
  506. all_configs = set(config.Config.all_configs())
  507. if cls.__unsupported_on__:
  508. spec = exclusions.db_spec(*cls.__unsupported_on__)
  509. for config_obj in list(all_configs):
  510. if spec(config_obj):
  511. all_configs.remove(config_obj)
  512. if getattr(cls, "__only_on__", None):
  513. spec = exclusions.db_spec(*util.to_list(cls.__only_on__))
  514. for config_obj in list(all_configs):
  515. if not spec(config_obj):
  516. all_configs.remove(config_obj)
  517. if getattr(cls, "__only_on_config__", None):
  518. all_configs.intersection_update([cls.__only_on_config__])
  519. if hasattr(cls, "__requires__"):
  520. requirements = config.requirements
  521. for config_obj in list(all_configs):
  522. for requirement in cls.__requires__:
  523. check = getattr(requirements, requirement)
  524. skip_reasons = check.matching_config_reasons(config_obj)
  525. if skip_reasons:
  526. all_configs.remove(config_obj)
  527. if reasons is not None:
  528. reasons.extend(skip_reasons)
  529. break
  530. if hasattr(cls, "__prefer_requires__"):
  531. non_preferred = set()
  532. requirements = config.requirements
  533. for config_obj in list(all_configs):
  534. for requirement in cls.__prefer_requires__:
  535. check = getattr(requirements, requirement)
  536. if not check.enabled_for_config(config_obj):
  537. non_preferred.add(config_obj)
  538. if all_configs.difference(non_preferred):
  539. all_configs.difference_update(non_preferred)
  540. if sparse:
  541. # pick only one config from each base dialect
  542. # sorted so we get the same backend each time selecting the highest
  543. # server version info.
  544. per_dialect = {}
  545. for cfg in reversed(
  546. sorted(
  547. all_configs,
  548. key=lambda cfg: (
  549. cfg.db.name,
  550. cfg.db.driver,
  551. cfg.db.dialect.server_version_info,
  552. ),
  553. )
  554. ):
  555. db = cfg.db.name
  556. if db not in per_dialect:
  557. per_dialect[db] = cfg
  558. return per_dialect.values()
  559. return all_configs
  560. def _do_skips(cls):
  561. reasons = []
  562. all_configs = _possible_configs_for_cls(cls, reasons)
  563. if getattr(cls, "__skip_if__", False):
  564. for c in getattr(cls, "__skip_if__"):
  565. if c():
  566. config.skip_test(
  567. "'%s' skipped by %s" % (cls.__name__, c.__name__)
  568. )
  569. if not all_configs:
  570. msg = "'%s' unsupported on any DB implementation %s%s" % (
  571. cls.__name__,
  572. ", ".join(
  573. "'%s(%s)+%s'"
  574. % (
  575. config_obj.db.name,
  576. ".".join(
  577. str(dig)
  578. for dig in exclusions._server_version(config_obj.db)
  579. ),
  580. config_obj.db.driver,
  581. )
  582. for config_obj in config.Config.all_configs()
  583. ),
  584. ", ".join(reasons),
  585. )
  586. config.skip_test(msg)
  587. elif hasattr(cls, "__prefer_backends__"):
  588. non_preferred = set()
  589. spec = exclusions.db_spec(*util.to_list(cls.__prefer_backends__))
  590. for config_obj in all_configs:
  591. if not spec(config_obj):
  592. non_preferred.add(config_obj)
  593. if all_configs.difference(non_preferred):
  594. all_configs.difference_update(non_preferred)
  595. if config._current not in all_configs:
  596. _setup_config(all_configs.pop(), cls)
  597. def _setup_config(config_obj, ctx):
  598. config._current.push(config_obj, testing)
  599. class FixtureFunctions(ABC):
  600. @abc.abstractmethod
  601. def skip_test_exception(self, *arg, **kw):
  602. raise NotImplementedError()
  603. @abc.abstractmethod
  604. def combinations(self, *args, **kw):
  605. raise NotImplementedError()
  606. @abc.abstractmethod
  607. def param_ident(self, *args, **kw):
  608. raise NotImplementedError()
  609. @abc.abstractmethod
  610. def fixture(self, *arg, **kw):
  611. raise NotImplementedError()
  612. def get_current_test_name(self):
  613. raise NotImplementedError()
  614. @abc.abstractmethod
  615. def mark_base_test_class(self):
  616. raise NotImplementedError()
  617. _fixture_fn_class = None
  618. def set_fixture_functions(fixture_fn_class):
  619. global _fixture_fn_class
  620. _fixture_fn_class = fixture_fn_class