You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

466 lines
13KB

  1. # testing/exclusions.py
  2. # Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. import contextlib
  8. import operator
  9. import re
  10. import sys
  11. from . import config
  12. from .. import util
  13. from ..util import decorator
  14. from ..util.compat import inspect_getfullargspec
  15. def skip_if(predicate, reason=None):
  16. rule = compound()
  17. pred = _as_predicate(predicate, reason)
  18. rule.skips.add(pred)
  19. return rule
  20. def fails_if(predicate, reason=None):
  21. rule = compound()
  22. pred = _as_predicate(predicate, reason)
  23. rule.fails.add(pred)
  24. return rule
  25. class compound(object):
  26. def __init__(self):
  27. self.fails = set()
  28. self.skips = set()
  29. self.tags = set()
  30. def __add__(self, other):
  31. return self.add(other)
  32. def as_skips(self):
  33. rule = compound()
  34. rule.skips.update(self.skips)
  35. rule.skips.update(self.fails)
  36. rule.tags.update(self.tags)
  37. return rule
  38. def add(self, *others):
  39. copy = compound()
  40. copy.fails.update(self.fails)
  41. copy.skips.update(self.skips)
  42. copy.tags.update(self.tags)
  43. for other in others:
  44. copy.fails.update(other.fails)
  45. copy.skips.update(other.skips)
  46. copy.tags.update(other.tags)
  47. return copy
  48. def not_(self):
  49. copy = compound()
  50. copy.fails.update(NotPredicate(fail) for fail in self.fails)
  51. copy.skips.update(NotPredicate(skip) for skip in self.skips)
  52. copy.tags.update(self.tags)
  53. return copy
  54. @property
  55. def enabled(self):
  56. return self.enabled_for_config(config._current)
  57. def enabled_for_config(self, config):
  58. for predicate in self.skips.union(self.fails):
  59. if predicate(config):
  60. return False
  61. else:
  62. return True
  63. def matching_config_reasons(self, config):
  64. return [
  65. predicate._as_string(config)
  66. for predicate in self.skips.union(self.fails)
  67. if predicate(config)
  68. ]
  69. def include_test(self, include_tags, exclude_tags):
  70. return bool(
  71. not self.tags.intersection(exclude_tags)
  72. and (not include_tags or self.tags.intersection(include_tags))
  73. )
  74. def _extend(self, other):
  75. self.skips.update(other.skips)
  76. self.fails.update(other.fails)
  77. self.tags.update(other.tags)
  78. def __call__(self, fn):
  79. if hasattr(fn, "_sa_exclusion_extend"):
  80. fn._sa_exclusion_extend._extend(self)
  81. return fn
  82. @decorator
  83. def decorate(fn, *args, **kw):
  84. return self._do(config._current, fn, *args, **kw)
  85. decorated = decorate(fn)
  86. decorated._sa_exclusion_extend = self
  87. return decorated
  88. @contextlib.contextmanager
  89. def fail_if(self):
  90. all_fails = compound()
  91. all_fails.fails.update(self.skips.union(self.fails))
  92. try:
  93. yield
  94. except Exception as ex:
  95. all_fails._expect_failure(config._current, ex)
  96. else:
  97. all_fails._expect_success(config._current)
  98. def _do(self, cfg, fn, *args, **kw):
  99. for skip in self.skips:
  100. if skip(cfg):
  101. msg = "'%s' : %s" % (
  102. config.get_current_test_name(),
  103. skip._as_string(cfg),
  104. )
  105. config.skip_test(msg)
  106. try:
  107. return_value = fn(*args, **kw)
  108. except Exception as ex:
  109. self._expect_failure(cfg, ex, name=fn.__name__)
  110. else:
  111. self._expect_success(cfg, name=fn.__name__)
  112. return return_value
  113. def _expect_failure(self, config, ex, name="block"):
  114. for fail in self.fails:
  115. if fail(config):
  116. if util.py2k:
  117. str_ex = unicode(ex).encode( # noqa: F821
  118. "utf-8", errors="ignore"
  119. )
  120. else:
  121. str_ex = str(ex)
  122. print(
  123. (
  124. "%s failed as expected (%s): %s "
  125. % (name, fail._as_string(config), str_ex)
  126. )
  127. )
  128. break
  129. else:
  130. util.raise_(ex, with_traceback=sys.exc_info()[2])
  131. def _expect_success(self, config, name="block"):
  132. if not self.fails:
  133. return
  134. for fail in self.fails:
  135. if fail(config):
  136. raise AssertionError(
  137. "Unexpected success for '%s' (%s)"
  138. % (
  139. name,
  140. " and ".join(
  141. fail._as_string(config) for fail in self.fails
  142. ),
  143. )
  144. )
  145. def requires_tag(tagname):
  146. return tags([tagname])
  147. def tags(tagnames):
  148. comp = compound()
  149. comp.tags.update(tagnames)
  150. return comp
  151. def only_if(predicate, reason=None):
  152. predicate = _as_predicate(predicate)
  153. return skip_if(NotPredicate(predicate), reason)
  154. def succeeds_if(predicate, reason=None):
  155. predicate = _as_predicate(predicate)
  156. return fails_if(NotPredicate(predicate), reason)
  157. class Predicate(object):
  158. @classmethod
  159. def as_predicate(cls, predicate, description=None):
  160. if isinstance(predicate, compound):
  161. return cls.as_predicate(predicate.enabled_for_config, description)
  162. elif isinstance(predicate, Predicate):
  163. if description and predicate.description is None:
  164. predicate.description = description
  165. return predicate
  166. elif isinstance(predicate, (list, set)):
  167. return OrPredicate(
  168. [cls.as_predicate(pred) for pred in predicate], description
  169. )
  170. elif isinstance(predicate, tuple):
  171. return SpecPredicate(*predicate)
  172. elif isinstance(predicate, util.string_types):
  173. tokens = re.match(
  174. r"([\+\w]+)\s*(?:(>=|==|!=|<=|<|>)\s*([\d\.]+))?", predicate
  175. )
  176. if not tokens:
  177. raise ValueError(
  178. "Couldn't locate DB name in predicate: %r" % predicate
  179. )
  180. db = tokens.group(1)
  181. op = tokens.group(2)
  182. spec = (
  183. tuple(int(d) for d in tokens.group(3).split("."))
  184. if tokens.group(3)
  185. else None
  186. )
  187. return SpecPredicate(db, op, spec, description=description)
  188. elif callable(predicate):
  189. return LambdaPredicate(predicate, description)
  190. else:
  191. assert False, "unknown predicate type: %s" % predicate
  192. def _format_description(self, config, negate=False):
  193. bool_ = self(config)
  194. if negate:
  195. bool_ = not negate
  196. return self.description % {
  197. "driver": config.db.url.get_driver_name()
  198. if config
  199. else "<no driver>",
  200. "database": config.db.url.get_backend_name()
  201. if config
  202. else "<no database>",
  203. "doesnt_support": "doesn't support" if bool_ else "does support",
  204. "does_support": "does support" if bool_ else "doesn't support",
  205. }
  206. def _as_string(self, config=None, negate=False):
  207. raise NotImplementedError()
  208. class BooleanPredicate(Predicate):
  209. def __init__(self, value, description=None):
  210. self.value = value
  211. self.description = description or "boolean %s" % value
  212. def __call__(self, config):
  213. return self.value
  214. def _as_string(self, config, negate=False):
  215. return self._format_description(config, negate=negate)
  216. class SpecPredicate(Predicate):
  217. def __init__(self, db, op=None, spec=None, description=None):
  218. self.db = db
  219. self.op = op
  220. self.spec = spec
  221. self.description = description
  222. _ops = {
  223. "<": operator.lt,
  224. ">": operator.gt,
  225. "==": operator.eq,
  226. "!=": operator.ne,
  227. "<=": operator.le,
  228. ">=": operator.ge,
  229. "in": operator.contains,
  230. "between": lambda val, pair: val >= pair[0] and val <= pair[1],
  231. }
  232. def __call__(self, config):
  233. if config is None:
  234. return False
  235. engine = config.db
  236. if "+" in self.db:
  237. dialect, driver = self.db.split("+")
  238. else:
  239. dialect, driver = self.db, None
  240. if dialect and engine.name != dialect:
  241. return False
  242. if driver is not None and engine.driver != driver:
  243. return False
  244. if self.op is not None:
  245. assert driver is None, "DBAPI version specs not supported yet"
  246. version = _server_version(engine)
  247. oper = (
  248. hasattr(self.op, "__call__") and self.op or self._ops[self.op]
  249. )
  250. return oper(version, self.spec)
  251. else:
  252. return True
  253. def _as_string(self, config, negate=False):
  254. if self.description is not None:
  255. return self._format_description(config)
  256. elif self.op is None:
  257. if negate:
  258. return "not %s" % self.db
  259. else:
  260. return "%s" % self.db
  261. else:
  262. if negate:
  263. return "not %s %s %s" % (self.db, self.op, self.spec)
  264. else:
  265. return "%s %s %s" % (self.db, self.op, self.spec)
  266. class LambdaPredicate(Predicate):
  267. def __init__(self, lambda_, description=None, args=None, kw=None):
  268. spec = inspect_getfullargspec(lambda_)
  269. if not spec[0]:
  270. self.lambda_ = lambda db: lambda_()
  271. else:
  272. self.lambda_ = lambda_
  273. self.args = args or ()
  274. self.kw = kw or {}
  275. if description:
  276. self.description = description
  277. elif lambda_.__doc__:
  278. self.description = lambda_.__doc__
  279. else:
  280. self.description = "custom function"
  281. def __call__(self, config):
  282. return self.lambda_(config)
  283. def _as_string(self, config, negate=False):
  284. return self._format_description(config)
  285. class NotPredicate(Predicate):
  286. def __init__(self, predicate, description=None):
  287. self.predicate = predicate
  288. self.description = description
  289. def __call__(self, config):
  290. return not self.predicate(config)
  291. def _as_string(self, config, negate=False):
  292. if self.description:
  293. return self._format_description(config, not negate)
  294. else:
  295. return self.predicate._as_string(config, not negate)
  296. class OrPredicate(Predicate):
  297. def __init__(self, predicates, description=None):
  298. self.predicates = predicates
  299. self.description = description
  300. def __call__(self, config):
  301. for pred in self.predicates:
  302. if pred(config):
  303. return True
  304. return False
  305. def _eval_str(self, config, negate=False):
  306. if negate:
  307. conjunction = " and "
  308. else:
  309. conjunction = " or "
  310. return conjunction.join(
  311. p._as_string(config, negate=negate) for p in self.predicates
  312. )
  313. def _negation_str(self, config):
  314. if self.description is not None:
  315. return "Not " + self._format_description(config)
  316. else:
  317. return self._eval_str(config, negate=True)
  318. def _as_string(self, config, negate=False):
  319. if negate:
  320. return self._negation_str(config)
  321. else:
  322. if self.description is not None:
  323. return self._format_description(config)
  324. else:
  325. return self._eval_str(config)
  326. _as_predicate = Predicate.as_predicate
  327. def _is_excluded(db, op, spec):
  328. return SpecPredicate(db, op, spec)(config._current)
  329. def _server_version(engine):
  330. """Return a server_version_info tuple."""
  331. # force metadata to be retrieved
  332. conn = engine.connect()
  333. version = getattr(engine.dialect, "server_version_info", None)
  334. if version is None:
  335. version = ()
  336. conn.close()
  337. return version
  338. def db_spec(*dbs):
  339. return OrPredicate([Predicate.as_predicate(db) for db in dbs])
  340. def open(): # noqa
  341. return skip_if(BooleanPredicate(False, "mark as execute"))
  342. def closed():
  343. return skip_if(BooleanPredicate(True, "marked as skip"))
  344. def fails(reason=None):
  345. return fails_if(BooleanPredicate(True, reason or "expected to fail"))
  346. @decorator
  347. def future(fn, *arg):
  348. return fails_if(LambdaPredicate(fn), "Future feature")
  349. def fails_on(db, reason=None):
  350. return fails_if(db, reason)
  351. def fails_on_everything_except(*dbs):
  352. return succeeds_if(OrPredicate([Predicate.as_predicate(db) for db in dbs]))
  353. def skip(db, reason=None):
  354. return skip_if(db, reason)
  355. def only_on(dbs, reason=None):
  356. return only_if(
  357. OrPredicate(
  358. [Predicate.as_predicate(db, reason) for db in util.to_list(dbs)]
  359. )
  360. )
  361. def exclude(db, op, spec, reason=None):
  362. return skip_if(SpecPredicate(db, op, spec), reason)
  363. def against(config, *queries):
  364. assert queries, "no queries sent!"
  365. return OrPredicate([Predicate.as_predicate(query) for query in queries])(
  366. config
  367. )