You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

775 lines
24KB

  1. try:
  2. # installed by bootstrap.py
  3. import sqla_plugin_base as plugin_base
  4. except ImportError:
  5. # assume we're a package, use traditional import
  6. from . import plugin_base
  7. import argparse
  8. import collections
  9. from functools import update_wrapper
  10. import inspect
  11. import itertools
  12. import operator
  13. import os
  14. import re
  15. import sys
  16. import pytest
  17. try:
  18. import xdist # noqa
  19. has_xdist = True
  20. except ImportError:
  21. has_xdist = False
  22. py2k = sys.version_info < (3, 0)
  23. if py2k:
  24. try:
  25. import sqla_reinvent_fixtures as reinvent_fixtures_py2k
  26. except ImportError:
  27. from . import reinvent_fixtures_py2k
  28. def pytest_addoption(parser):
  29. group = parser.getgroup("sqlalchemy")
  30. def make_option(name, **kw):
  31. callback_ = kw.pop("callback", None)
  32. if callback_:
  33. class CallableAction(argparse.Action):
  34. def __call__(
  35. self, parser, namespace, values, option_string=None
  36. ):
  37. callback_(option_string, values, parser)
  38. kw["action"] = CallableAction
  39. zeroarg_callback = kw.pop("zeroarg_callback", None)
  40. if zeroarg_callback:
  41. class CallableAction(argparse.Action):
  42. def __init__(
  43. self,
  44. option_strings,
  45. dest,
  46. default=False,
  47. required=False,
  48. help=None, # noqa
  49. ):
  50. super(CallableAction, self).__init__(
  51. option_strings=option_strings,
  52. dest=dest,
  53. nargs=0,
  54. const=True,
  55. default=default,
  56. required=required,
  57. help=help,
  58. )
  59. def __call__(
  60. self, parser, namespace, values, option_string=None
  61. ):
  62. zeroarg_callback(option_string, values, parser)
  63. kw["action"] = CallableAction
  64. group.addoption(name, **kw)
  65. plugin_base.setup_options(make_option)
  66. plugin_base.read_config()
  67. def pytest_configure(config):
  68. if hasattr(config, "workerinput"):
  69. plugin_base.restore_important_follower_config(config.workerinput)
  70. plugin_base.configure_follower(config.workerinput["follower_ident"])
  71. else:
  72. if config.option.write_idents and os.path.exists(
  73. config.option.write_idents
  74. ):
  75. os.remove(config.option.write_idents)
  76. plugin_base.pre_begin(config.option)
  77. plugin_base.set_coverage_flag(
  78. bool(getattr(config.option, "cov_source", False))
  79. )
  80. plugin_base.set_fixture_functions(PytestFixtureFunctions)
  81. if config.option.dump_pyannotate:
  82. global DUMP_PYANNOTATE
  83. DUMP_PYANNOTATE = True
  84. DUMP_PYANNOTATE = False
  85. @pytest.fixture(autouse=True)
  86. def collect_types_fixture():
  87. if DUMP_PYANNOTATE:
  88. from pyannotate_runtime import collect_types
  89. collect_types.start()
  90. yield
  91. if DUMP_PYANNOTATE:
  92. collect_types.stop()
  93. def pytest_sessionstart(session):
  94. from sqlalchemy.testing import asyncio
  95. asyncio._assume_async(plugin_base.post_begin)
  96. def pytest_sessionfinish(session):
  97. from sqlalchemy.testing import asyncio
  98. asyncio._maybe_async_provisioning(plugin_base.final_process_cleanup)
  99. if session.config.option.dump_pyannotate:
  100. from pyannotate_runtime import collect_types
  101. collect_types.dump_stats(session.config.option.dump_pyannotate)
  102. def pytest_collection_finish(session):
  103. if session.config.option.dump_pyannotate:
  104. from pyannotate_runtime import collect_types
  105. lib_sqlalchemy = os.path.abspath("lib/sqlalchemy")
  106. def _filter(filename):
  107. filename = os.path.normpath(os.path.abspath(filename))
  108. if "lib/sqlalchemy" not in os.path.commonpath(
  109. [filename, lib_sqlalchemy]
  110. ):
  111. return None
  112. if "testing" in filename:
  113. return None
  114. return filename
  115. collect_types.init_types_collection(filter_filename=_filter)
  116. if has_xdist:
  117. import uuid
  118. def pytest_configure_node(node):
  119. from sqlalchemy.testing import provision
  120. from sqlalchemy.testing import asyncio
  121. # the master for each node fills workerinput dictionary
  122. # which pytest-xdist will transfer to the subprocess
  123. plugin_base.memoize_important_follower_config(node.workerinput)
  124. node.workerinput["follower_ident"] = "test_%s" % uuid.uuid4().hex[0:12]
  125. asyncio._maybe_async_provisioning(
  126. provision.create_follower_db, node.workerinput["follower_ident"]
  127. )
  128. def pytest_testnodedown(node, error):
  129. from sqlalchemy.testing import provision
  130. from sqlalchemy.testing import asyncio
  131. asyncio._maybe_async_provisioning(
  132. provision.drop_follower_db, node.workerinput["follower_ident"]
  133. )
  134. def pytest_collection_modifyitems(session, config, items):
  135. # look for all those classes that specify __backend__ and
  136. # expand them out into per-database test cases.
  137. # this is much easier to do within pytest_pycollect_makeitem, however
  138. # pytest is iterating through cls.__dict__ as makeitem is
  139. # called which causes a "dictionary changed size" error on py3k.
  140. # I'd submit a pullreq for them to turn it into a list first, but
  141. # it's to suit the rather odd use case here which is that we are adding
  142. # new classes to a module on the fly.
  143. from sqlalchemy.testing import asyncio
  144. rebuilt_items = collections.defaultdict(
  145. lambda: collections.defaultdict(list)
  146. )
  147. items[:] = [
  148. item
  149. for item in items
  150. if isinstance(item.parent, pytest.Instance)
  151. and not item.parent.parent.name.startswith("_")
  152. ]
  153. test_classes = set(item.parent for item in items)
  154. def setup_test_classes():
  155. for test_class in test_classes:
  156. for sub_cls in plugin_base.generate_sub_tests(
  157. test_class.cls, test_class.parent.module
  158. ):
  159. if sub_cls is not test_class.cls:
  160. per_cls_dict = rebuilt_items[test_class.cls]
  161. # support pytest 5.4.0 and above pytest.Class.from_parent
  162. ctor = getattr(pytest.Class, "from_parent", pytest.Class)
  163. for inst in ctor(
  164. name=sub_cls.__name__, parent=test_class.parent.parent
  165. ).collect():
  166. for t in inst.collect():
  167. per_cls_dict[t.name].append(t)
  168. # class requirements will sometimes need to access the DB to check
  169. # capabilities, so need to do this for async
  170. asyncio._maybe_async_provisioning(setup_test_classes)
  171. newitems = []
  172. for item in items:
  173. if item.parent.cls in rebuilt_items:
  174. newitems.extend(rebuilt_items[item.parent.cls][item.name])
  175. else:
  176. newitems.append(item)
  177. if py2k:
  178. for item in newitems:
  179. reinvent_fixtures_py2k.scan_for_fixtures_to_use_for_class(item)
  180. # seems like the functions attached to a test class aren't sorted already?
  181. # is that true and why's that? (when using unittest, they're sorted)
  182. items[:] = sorted(
  183. newitems,
  184. key=lambda item: (
  185. item.parent.parent.parent.name,
  186. item.parent.parent.name,
  187. item.name,
  188. ),
  189. )
  190. def pytest_pycollect_makeitem(collector, name, obj):
  191. if inspect.isclass(obj) and plugin_base.want_class(name, obj):
  192. from sqlalchemy.testing import config
  193. if config.any_async:
  194. obj = _apply_maybe_async(obj)
  195. ctor = getattr(pytest.Class, "from_parent", pytest.Class)
  196. return [
  197. ctor(name=parametrize_cls.__name__, parent=collector)
  198. for parametrize_cls in _parametrize_cls(collector.module, obj)
  199. ]
  200. elif (
  201. inspect.isfunction(obj)
  202. and isinstance(collector, pytest.Instance)
  203. and plugin_base.want_method(collector.cls, obj)
  204. ):
  205. # None means, fall back to default logic, which includes
  206. # method-level parametrize
  207. return None
  208. else:
  209. # empty list means skip this item
  210. return []
  211. def _is_wrapped_coroutine_function(fn):
  212. while hasattr(fn, "__wrapped__"):
  213. fn = fn.__wrapped__
  214. return inspect.iscoroutinefunction(fn)
  215. def _apply_maybe_async(obj, recurse=True):
  216. from sqlalchemy.testing import asyncio
  217. for name, value in vars(obj).items():
  218. if (
  219. (callable(value) or isinstance(value, classmethod))
  220. and not getattr(value, "_maybe_async_applied", False)
  221. and (name.startswith("test_"))
  222. and not _is_wrapped_coroutine_function(value)
  223. ):
  224. is_classmethod = False
  225. if isinstance(value, classmethod):
  226. value = value.__func__
  227. is_classmethod = True
  228. @_pytest_fn_decorator
  229. def make_async(fn, *args, **kwargs):
  230. return asyncio._maybe_async(fn, *args, **kwargs)
  231. do_async = make_async(value)
  232. if is_classmethod:
  233. do_async = classmethod(do_async)
  234. do_async._maybe_async_applied = True
  235. setattr(obj, name, do_async)
  236. if recurse:
  237. for cls in obj.mro()[1:]:
  238. if cls != object:
  239. _apply_maybe_async(cls, False)
  240. return obj
  241. def _parametrize_cls(module, cls):
  242. """implement a class-based version of pytest parametrize."""
  243. if "_sa_parametrize" not in cls.__dict__:
  244. return [cls]
  245. _sa_parametrize = cls._sa_parametrize
  246. classes = []
  247. for full_param_set in itertools.product(
  248. *[params for argname, params in _sa_parametrize]
  249. ):
  250. cls_variables = {}
  251. for argname, param in zip(
  252. [_sa_param[0] for _sa_param in _sa_parametrize], full_param_set
  253. ):
  254. if not argname:
  255. raise TypeError("need argnames for class-based combinations")
  256. argname_split = re.split(r",\s*", argname)
  257. for arg, val in zip(argname_split, param.values):
  258. cls_variables[arg] = val
  259. parametrized_name = "_".join(
  260. # token is a string, but in py2k pytest is giving us a unicode,
  261. # so call str() on it.
  262. str(re.sub(r"\W", "", token))
  263. for param in full_param_set
  264. for token in param.id.split("-")
  265. )
  266. name = "%s_%s" % (cls.__name__, parametrized_name)
  267. newcls = type.__new__(type, name, (cls,), cls_variables)
  268. setattr(module, name, newcls)
  269. classes.append(newcls)
  270. return classes
  271. _current_class = None
  272. def pytest_runtest_setup(item):
  273. from sqlalchemy.testing import asyncio
  274. if not isinstance(item, pytest.Function):
  275. return
  276. # pytest_runtest_setup runs *before* pytest fixtures with scope="class".
  277. # plugin_base.start_test_class_outside_fixtures may opt to raise SkipTest
  278. # for the whole class and has to run things that are across all current
  279. # databases, so we run this outside of the pytest fixture system altogether
  280. # and ensure asyncio greenlet if any engines are async
  281. global _current_class
  282. if _current_class is None:
  283. asyncio._maybe_async_provisioning(
  284. plugin_base.start_test_class_outside_fixtures,
  285. item.parent.parent.cls,
  286. )
  287. _current_class = item.parent.parent
  288. def finalize():
  289. global _current_class
  290. _current_class = None
  291. asyncio._maybe_async_provisioning(
  292. plugin_base.stop_test_class_outside_fixtures,
  293. item.parent.parent.cls,
  294. )
  295. item.parent.parent.addfinalizer(finalize)
  296. def pytest_runtest_call(item):
  297. # runs inside of pytest function fixture scope
  298. # before test function runs
  299. from sqlalchemy.testing import asyncio
  300. asyncio._maybe_async(
  301. plugin_base.before_test,
  302. item,
  303. item.parent.module.__name__,
  304. item.parent.cls,
  305. item.name,
  306. )
  307. def pytest_runtest_teardown(item, nextitem):
  308. # runs inside of pytest function fixture scope
  309. # after test function runs
  310. from sqlalchemy.testing import asyncio
  311. asyncio._maybe_async(plugin_base.after_test, item)
  312. @pytest.fixture(scope="class")
  313. def setup_class_methods(request):
  314. from sqlalchemy.testing import asyncio
  315. cls = request.cls
  316. if hasattr(cls, "setup_test_class"):
  317. asyncio._maybe_async(cls.setup_test_class)
  318. if py2k:
  319. reinvent_fixtures_py2k.run_class_fixture_setup(request)
  320. yield
  321. if py2k:
  322. reinvent_fixtures_py2k.run_class_fixture_teardown(request)
  323. if hasattr(cls, "teardown_test_class"):
  324. asyncio._maybe_async(cls.teardown_test_class)
  325. asyncio._maybe_async(plugin_base.stop_test_class, cls)
  326. @pytest.fixture(scope="function")
  327. def setup_test_methods(request):
  328. from sqlalchemy.testing import asyncio
  329. # called for each test
  330. self = request.instance
  331. # 1. run outer xdist-style setup
  332. if hasattr(self, "setup_test"):
  333. asyncio._maybe_async(self.setup_test)
  334. # alembic test suite is using setUp and tearDown
  335. # xdist methods; support these in the test suite
  336. # for the near term
  337. if hasattr(self, "setUp"):
  338. asyncio._maybe_async(self.setUp)
  339. # 2. run homegrown function level "autouse" fixtures under py2k
  340. if py2k:
  341. reinvent_fixtures_py2k.run_fn_fixture_setup(request)
  342. # inside the yield:
  343. # 3. function level "autouse" fixtures under py3k (examples: TablesTest
  344. # define tables / data, MappedTest define tables / mappers / data)
  345. # 4. function level fixtures defined on test functions themselves,
  346. # e.g. "connection", "metadata" run next
  347. # 5. pytest hook pytest_runtest_call then runs
  348. # 6. test itself runs
  349. yield
  350. # yield finishes:
  351. # 7. pytest hook pytest_runtest_teardown hook runs, this is associated
  352. # with fixtures close all sessions, provisioning.stop_test_class(),
  353. # engines.testing_reaper -> ensure all connection pool connections
  354. # are returned, engines created by testing_engine that aren't the
  355. # config engine are disposed
  356. # 8. function level fixtures defined on test functions
  357. # themselves, e.g. "connection" rolls back the transaction, "metadata"
  358. # emits drop all
  359. # 9. function level "autouse" fixtures under py3k (examples: TablesTest /
  360. # MappedTest delete table data, possibly drop tables and clear mappers
  361. # depending on the flags defined by the test class)
  362. # 10. run homegrown function-level "autouse" fixtures under py2k
  363. if py2k:
  364. reinvent_fixtures_py2k.run_fn_fixture_teardown(request)
  365. asyncio._maybe_async(plugin_base.after_test_fixtures, self)
  366. # 11. run outer xdist-style teardown
  367. if hasattr(self, "tearDown"):
  368. asyncio._maybe_async(self.tearDown)
  369. if hasattr(self, "teardown_test"):
  370. asyncio._maybe_async(self.teardown_test)
  371. def getargspec(fn):
  372. if sys.version_info.major == 3:
  373. return inspect.getfullargspec(fn)
  374. else:
  375. return inspect.getargspec(fn)
  376. def _pytest_fn_decorator(target):
  377. """Port of langhelpers.decorator with pytest-specific tricks."""
  378. from sqlalchemy.util.langhelpers import format_argspec_plus
  379. from sqlalchemy.util.compat import inspect_getfullargspec
  380. def _exec_code_in_env(code, env, fn_name):
  381. exec(code, env)
  382. return env[fn_name]
  383. def decorate(fn, add_positional_parameters=()):
  384. spec = inspect_getfullargspec(fn)
  385. if add_positional_parameters:
  386. spec.args.extend(add_positional_parameters)
  387. metadata = dict(
  388. __target_fn="__target_fn", __orig_fn="__orig_fn", name=fn.__name__
  389. )
  390. metadata.update(format_argspec_plus(spec, grouped=False))
  391. code = (
  392. """\
  393. def %(name)s(%(args)s):
  394. return %(__target_fn)s(%(__orig_fn)s, %(apply_kw)s)
  395. """
  396. % metadata
  397. )
  398. decorated = _exec_code_in_env(
  399. code, {"__target_fn": target, "__orig_fn": fn}, fn.__name__
  400. )
  401. if not add_positional_parameters:
  402. decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
  403. decorated.__wrapped__ = fn
  404. return update_wrapper(decorated, fn)
  405. else:
  406. # this is the pytest hacky part. don't do a full update wrapper
  407. # because pytest is really being sneaky about finding the args
  408. # for the wrapped function
  409. decorated.__module__ = fn.__module__
  410. decorated.__name__ = fn.__name__
  411. if hasattr(fn, "pytestmark"):
  412. decorated.pytestmark = fn.pytestmark
  413. return decorated
  414. return decorate
  415. class PytestFixtureFunctions(plugin_base.FixtureFunctions):
  416. def skip_test_exception(self, *arg, **kw):
  417. return pytest.skip.Exception(*arg, **kw)
  418. def mark_base_test_class(self):
  419. return pytest.mark.usefixtures(
  420. "setup_class_methods", "setup_test_methods"
  421. )
  422. _combination_id_fns = {
  423. "i": lambda obj: obj,
  424. "r": repr,
  425. "s": str,
  426. "n": lambda obj: obj.__name__
  427. if hasattr(obj, "__name__")
  428. else type(obj).__name__,
  429. }
  430. def combinations(self, *arg_sets, **kw):
  431. """Facade for pytest.mark.parametrize.
  432. Automatically derives argument names from the callable which in our
  433. case is always a method on a class with positional arguments.
  434. ids for parameter sets are derived using an optional template.
  435. """
  436. from sqlalchemy.testing import exclusions
  437. if sys.version_info.major == 3:
  438. if len(arg_sets) == 1 and hasattr(arg_sets[0], "__next__"):
  439. arg_sets = list(arg_sets[0])
  440. else:
  441. if len(arg_sets) == 1 and hasattr(arg_sets[0], "next"):
  442. arg_sets = list(arg_sets[0])
  443. argnames = kw.pop("argnames", None)
  444. def _filter_exclusions(args):
  445. result = []
  446. gathered_exclusions = []
  447. for a in args:
  448. if isinstance(a, exclusions.compound):
  449. gathered_exclusions.append(a)
  450. else:
  451. result.append(a)
  452. return result, gathered_exclusions
  453. id_ = kw.pop("id_", None)
  454. tobuild_pytest_params = []
  455. has_exclusions = False
  456. if id_:
  457. _combination_id_fns = self._combination_id_fns
  458. # because itemgetter is not consistent for one argument vs.
  459. # multiple, make it multiple in all cases and use a slice
  460. # to omit the first argument
  461. _arg_getter = operator.itemgetter(
  462. 0,
  463. *[
  464. idx
  465. for idx, char in enumerate(id_)
  466. if char in ("n", "r", "s", "a")
  467. ]
  468. )
  469. fns = [
  470. (operator.itemgetter(idx), _combination_id_fns[char])
  471. for idx, char in enumerate(id_)
  472. if char in _combination_id_fns
  473. ]
  474. for arg in arg_sets:
  475. if not isinstance(arg, tuple):
  476. arg = (arg,)
  477. fn_params, param_exclusions = _filter_exclusions(arg)
  478. parameters = _arg_getter(fn_params)[1:]
  479. if param_exclusions:
  480. has_exclusions = True
  481. tobuild_pytest_params.append(
  482. (
  483. parameters,
  484. param_exclusions,
  485. "-".join(
  486. comb_fn(getter(arg)) for getter, comb_fn in fns
  487. ),
  488. )
  489. )
  490. else:
  491. for arg in arg_sets:
  492. if not isinstance(arg, tuple):
  493. arg = (arg,)
  494. fn_params, param_exclusions = _filter_exclusions(arg)
  495. if param_exclusions:
  496. has_exclusions = True
  497. tobuild_pytest_params.append(
  498. (fn_params, param_exclusions, None)
  499. )
  500. pytest_params = []
  501. for parameters, param_exclusions, id_ in tobuild_pytest_params:
  502. if has_exclusions:
  503. parameters += (param_exclusions,)
  504. param = pytest.param(*parameters, id=id_)
  505. pytest_params.append(param)
  506. def decorate(fn):
  507. if inspect.isclass(fn):
  508. if has_exclusions:
  509. raise NotImplementedError(
  510. "exclusions not supported for class level combinations"
  511. )
  512. if "_sa_parametrize" not in fn.__dict__:
  513. fn._sa_parametrize = []
  514. fn._sa_parametrize.append((argnames, pytest_params))
  515. return fn
  516. else:
  517. if argnames is None:
  518. _argnames = getargspec(fn).args[1:]
  519. else:
  520. _argnames = re.split(r", *", argnames)
  521. if has_exclusions:
  522. _argnames += ["_exclusions"]
  523. @_pytest_fn_decorator
  524. def check_exclusions(fn, *args, **kw):
  525. _exclusions = args[-1]
  526. if _exclusions:
  527. exlu = exclusions.compound().add(*_exclusions)
  528. fn = exlu(fn)
  529. return fn(*args[0:-1], **kw)
  530. def process_metadata(spec):
  531. spec.args.append("_exclusions")
  532. fn = check_exclusions(
  533. fn, add_positional_parameters=("_exclusions",)
  534. )
  535. return pytest.mark.parametrize(_argnames, pytest_params)(fn)
  536. return decorate
  537. def param_ident(self, *parameters):
  538. ident = parameters[0]
  539. return pytest.param(*parameters[1:], id=ident)
  540. def fixture(self, *arg, **kw):
  541. from sqlalchemy.testing import config
  542. from sqlalchemy.testing import asyncio
  543. # wrapping pytest.fixture function. determine if
  544. # decorator was called as @fixture or @fixture().
  545. if len(arg) > 0 and callable(arg[0]):
  546. # was called as @fixture(), we have the function to wrap.
  547. fn = arg[0]
  548. arg = arg[1:]
  549. else:
  550. # was called as @fixture, don't have the function yet.
  551. fn = None
  552. # create a pytest.fixture marker. because the fn is not being
  553. # passed, this is always a pytest.FixtureFunctionMarker()
  554. # object (or whatever pytest is calling it when you read this)
  555. # that is waiting for a function.
  556. fixture = pytest.fixture(*arg, **kw)
  557. # now apply wrappers to the function, including fixture itself
  558. def wrap(fn):
  559. if config.any_async:
  560. fn = asyncio._maybe_async_wrapper(fn)
  561. # other wrappers may be added here
  562. if py2k and "autouse" in kw:
  563. # py2k workaround for too-slow collection of autouse fixtures
  564. # in pytest 4.6.11. See notes in reinvent_fixtures_py2k for
  565. # rationale.
  566. # comment this condition out in order to disable the
  567. # py2k workaround entirely.
  568. reinvent_fixtures_py2k.add_fixture(fn, fixture)
  569. else:
  570. # now apply FixtureFunctionMarker
  571. fn = fixture(fn)
  572. return fn
  573. if fn:
  574. return wrap(fn)
  575. else:
  576. return wrap
  577. def get_current_test_name(self):
  578. return os.environ.get("PYTEST_CURRENT_TEST")
  579. def async_test(self, fn):
  580. from sqlalchemy.testing import asyncio
  581. @_pytest_fn_decorator
  582. def decorate(fn, *args, **kwargs):
  583. asyncio._run_coroutine_function(fn, *args, **kwargs)
  584. return decorate(fn)