You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

809 lines
25KB

  1. # testing/fixtures.py
  2. # Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. import contextlib
  8. import re
  9. import sys
  10. import sqlalchemy as sa
  11. from . import assertions
  12. from . import config
  13. from . import schema
  14. from .entities import BasicEntity
  15. from .entities import ComparableEntity
  16. from .entities import ComparableMixin # noqa
  17. from .util import adict
  18. from .util import drop_all_tables_from_metadata
  19. from .. import event
  20. from .. import util
  21. from ..orm import declarative_base
  22. from ..orm import registry
  23. from ..orm.decl_api import DeclarativeMeta
  24. from ..schema import sort_tables_and_constraints
  25. @config.mark_base_test_class()
  26. class TestBase(object):
  27. # A sequence of requirement names matching testing.requires decorators
  28. __requires__ = ()
  29. # A sequence of dialect names to exclude from the test class.
  30. __unsupported_on__ = ()
  31. # If present, test class is only runnable for the *single* specified
  32. # dialect. If you need multiple, use __unsupported_on__ and invert.
  33. __only_on__ = None
  34. # A sequence of no-arg callables. If any are True, the entire testcase is
  35. # skipped.
  36. __skip_if__ = None
  37. # if True, the testing reaper will not attempt to touch connection
  38. # state after a test is completed and before the outer teardown
  39. # starts
  40. __leave_connections_for_teardown__ = False
  41. def assert_(self, val, msg=None):
  42. assert val, msg
  43. @config.fixture()
  44. def connection_no_trans(self):
  45. eng = getattr(self, "bind", None) or config.db
  46. with eng.connect() as conn:
  47. yield conn
  48. @config.fixture()
  49. def connection(self):
  50. global _connection_fixture_connection
  51. eng = getattr(self, "bind", None) or config.db
  52. conn = eng.connect()
  53. trans = conn.begin()
  54. _connection_fixture_connection = conn
  55. yield conn
  56. _connection_fixture_connection = None
  57. if trans.is_active:
  58. trans.rollback()
  59. # trans would not be active here if the test is using
  60. # the legacy @provide_metadata decorator still, as it will
  61. # run a close all connections.
  62. conn.close()
  63. @config.fixture()
  64. def registry(self, metadata):
  65. reg = registry(metadata=metadata)
  66. yield reg
  67. reg.dispose()
  68. @config.fixture()
  69. def future_connection(self, future_engine, connection):
  70. # integrate the future_engine and connection fixtures so
  71. # that users of the "connection" fixture will get at the
  72. # "future" connection
  73. yield connection
  74. @config.fixture()
  75. def future_engine(self):
  76. eng = getattr(self, "bind", None) or config.db
  77. with _push_future_engine(eng):
  78. yield
  79. @config.fixture()
  80. def testing_engine(self):
  81. from . import engines
  82. def gen_testing_engine(
  83. url=None,
  84. options=None,
  85. future=None,
  86. asyncio=False,
  87. transfer_staticpool=False,
  88. ):
  89. if options is None:
  90. options = {}
  91. options["scope"] = "fixture"
  92. return engines.testing_engine(
  93. url=url,
  94. options=options,
  95. future=future,
  96. asyncio=asyncio,
  97. transfer_staticpool=transfer_staticpool,
  98. )
  99. yield gen_testing_engine
  100. engines.testing_reaper._drop_testing_engines("fixture")
  101. @config.fixture()
  102. def async_testing_engine(self, testing_engine):
  103. def go(**kw):
  104. kw["asyncio"] = True
  105. return testing_engine(**kw)
  106. return go
  107. @config.fixture()
  108. def metadata(self, request):
  109. """Provide bound MetaData for a single test, dropping afterwards."""
  110. from ..sql import schema
  111. metadata = schema.MetaData()
  112. request.instance.metadata = metadata
  113. yield metadata
  114. del request.instance.metadata
  115. if (
  116. _connection_fixture_connection
  117. and _connection_fixture_connection.in_transaction()
  118. ):
  119. trans = _connection_fixture_connection.get_transaction()
  120. trans.rollback()
  121. with _connection_fixture_connection.begin():
  122. drop_all_tables_from_metadata(
  123. metadata, _connection_fixture_connection
  124. )
  125. else:
  126. drop_all_tables_from_metadata(metadata, config.db)
  127. @config.fixture(
  128. params=[
  129. (rollback, second_operation, begin_nested)
  130. for rollback in (True, False)
  131. for second_operation in ("none", "execute", "begin")
  132. for begin_nested in (
  133. True,
  134. False,
  135. )
  136. ]
  137. )
  138. def trans_ctx_manager_fixture(self, request, metadata):
  139. rollback, second_operation, begin_nested = request.param
  140. from sqlalchemy import Table, Column, Integer, func, select
  141. from . import eq_
  142. t = Table("test", metadata, Column("data", Integer))
  143. eng = getattr(self, "bind", None) or config.db
  144. t.create(eng)
  145. def run_test(subject, trans_on_subject, execute_on_subject):
  146. with subject.begin() as trans:
  147. if begin_nested:
  148. if not config.requirements.savepoints.enabled:
  149. config.skip_test("savepoints not enabled")
  150. if execute_on_subject:
  151. nested_trans = subject.begin_nested()
  152. else:
  153. nested_trans = trans.begin_nested()
  154. with nested_trans:
  155. if execute_on_subject:
  156. subject.execute(t.insert(), {"data": 10})
  157. else:
  158. trans.execute(t.insert(), {"data": 10})
  159. # for nested trans, we always commit/rollback on the
  160. # "nested trans" object itself.
  161. # only Session(future=False) will affect savepoint
  162. # transaction for session.commit/rollback
  163. if rollback:
  164. nested_trans.rollback()
  165. else:
  166. nested_trans.commit()
  167. if second_operation != "none":
  168. with assertions.expect_raises_message(
  169. sa.exc.InvalidRequestError,
  170. "Can't operate on closed transaction "
  171. "inside context "
  172. "manager. Please complete the context "
  173. "manager "
  174. "before emitting further commands.",
  175. ):
  176. if second_operation == "execute":
  177. if execute_on_subject:
  178. subject.execute(
  179. t.insert(), {"data": 12}
  180. )
  181. else:
  182. trans.execute(t.insert(), {"data": 12})
  183. elif second_operation == "begin":
  184. if execute_on_subject:
  185. subject.begin_nested()
  186. else:
  187. trans.begin_nested()
  188. # outside the nested trans block, but still inside the
  189. # transaction block, we can run SQL, and it will be
  190. # committed
  191. if execute_on_subject:
  192. subject.execute(t.insert(), {"data": 14})
  193. else:
  194. trans.execute(t.insert(), {"data": 14})
  195. else:
  196. if execute_on_subject:
  197. subject.execute(t.insert(), {"data": 10})
  198. else:
  199. trans.execute(t.insert(), {"data": 10})
  200. if trans_on_subject:
  201. if rollback:
  202. subject.rollback()
  203. else:
  204. subject.commit()
  205. else:
  206. if rollback:
  207. trans.rollback()
  208. else:
  209. trans.commit()
  210. if second_operation != "none":
  211. with assertions.expect_raises_message(
  212. sa.exc.InvalidRequestError,
  213. "Can't operate on closed transaction inside "
  214. "context "
  215. "manager. Please complete the context manager "
  216. "before emitting further commands.",
  217. ):
  218. if second_operation == "execute":
  219. if execute_on_subject:
  220. subject.execute(t.insert(), {"data": 12})
  221. else:
  222. trans.execute(t.insert(), {"data": 12})
  223. elif second_operation == "begin":
  224. if hasattr(trans, "begin"):
  225. trans.begin()
  226. else:
  227. subject.begin()
  228. elif second_operation == "begin_nested":
  229. if execute_on_subject:
  230. subject.begin_nested()
  231. else:
  232. trans.begin_nested()
  233. expected_committed = 0
  234. if begin_nested:
  235. # begin_nested variant, we inserted a row after the nested
  236. # block
  237. expected_committed += 1
  238. if not rollback:
  239. # not rollback variant, our row inserted in the target
  240. # block itself would be committed
  241. expected_committed += 1
  242. if execute_on_subject:
  243. eq_(
  244. subject.scalar(select(func.count()).select_from(t)),
  245. expected_committed,
  246. )
  247. else:
  248. with subject.connect() as conn:
  249. eq_(
  250. conn.scalar(select(func.count()).select_from(t)),
  251. expected_committed,
  252. )
  253. return run_test
  254. _connection_fixture_connection = None
  255. @contextlib.contextmanager
  256. def _push_future_engine(engine):
  257. from ..future.engine import Engine
  258. from sqlalchemy import testing
  259. facade = Engine._future_facade(engine)
  260. config._current.push_engine(facade, testing)
  261. yield facade
  262. config._current.pop(testing)
  263. class FutureEngineMixin(object):
  264. @config.fixture(autouse=True, scope="class")
  265. def _push_future_engine(self):
  266. eng = getattr(self, "bind", None) or config.db
  267. with _push_future_engine(eng):
  268. yield
  269. class TablesTest(TestBase):
  270. # 'once', None
  271. run_setup_bind = "once"
  272. # 'once', 'each', None
  273. run_define_tables = "once"
  274. # 'once', 'each', None
  275. run_create_tables = "once"
  276. # 'once', 'each', None
  277. run_inserts = "each"
  278. # 'each', None
  279. run_deletes = "each"
  280. # 'once', None
  281. run_dispose_bind = None
  282. bind = None
  283. _tables_metadata = None
  284. tables = None
  285. other = None
  286. sequences = None
  287. @config.fixture(autouse=True, scope="class")
  288. def _setup_tables_test_class(self):
  289. cls = self.__class__
  290. cls._init_class()
  291. cls._setup_once_tables()
  292. cls._setup_once_inserts()
  293. yield
  294. cls._teardown_once_metadata_bind()
  295. @config.fixture(autouse=True, scope="function")
  296. def _setup_tables_test_instance(self):
  297. self._setup_each_tables()
  298. self._setup_each_inserts()
  299. yield
  300. self._teardown_each_tables()
  301. @property
  302. def tables_test_metadata(self):
  303. return self._tables_metadata
  304. @classmethod
  305. def _init_class(cls):
  306. if cls.run_define_tables == "each":
  307. if cls.run_create_tables == "once":
  308. cls.run_create_tables = "each"
  309. assert cls.run_inserts in ("each", None)
  310. cls.other = adict()
  311. cls.tables = adict()
  312. cls.sequences = adict()
  313. cls.bind = cls.setup_bind()
  314. cls._tables_metadata = sa.MetaData()
  315. @classmethod
  316. def _setup_once_inserts(cls):
  317. if cls.run_inserts == "once":
  318. cls._load_fixtures()
  319. with cls.bind.begin() as conn:
  320. cls.insert_data(conn)
  321. @classmethod
  322. def _setup_once_tables(cls):
  323. if cls.run_define_tables == "once":
  324. cls.define_tables(cls._tables_metadata)
  325. if cls.run_create_tables == "once":
  326. cls._tables_metadata.create_all(cls.bind)
  327. cls.tables.update(cls._tables_metadata.tables)
  328. cls.sequences.update(cls._tables_metadata._sequences)
  329. def _setup_each_tables(self):
  330. if self.run_define_tables == "each":
  331. self.define_tables(self._tables_metadata)
  332. if self.run_create_tables == "each":
  333. self._tables_metadata.create_all(self.bind)
  334. self.tables.update(self._tables_metadata.tables)
  335. self.sequences.update(self._tables_metadata._sequences)
  336. elif self.run_create_tables == "each":
  337. self._tables_metadata.create_all(self.bind)
  338. def _setup_each_inserts(self):
  339. if self.run_inserts == "each":
  340. self._load_fixtures()
  341. with self.bind.begin() as conn:
  342. self.insert_data(conn)
  343. def _teardown_each_tables(self):
  344. if self.run_define_tables == "each":
  345. self.tables.clear()
  346. if self.run_create_tables == "each":
  347. drop_all_tables_from_metadata(self._tables_metadata, self.bind)
  348. self._tables_metadata.clear()
  349. elif self.run_create_tables == "each":
  350. drop_all_tables_from_metadata(self._tables_metadata, self.bind)
  351. # no need to run deletes if tables are recreated on setup
  352. if (
  353. self.run_define_tables != "each"
  354. and self.run_create_tables != "each"
  355. and self.run_deletes == "each"
  356. ):
  357. with self.bind.begin() as conn:
  358. for table in reversed(
  359. [
  360. t
  361. for (t, fks) in sort_tables_and_constraints(
  362. self._tables_metadata.tables.values()
  363. )
  364. if t is not None
  365. ]
  366. ):
  367. try:
  368. conn.execute(table.delete())
  369. except sa.exc.DBAPIError as ex:
  370. util.print_(
  371. ("Error emptying table %s: %r" % (table, ex)),
  372. file=sys.stderr,
  373. )
  374. @classmethod
  375. def _teardown_once_metadata_bind(cls):
  376. if cls.run_create_tables:
  377. drop_all_tables_from_metadata(cls._tables_metadata, cls.bind)
  378. if cls.run_dispose_bind == "once":
  379. cls.dispose_bind(cls.bind)
  380. cls._tables_metadata.bind = None
  381. if cls.run_setup_bind is not None:
  382. cls.bind = None
  383. @classmethod
  384. def setup_bind(cls):
  385. return config.db
  386. @classmethod
  387. def dispose_bind(cls, bind):
  388. if hasattr(bind, "dispose"):
  389. bind.dispose()
  390. elif hasattr(bind, "close"):
  391. bind.close()
  392. @classmethod
  393. def define_tables(cls, metadata):
  394. pass
  395. @classmethod
  396. def fixtures(cls):
  397. return {}
  398. @classmethod
  399. def insert_data(cls, connection):
  400. pass
  401. def sql_count_(self, count, fn):
  402. self.assert_sql_count(self.bind, fn, count)
  403. def sql_eq_(self, callable_, statements):
  404. self.assert_sql(self.bind, callable_, statements)
  405. @classmethod
  406. def _load_fixtures(cls):
  407. """Insert rows as represented by the fixtures() method."""
  408. headers, rows = {}, {}
  409. for table, data in cls.fixtures().items():
  410. if len(data) < 2:
  411. continue
  412. if isinstance(table, util.string_types):
  413. table = cls.tables[table]
  414. headers[table] = data[0]
  415. rows[table] = data[1:]
  416. for table, fks in sort_tables_and_constraints(
  417. cls._tables_metadata.tables.values()
  418. ):
  419. if table is None:
  420. continue
  421. if table not in headers:
  422. continue
  423. with cls.bind.begin() as conn:
  424. conn.execute(
  425. table.insert(),
  426. [
  427. dict(zip(headers[table], column_values))
  428. for column_values in rows[table]
  429. ],
  430. )
  431. class RemovesEvents(object):
  432. @util.memoized_property
  433. def _event_fns(self):
  434. return set()
  435. def event_listen(self, target, name, fn, **kw):
  436. self._event_fns.add((target, name, fn))
  437. event.listen(target, name, fn, **kw)
  438. @config.fixture(autouse=True, scope="function")
  439. def _remove_events(self):
  440. yield
  441. for key in self._event_fns:
  442. event.remove(*key)
  443. _fixture_sessions = set()
  444. def fixture_session(**kw):
  445. kw.setdefault("autoflush", True)
  446. kw.setdefault("expire_on_commit", True)
  447. sess = sa.orm.Session(config.db, **kw)
  448. _fixture_sessions.add(sess)
  449. return sess
  450. def _close_all_sessions():
  451. # will close all still-referenced sessions
  452. sa.orm.session.close_all_sessions()
  453. _fixture_sessions.clear()
  454. def stop_test_class_inside_fixtures(cls):
  455. _close_all_sessions()
  456. sa.orm.clear_mappers()
  457. def after_test():
  458. if _fixture_sessions:
  459. _close_all_sessions()
  460. class ORMTest(TestBase):
  461. pass
  462. class MappedTest(TablesTest, assertions.AssertsExecutionResults):
  463. # 'once', 'each', None
  464. run_setup_classes = "once"
  465. # 'once', 'each', None
  466. run_setup_mappers = "each"
  467. classes = None
  468. @config.fixture(autouse=True, scope="class")
  469. def _setup_tables_test_class(self):
  470. cls = self.__class__
  471. cls._init_class()
  472. if cls.classes is None:
  473. cls.classes = adict()
  474. cls._setup_once_tables()
  475. cls._setup_once_classes()
  476. cls._setup_once_mappers()
  477. cls._setup_once_inserts()
  478. yield
  479. cls._teardown_once_class()
  480. cls._teardown_once_metadata_bind()
  481. @config.fixture(autouse=True, scope="function")
  482. def _setup_tables_test_instance(self):
  483. self._setup_each_tables()
  484. self._setup_each_classes()
  485. self._setup_each_mappers()
  486. self._setup_each_inserts()
  487. yield
  488. sa.orm.session.close_all_sessions()
  489. self._teardown_each_mappers()
  490. self._teardown_each_classes()
  491. self._teardown_each_tables()
  492. @classmethod
  493. def _teardown_once_class(cls):
  494. cls.classes.clear()
  495. @classmethod
  496. def _setup_once_classes(cls):
  497. if cls.run_setup_classes == "once":
  498. cls._with_register_classes(cls.setup_classes)
  499. @classmethod
  500. def _setup_once_mappers(cls):
  501. if cls.run_setup_mappers == "once":
  502. cls.mapper = cls._generate_mapper()
  503. cls._with_register_classes(cls.setup_mappers)
  504. def _setup_each_mappers(self):
  505. if self.run_setup_mappers == "each":
  506. self.__class__.mapper = self._generate_mapper()
  507. self._with_register_classes(self.setup_mappers)
  508. def _setup_each_classes(self):
  509. if self.run_setup_classes == "each":
  510. self._with_register_classes(self.setup_classes)
  511. @classmethod
  512. def _generate_mapper(cls):
  513. decl = registry()
  514. return decl.map_imperatively
  515. @classmethod
  516. def _with_register_classes(cls, fn):
  517. """Run a setup method, framing the operation with a Base class
  518. that will catch new subclasses to be established within
  519. the "classes" registry.
  520. """
  521. cls_registry = cls.classes
  522. assert cls_registry is not None
  523. class FindFixture(type):
  524. def __init__(cls, classname, bases, dict_):
  525. cls_registry[classname] = cls
  526. type.__init__(cls, classname, bases, dict_)
  527. class _Base(util.with_metaclass(FindFixture, object)):
  528. pass
  529. class Basic(BasicEntity, _Base):
  530. pass
  531. class Comparable(ComparableEntity, _Base):
  532. pass
  533. cls.Basic = Basic
  534. cls.Comparable = Comparable
  535. fn()
  536. def _teardown_each_mappers(self):
  537. # some tests create mappers in the test bodies
  538. # and will define setup_mappers as None -
  539. # clear mappers in any case
  540. if self.run_setup_mappers != "once":
  541. sa.orm.clear_mappers()
  542. def _teardown_each_classes(self):
  543. if self.run_setup_classes != "once":
  544. self.classes.clear()
  545. @classmethod
  546. def setup_classes(cls):
  547. pass
  548. @classmethod
  549. def setup_mappers(cls):
  550. pass
  551. class DeclarativeMappedTest(MappedTest):
  552. run_setup_classes = "once"
  553. run_setup_mappers = "once"
  554. @classmethod
  555. def _setup_once_tables(cls):
  556. pass
  557. @classmethod
  558. def _with_register_classes(cls, fn):
  559. cls_registry = cls.classes
  560. class FindFixtureDeclarative(DeclarativeMeta):
  561. def __init__(cls, classname, bases, dict_):
  562. cls_registry[classname] = cls
  563. DeclarativeMeta.__init__(cls, classname, bases, dict_)
  564. class DeclarativeBasic(object):
  565. __table_cls__ = schema.Table
  566. _DeclBase = declarative_base(
  567. metadata=cls._tables_metadata,
  568. metaclass=FindFixtureDeclarative,
  569. cls=DeclarativeBasic,
  570. )
  571. cls.DeclarativeBasic = _DeclBase
  572. # sets up cls.Basic which is helpful for things like composite
  573. # classes
  574. super(DeclarativeMappedTest, cls)._with_register_classes(fn)
  575. if cls._tables_metadata.tables and cls.run_create_tables:
  576. cls._tables_metadata.create_all(config.db)
  577. class ComputedReflectionFixtureTest(TablesTest):
  578. run_inserts = run_deletes = None
  579. __backend__ = True
  580. __requires__ = ("computed_columns", "table_reflection")
  581. regexp = re.compile(r"[\[\]\(\)\s`'\"]*")
  582. def normalize(self, text):
  583. return self.regexp.sub("", text).lower()
  584. @classmethod
  585. def define_tables(cls, metadata):
  586. from .. import Integer
  587. from .. import testing
  588. from ..schema import Column
  589. from ..schema import Computed
  590. from ..schema import Table
  591. Table(
  592. "computed_default_table",
  593. metadata,
  594. Column("id", Integer, primary_key=True),
  595. Column("normal", Integer),
  596. Column("computed_col", Integer, Computed("normal + 42")),
  597. Column("with_default", Integer, server_default="42"),
  598. )
  599. t = Table(
  600. "computed_column_table",
  601. metadata,
  602. Column("id", Integer, primary_key=True),
  603. Column("normal", Integer),
  604. Column("computed_no_flag", Integer, Computed("normal + 42")),
  605. )
  606. if testing.requires.schemas.enabled:
  607. t2 = Table(
  608. "computed_column_table",
  609. metadata,
  610. Column("id", Integer, primary_key=True),
  611. Column("normal", Integer),
  612. Column("computed_no_flag", Integer, Computed("normal / 42")),
  613. schema=config.test_schema,
  614. )
  615. if testing.requires.computed_columns_virtual.enabled:
  616. t.append_column(
  617. Column(
  618. "computed_virtual",
  619. Integer,
  620. Computed("normal + 2", persisted=False),
  621. )
  622. )
  623. if testing.requires.schemas.enabled:
  624. t2.append_column(
  625. Column(
  626. "computed_virtual",
  627. Integer,
  628. Computed("normal / 2", persisted=False),
  629. )
  630. )
  631. if testing.requires.computed_columns_stored.enabled:
  632. t.append_column(
  633. Column(
  634. "computed_stored",
  635. Integer,
  636. Computed("normal - 42", persisted=True),
  637. )
  638. )
  639. if testing.requires.schemas.enabled:
  640. t2.append_column(
  641. Column(
  642. "computed_stored",
  643. Integer,
  644. Computed("normal * 42", persisted=True),
  645. )
  646. )