You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1126 lines
43KB

  1. """Rewrite assertion AST to produce nice error messages."""
  2. import ast
  3. import errno
  4. import functools
  5. import importlib.abc
  6. import importlib.machinery
  7. import importlib.util
  8. import io
  9. import itertools
  10. import marshal
  11. import os
  12. import struct
  13. import sys
  14. import tokenize
  15. import types
  16. from pathlib import Path
  17. from pathlib import PurePath
  18. from typing import Callable
  19. from typing import Dict
  20. from typing import IO
  21. from typing import Iterable
  22. from typing import List
  23. from typing import Optional
  24. from typing import Sequence
  25. from typing import Set
  26. from typing import Tuple
  27. from typing import TYPE_CHECKING
  28. from typing import Union
  29. import py
  30. from _pytest._io.saferepr import saferepr
  31. from _pytest._version import version
  32. from _pytest.assertion import util
  33. from _pytest.assertion.util import ( # noqa: F401
  34. format_explanation as _format_explanation,
  35. )
  36. from _pytest.config import Config
  37. from _pytest.main import Session
  38. from _pytest.pathlib import fnmatch_ex
  39. from _pytest.store import StoreKey
  40. if TYPE_CHECKING:
  41. from _pytest.assertion import AssertionState
  42. assertstate_key = StoreKey["AssertionState"]()
  43. # pytest caches rewritten pycs in pycache dirs
  44. PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
  45. PYC_EXT = ".py" + (__debug__ and "c" or "o")
  46. PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
  47. class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
  48. """PEP302/PEP451 import hook which rewrites asserts."""
  49. def __init__(self, config: Config) -> None:
  50. self.config = config
  51. try:
  52. self.fnpats = config.getini("python_files")
  53. except ValueError:
  54. self.fnpats = ["test_*.py", "*_test.py"]
  55. self.session: Optional[Session] = None
  56. self._rewritten_names: Set[str] = set()
  57. self._must_rewrite: Set[str] = set()
  58. # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
  59. # which might result in infinite recursion (#3506)
  60. self._writing_pyc = False
  61. self._basenames_to_check_rewrite = {"conftest"}
  62. self._marked_for_rewrite_cache: Dict[str, bool] = {}
  63. self._session_paths_checked = False
  64. def set_session(self, session: Optional[Session]) -> None:
  65. self.session = session
  66. self._session_paths_checked = False
  67. # Indirection so we can mock calls to find_spec originated from the hook during testing
  68. _find_spec = importlib.machinery.PathFinder.find_spec
  69. def find_spec(
  70. self,
  71. name: str,
  72. path: Optional[Sequence[Union[str, bytes]]] = None,
  73. target: Optional[types.ModuleType] = None,
  74. ) -> Optional[importlib.machinery.ModuleSpec]:
  75. if self._writing_pyc:
  76. return None
  77. state = self.config._store[assertstate_key]
  78. if self._early_rewrite_bailout(name, state):
  79. return None
  80. state.trace("find_module called for: %s" % name)
  81. # Type ignored because mypy is confused about the `self` binding here.
  82. spec = self._find_spec(name, path) # type: ignore
  83. if (
  84. # the import machinery could not find a file to import
  85. spec is None
  86. # this is a namespace package (without `__init__.py`)
  87. # there's nothing to rewrite there
  88. # python3.6: `namespace`
  89. # python3.7+: `None`
  90. or spec.origin == "namespace"
  91. or spec.origin is None
  92. # we can only rewrite source files
  93. or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
  94. # if the file doesn't exist, we can't rewrite it
  95. or not os.path.exists(spec.origin)
  96. ):
  97. return None
  98. else:
  99. fn = spec.origin
  100. if not self._should_rewrite(name, fn, state):
  101. return None
  102. return importlib.util.spec_from_file_location(
  103. name,
  104. fn,
  105. loader=self,
  106. submodule_search_locations=spec.submodule_search_locations,
  107. )
  108. def create_module(
  109. self, spec: importlib.machinery.ModuleSpec
  110. ) -> Optional[types.ModuleType]:
  111. return None # default behaviour is fine
  112. def exec_module(self, module: types.ModuleType) -> None:
  113. assert module.__spec__ is not None
  114. assert module.__spec__.origin is not None
  115. fn = Path(module.__spec__.origin)
  116. state = self.config._store[assertstate_key]
  117. self._rewritten_names.add(module.__name__)
  118. # The requested module looks like a test file, so rewrite it. This is
  119. # the most magical part of the process: load the source, rewrite the
  120. # asserts, and load the rewritten source. We also cache the rewritten
  121. # module code in a special pyc. We must be aware of the possibility of
  122. # concurrent pytest processes rewriting and loading pycs. To avoid
  123. # tricky race conditions, we maintain the following invariant: The
  124. # cached pyc is always a complete, valid pyc. Operations on it must be
  125. # atomic. POSIX's atomic rename comes in handy.
  126. write = not sys.dont_write_bytecode
  127. cache_dir = get_cache_dir(fn)
  128. if write:
  129. ok = try_makedirs(cache_dir)
  130. if not ok:
  131. write = False
  132. state.trace(f"read only directory: {cache_dir}")
  133. cache_name = fn.name[:-3] + PYC_TAIL
  134. pyc = cache_dir / cache_name
  135. # Notice that even if we're in a read-only directory, I'm going
  136. # to check for a cached pyc. This may not be optimal...
  137. co = _read_pyc(fn, pyc, state.trace)
  138. if co is None:
  139. state.trace(f"rewriting {fn!r}")
  140. source_stat, co = _rewrite_test(fn, self.config)
  141. if write:
  142. self._writing_pyc = True
  143. try:
  144. _write_pyc(state, co, source_stat, pyc)
  145. finally:
  146. self._writing_pyc = False
  147. else:
  148. state.trace(f"found cached rewritten pyc for {fn}")
  149. exec(co, module.__dict__)
  150. def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
  151. """A fast way to get out of rewriting modules.
  152. Profiling has shown that the call to PathFinder.find_spec (inside of
  153. the find_spec from this class) is a major slowdown, so, this method
  154. tries to filter what we're sure won't be rewritten before getting to
  155. it.
  156. """
  157. if self.session is not None and not self._session_paths_checked:
  158. self._session_paths_checked = True
  159. for initial_path in self.session._initialpaths:
  160. # Make something as c:/projects/my_project/path.py ->
  161. # ['c:', 'projects', 'my_project', 'path.py']
  162. parts = str(initial_path).split(os.path.sep)
  163. # add 'path' to basenames to be checked.
  164. self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
  165. # Note: conftest already by default in _basenames_to_check_rewrite.
  166. parts = name.split(".")
  167. if parts[-1] in self._basenames_to_check_rewrite:
  168. return False
  169. # For matching the name it must be as if it was a filename.
  170. path = PurePath(os.path.sep.join(parts) + ".py")
  171. for pat in self.fnpats:
  172. # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
  173. # on the name alone because we need to match against the full path
  174. if os.path.dirname(pat):
  175. return False
  176. if fnmatch_ex(pat, path):
  177. return False
  178. if self._is_marked_for_rewrite(name, state):
  179. return False
  180. state.trace(f"early skip of rewriting module: {name}")
  181. return True
  182. def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
  183. # always rewrite conftest files
  184. if os.path.basename(fn) == "conftest.py":
  185. state.trace(f"rewriting conftest file: {fn!r}")
  186. return True
  187. if self.session is not None:
  188. if self.session.isinitpath(py.path.local(fn)):
  189. state.trace(f"matched test file (was specified on cmdline): {fn!r}")
  190. return True
  191. # modules not passed explicitly on the command line are only
  192. # rewritten if they match the naming convention for test files
  193. fn_path = PurePath(fn)
  194. for pat in self.fnpats:
  195. if fnmatch_ex(pat, fn_path):
  196. state.trace(f"matched test file {fn!r}")
  197. return True
  198. return self._is_marked_for_rewrite(name, state)
  199. def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
  200. try:
  201. return self._marked_for_rewrite_cache[name]
  202. except KeyError:
  203. for marked in self._must_rewrite:
  204. if name == marked or name.startswith(marked + "."):
  205. state.trace(f"matched marked file {name!r} (from {marked!r})")
  206. self._marked_for_rewrite_cache[name] = True
  207. return True
  208. self._marked_for_rewrite_cache[name] = False
  209. return False
  210. def mark_rewrite(self, *names: str) -> None:
  211. """Mark import names as needing to be rewritten.
  212. The named module or package as well as any nested modules will
  213. be rewritten on import.
  214. """
  215. already_imported = (
  216. set(names).intersection(sys.modules).difference(self._rewritten_names)
  217. )
  218. for name in already_imported:
  219. mod = sys.modules[name]
  220. if not AssertionRewriter.is_rewrite_disabled(
  221. mod.__doc__ or ""
  222. ) and not isinstance(mod.__loader__, type(self)):
  223. self._warn_already_imported(name)
  224. self._must_rewrite.update(names)
  225. self._marked_for_rewrite_cache.clear()
  226. def _warn_already_imported(self, name: str) -> None:
  227. from _pytest.warning_types import PytestAssertRewriteWarning
  228. self.config.issue_config_time_warning(
  229. PytestAssertRewriteWarning(
  230. "Module already imported so cannot be rewritten: %s" % name
  231. ),
  232. stacklevel=5,
  233. )
  234. def get_data(self, pathname: Union[str, bytes]) -> bytes:
  235. """Optional PEP302 get_data API."""
  236. with open(pathname, "rb") as f:
  237. return f.read()
  238. def _write_pyc_fp(
  239. fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
  240. ) -> None:
  241. # Technically, we don't have to have the same pyc format as
  242. # (C)Python, since these "pycs" should never be seen by builtin
  243. # import. However, there's little reason to deviate.
  244. fp.write(importlib.util.MAGIC_NUMBER)
  245. # https://www.python.org/dev/peps/pep-0552/
  246. if sys.version_info >= (3, 7):
  247. flags = b"\x00\x00\x00\x00"
  248. fp.write(flags)
  249. # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
  250. mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
  251. size = source_stat.st_size & 0xFFFFFFFF
  252. # "<LL" stands for 2 unsigned longs, little-endian.
  253. fp.write(struct.pack("<LL", mtime, size))
  254. fp.write(marshal.dumps(co))
  255. if sys.platform == "win32":
  256. from atomicwrites import atomic_write
  257. def _write_pyc(
  258. state: "AssertionState",
  259. co: types.CodeType,
  260. source_stat: os.stat_result,
  261. pyc: Path,
  262. ) -> bool:
  263. try:
  264. with atomic_write(os.fspath(pyc), mode="wb", overwrite=True) as fp:
  265. _write_pyc_fp(fp, source_stat, co)
  266. except OSError as e:
  267. state.trace(f"error writing pyc file at {pyc}: {e}")
  268. # we ignore any failure to write the cache file
  269. # there are many reasons, permission-denied, pycache dir being a
  270. # file etc.
  271. return False
  272. return True
  273. else:
  274. def _write_pyc(
  275. state: "AssertionState",
  276. co: types.CodeType,
  277. source_stat: os.stat_result,
  278. pyc: Path,
  279. ) -> bool:
  280. proc_pyc = f"{pyc}.{os.getpid()}"
  281. try:
  282. fp = open(proc_pyc, "wb")
  283. except OSError as e:
  284. state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
  285. return False
  286. try:
  287. _write_pyc_fp(fp, source_stat, co)
  288. os.rename(proc_pyc, os.fspath(pyc))
  289. except OSError as e:
  290. state.trace(f"error writing pyc file at {pyc}: {e}")
  291. # we ignore any failure to write the cache file
  292. # there are many reasons, permission-denied, pycache dir being a
  293. # file etc.
  294. return False
  295. finally:
  296. fp.close()
  297. return True
  298. def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
  299. """Read and rewrite *fn* and return the code object."""
  300. fn_ = os.fspath(fn)
  301. stat = os.stat(fn_)
  302. with open(fn_, "rb") as f:
  303. source = f.read()
  304. tree = ast.parse(source, filename=fn_)
  305. rewrite_asserts(tree, source, fn_, config)
  306. co = compile(tree, fn_, "exec", dont_inherit=True)
  307. return stat, co
  308. def _read_pyc(
  309. source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
  310. ) -> Optional[types.CodeType]:
  311. """Possibly read a pytest pyc containing rewritten code.
  312. Return rewritten code if successful or None if not.
  313. """
  314. try:
  315. fp = open(os.fspath(pyc), "rb")
  316. except OSError:
  317. return None
  318. with fp:
  319. # https://www.python.org/dev/peps/pep-0552/
  320. has_flags = sys.version_info >= (3, 7)
  321. try:
  322. stat_result = os.stat(os.fspath(source))
  323. mtime = int(stat_result.st_mtime)
  324. size = stat_result.st_size
  325. data = fp.read(16 if has_flags else 12)
  326. except OSError as e:
  327. trace(f"_read_pyc({source}): OSError {e}")
  328. return None
  329. # Check for invalid or out of date pyc file.
  330. if len(data) != (16 if has_flags else 12):
  331. trace("_read_pyc(%s): invalid pyc (too short)" % source)
  332. return None
  333. if data[:4] != importlib.util.MAGIC_NUMBER:
  334. trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
  335. return None
  336. if has_flags and data[4:8] != b"\x00\x00\x00\x00":
  337. trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
  338. return None
  339. mtime_data = data[8 if has_flags else 4 : 12 if has_flags else 8]
  340. if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
  341. trace("_read_pyc(%s): out of date" % source)
  342. return None
  343. size_data = data[12 if has_flags else 8 : 16 if has_flags else 12]
  344. if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
  345. trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
  346. return None
  347. try:
  348. co = marshal.load(fp)
  349. except Exception as e:
  350. trace(f"_read_pyc({source}): marshal.load error {e}")
  351. return None
  352. if not isinstance(co, types.CodeType):
  353. trace("_read_pyc(%s): not a code object" % source)
  354. return None
  355. return co
  356. def rewrite_asserts(
  357. mod: ast.Module,
  358. source: bytes,
  359. module_path: Optional[str] = None,
  360. config: Optional[Config] = None,
  361. ) -> None:
  362. """Rewrite the assert statements in mod."""
  363. AssertionRewriter(module_path, config, source).run(mod)
  364. def _saferepr(obj: object) -> str:
  365. r"""Get a safe repr of an object for assertion error messages.
  366. The assertion formatting (util.format_explanation()) requires
  367. newlines to be escaped since they are a special character for it.
  368. Normally assertion.util.format_explanation() does this but for a
  369. custom repr it is possible to contain one of the special escape
  370. sequences, especially '\n{' and '\n}' are likely to be present in
  371. JSON reprs.
  372. """
  373. return saferepr(obj).replace("\n", "\\n")
  374. def _format_assertmsg(obj: object) -> str:
  375. r"""Format the custom assertion message given.
  376. For strings this simply replaces newlines with '\n~' so that
  377. util.format_explanation() will preserve them instead of escaping
  378. newlines. For other objects saferepr() is used first.
  379. """
  380. # reprlib appears to have a bug which means that if a string
  381. # contains a newline it gets escaped, however if an object has a
  382. # .__repr__() which contains newlines it does not get escaped.
  383. # However in either case we want to preserve the newline.
  384. replaces = [("\n", "\n~"), ("%", "%%")]
  385. if not isinstance(obj, str):
  386. obj = saferepr(obj)
  387. replaces.append(("\\n", "\n~"))
  388. for r1, r2 in replaces:
  389. obj = obj.replace(r1, r2)
  390. return obj
  391. def _should_repr_global_name(obj: object) -> bool:
  392. if callable(obj):
  393. return False
  394. try:
  395. return not hasattr(obj, "__name__")
  396. except Exception:
  397. return True
  398. def _format_boolop(explanations: Iterable[str], is_or: bool) -> str:
  399. explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
  400. return explanation.replace("%", "%%")
  401. def _call_reprcompare(
  402. ops: Sequence[str],
  403. results: Sequence[bool],
  404. expls: Sequence[str],
  405. each_obj: Sequence[object],
  406. ) -> str:
  407. for i, res, expl in zip(range(len(ops)), results, expls):
  408. try:
  409. done = not res
  410. except Exception:
  411. done = True
  412. if done:
  413. break
  414. if util._reprcompare is not None:
  415. custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
  416. if custom is not None:
  417. return custom
  418. return expl
  419. def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
  420. if util._assertion_pass is not None:
  421. util._assertion_pass(lineno, orig, expl)
  422. def _check_if_assertion_pass_impl() -> bool:
  423. """Check if any plugins implement the pytest_assertion_pass hook
  424. in order not to generate explanation unecessarily (might be expensive)."""
  425. return True if util._assertion_pass else False
  426. UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
  427. BINOP_MAP = {
  428. ast.BitOr: "|",
  429. ast.BitXor: "^",
  430. ast.BitAnd: "&",
  431. ast.LShift: "<<",
  432. ast.RShift: ">>",
  433. ast.Add: "+",
  434. ast.Sub: "-",
  435. ast.Mult: "*",
  436. ast.Div: "/",
  437. ast.FloorDiv: "//",
  438. ast.Mod: "%%", # escaped for string formatting
  439. ast.Eq: "==",
  440. ast.NotEq: "!=",
  441. ast.Lt: "<",
  442. ast.LtE: "<=",
  443. ast.Gt: ">",
  444. ast.GtE: ">=",
  445. ast.Pow: "**",
  446. ast.Is: "is",
  447. ast.IsNot: "is not",
  448. ast.In: "in",
  449. ast.NotIn: "not in",
  450. ast.MatMult: "@",
  451. }
  452. def set_location(node, lineno, col_offset):
  453. """Set node location information recursively."""
  454. def _fix(node, lineno, col_offset):
  455. if "lineno" in node._attributes:
  456. node.lineno = lineno
  457. if "col_offset" in node._attributes:
  458. node.col_offset = col_offset
  459. for child in ast.iter_child_nodes(node):
  460. _fix(child, lineno, col_offset)
  461. _fix(node, lineno, col_offset)
  462. return node
  463. def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
  464. """Return a mapping from {lineno: "assertion test expression"}."""
  465. ret: Dict[int, str] = {}
  466. depth = 0
  467. lines: List[str] = []
  468. assert_lineno: Optional[int] = None
  469. seen_lines: Set[int] = set()
  470. def _write_and_reset() -> None:
  471. nonlocal depth, lines, assert_lineno, seen_lines
  472. assert assert_lineno is not None
  473. ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
  474. depth = 0
  475. lines = []
  476. assert_lineno = None
  477. seen_lines = set()
  478. tokens = tokenize.tokenize(io.BytesIO(src).readline)
  479. for tp, source, (lineno, offset), _, line in tokens:
  480. if tp == tokenize.NAME and source == "assert":
  481. assert_lineno = lineno
  482. elif assert_lineno is not None:
  483. # keep track of depth for the assert-message `,` lookup
  484. if tp == tokenize.OP and source in "([{":
  485. depth += 1
  486. elif tp == tokenize.OP and source in ")]}":
  487. depth -= 1
  488. if not lines:
  489. lines.append(line[offset:])
  490. seen_lines.add(lineno)
  491. # a non-nested comma separates the expression from the message
  492. elif depth == 0 and tp == tokenize.OP and source == ",":
  493. # one line assert with message
  494. if lineno in seen_lines and len(lines) == 1:
  495. offset_in_trimmed = offset + len(lines[-1]) - len(line)
  496. lines[-1] = lines[-1][:offset_in_trimmed]
  497. # multi-line assert with message
  498. elif lineno in seen_lines:
  499. lines[-1] = lines[-1][:offset]
  500. # multi line assert with escapd newline before message
  501. else:
  502. lines.append(line[:offset])
  503. _write_and_reset()
  504. elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
  505. _write_and_reset()
  506. elif lines and lineno not in seen_lines:
  507. lines.append(line)
  508. seen_lines.add(lineno)
  509. return ret
  510. class AssertionRewriter(ast.NodeVisitor):
  511. """Assertion rewriting implementation.
  512. The main entrypoint is to call .run() with an ast.Module instance,
  513. this will then find all the assert statements and rewrite them to
  514. provide intermediate values and a detailed assertion error. See
  515. http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
  516. for an overview of how this works.
  517. The entry point here is .run() which will iterate over all the
  518. statements in an ast.Module and for each ast.Assert statement it
  519. finds call .visit() with it. Then .visit_Assert() takes over and
  520. is responsible for creating new ast statements to replace the
  521. original assert statement: it rewrites the test of an assertion
  522. to provide intermediate values and replace it with an if statement
  523. which raises an assertion error with a detailed explanation in
  524. case the expression is false and calls pytest_assertion_pass hook
  525. if expression is true.
  526. For this .visit_Assert() uses the visitor pattern to visit all the
  527. AST nodes of the ast.Assert.test field, each visit call returning
  528. an AST node and the corresponding explanation string. During this
  529. state is kept in several instance attributes:
  530. :statements: All the AST statements which will replace the assert
  531. statement.
  532. :variables: This is populated by .variable() with each variable
  533. used by the statements so that they can all be set to None at
  534. the end of the statements.
  535. :variable_counter: Counter to create new unique variables needed
  536. by statements. Variables are created using .variable() and
  537. have the form of "@py_assert0".
  538. :expl_stmts: The AST statements which will be executed to get
  539. data from the assertion. This is the code which will construct
  540. the detailed assertion message that is used in the AssertionError
  541. or for the pytest_assertion_pass hook.
  542. :explanation_specifiers: A dict filled by .explanation_param()
  543. with %-formatting placeholders and their corresponding
  544. expressions to use in the building of an assertion message.
  545. This is used by .pop_format_context() to build a message.
  546. :stack: A stack of the explanation_specifiers dicts maintained by
  547. .push_format_context() and .pop_format_context() which allows
  548. to build another %-formatted string while already building one.
  549. This state is reset on every new assert statement visited and used
  550. by the other visitors.
  551. """
  552. def __init__(
  553. self, module_path: Optional[str], config: Optional[Config], source: bytes
  554. ) -> None:
  555. super().__init__()
  556. self.module_path = module_path
  557. self.config = config
  558. if config is not None:
  559. self.enable_assertion_pass_hook = config.getini(
  560. "enable_assertion_pass_hook"
  561. )
  562. else:
  563. self.enable_assertion_pass_hook = False
  564. self.source = source
  565. @functools.lru_cache(maxsize=1)
  566. def _assert_expr_to_lineno(self) -> Dict[int, str]:
  567. return _get_assertion_exprs(self.source)
  568. def run(self, mod: ast.Module) -> None:
  569. """Find all assert statements in *mod* and rewrite them."""
  570. if not mod.body:
  571. # Nothing to do.
  572. return
  573. # We'll insert some special imports at the top of the module, but after any
  574. # docstrings and __future__ imports, so first figure out where that is.
  575. doc = getattr(mod, "docstring", None)
  576. expect_docstring = doc is None
  577. if doc is not None and self.is_rewrite_disabled(doc):
  578. return
  579. pos = 0
  580. lineno = 1
  581. for item in mod.body:
  582. if (
  583. expect_docstring
  584. and isinstance(item, ast.Expr)
  585. and isinstance(item.value, ast.Str)
  586. ):
  587. doc = item.value.s
  588. if self.is_rewrite_disabled(doc):
  589. return
  590. expect_docstring = False
  591. elif (
  592. isinstance(item, ast.ImportFrom)
  593. and item.level == 0
  594. and item.module == "__future__"
  595. ):
  596. pass
  597. else:
  598. break
  599. pos += 1
  600. # Special case: for a decorated function, set the lineno to that of the
  601. # first decorator, not the `def`. Issue #4984.
  602. if isinstance(item, ast.FunctionDef) and item.decorator_list:
  603. lineno = item.decorator_list[0].lineno
  604. else:
  605. lineno = item.lineno
  606. # Now actually insert the special imports.
  607. if sys.version_info >= (3, 10):
  608. aliases = [
  609. ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
  610. ast.alias(
  611. "_pytest.assertion.rewrite",
  612. "@pytest_ar",
  613. lineno=lineno,
  614. col_offset=0,
  615. ),
  616. ]
  617. else:
  618. aliases = [
  619. ast.alias("builtins", "@py_builtins"),
  620. ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
  621. ]
  622. imports = [
  623. ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
  624. ]
  625. mod.body[pos:pos] = imports
  626. # Collect asserts.
  627. nodes: List[ast.AST] = [mod]
  628. while nodes:
  629. node = nodes.pop()
  630. for name, field in ast.iter_fields(node):
  631. if isinstance(field, list):
  632. new: List[ast.AST] = []
  633. for i, child in enumerate(field):
  634. if isinstance(child, ast.Assert):
  635. # Transform assert.
  636. new.extend(self.visit(child))
  637. else:
  638. new.append(child)
  639. if isinstance(child, ast.AST):
  640. nodes.append(child)
  641. setattr(node, name, new)
  642. elif (
  643. isinstance(field, ast.AST)
  644. # Don't recurse into expressions as they can't contain
  645. # asserts.
  646. and not isinstance(field, ast.expr)
  647. ):
  648. nodes.append(field)
  649. @staticmethod
  650. def is_rewrite_disabled(docstring: str) -> bool:
  651. return "PYTEST_DONT_REWRITE" in docstring
  652. def variable(self) -> str:
  653. """Get a new variable."""
  654. # Use a character invalid in python identifiers to avoid clashing.
  655. name = "@py_assert" + str(next(self.variable_counter))
  656. self.variables.append(name)
  657. return name
  658. def assign(self, expr: ast.expr) -> ast.Name:
  659. """Give *expr* a name."""
  660. name = self.variable()
  661. self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
  662. return ast.Name(name, ast.Load())
  663. def display(self, expr: ast.expr) -> ast.expr:
  664. """Call saferepr on the expression."""
  665. return self.helper("_saferepr", expr)
  666. def helper(self, name: str, *args: ast.expr) -> ast.expr:
  667. """Call a helper in this module."""
  668. py_name = ast.Name("@pytest_ar", ast.Load())
  669. attr = ast.Attribute(py_name, name, ast.Load())
  670. return ast.Call(attr, list(args), [])
  671. def builtin(self, name: str) -> ast.Attribute:
  672. """Return the builtin called *name*."""
  673. builtin_name = ast.Name("@py_builtins", ast.Load())
  674. return ast.Attribute(builtin_name, name, ast.Load())
  675. def explanation_param(self, expr: ast.expr) -> str:
  676. """Return a new named %-formatting placeholder for expr.
  677. This creates a %-formatting placeholder for expr in the
  678. current formatting context, e.g. ``%(py0)s``. The placeholder
  679. and expr are placed in the current format context so that it
  680. can be used on the next call to .pop_format_context().
  681. """
  682. specifier = "py" + str(next(self.variable_counter))
  683. self.explanation_specifiers[specifier] = expr
  684. return "%(" + specifier + ")s"
  685. def push_format_context(self) -> None:
  686. """Create a new formatting context.
  687. The format context is used for when an explanation wants to
  688. have a variable value formatted in the assertion message. In
  689. this case the value required can be added using
  690. .explanation_param(). Finally .pop_format_context() is used
  691. to format a string of %-formatted values as added by
  692. .explanation_param().
  693. """
  694. self.explanation_specifiers: Dict[str, ast.expr] = {}
  695. self.stack.append(self.explanation_specifiers)
  696. def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
  697. """Format the %-formatted string with current format context.
  698. The expl_expr should be an str ast.expr instance constructed from
  699. the %-placeholders created by .explanation_param(). This will
  700. add the required code to format said string to .expl_stmts and
  701. return the ast.Name instance of the formatted string.
  702. """
  703. current = self.stack.pop()
  704. if self.stack:
  705. self.explanation_specifiers = self.stack[-1]
  706. keys = [ast.Str(key) for key in current.keys()]
  707. format_dict = ast.Dict(keys, list(current.values()))
  708. form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
  709. name = "@py_format" + str(next(self.variable_counter))
  710. if self.enable_assertion_pass_hook:
  711. self.format_variables.append(name)
  712. self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
  713. return ast.Name(name, ast.Load())
  714. def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
  715. """Handle expressions we don't have custom code for."""
  716. assert isinstance(node, ast.expr)
  717. res = self.assign(node)
  718. return res, self.explanation_param(self.display(res))
  719. def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
  720. """Return the AST statements to replace the ast.Assert instance.
  721. This rewrites the test of an assertion to provide
  722. intermediate values and replace it with an if statement which
  723. raises an assertion error with a detailed explanation in case
  724. the expression is false.
  725. """
  726. if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
  727. from _pytest.warning_types import PytestAssertRewriteWarning
  728. import warnings
  729. # TODO: This assert should not be needed.
  730. assert self.module_path is not None
  731. warnings.warn_explicit(
  732. PytestAssertRewriteWarning(
  733. "assertion is always true, perhaps remove parentheses?"
  734. ),
  735. category=None,
  736. filename=os.fspath(self.module_path),
  737. lineno=assert_.lineno,
  738. )
  739. self.statements: List[ast.stmt] = []
  740. self.variables: List[str] = []
  741. self.variable_counter = itertools.count()
  742. if self.enable_assertion_pass_hook:
  743. self.format_variables: List[str] = []
  744. self.stack: List[Dict[str, ast.expr]] = []
  745. self.expl_stmts: List[ast.stmt] = []
  746. self.push_format_context()
  747. # Rewrite assert into a bunch of statements.
  748. top_condition, explanation = self.visit(assert_.test)
  749. negation = ast.UnaryOp(ast.Not(), top_condition)
  750. if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
  751. msg = self.pop_format_context(ast.Str(explanation))
  752. # Failed
  753. if assert_.msg:
  754. assertmsg = self.helper("_format_assertmsg", assert_.msg)
  755. gluestr = "\n>assert "
  756. else:
  757. assertmsg = ast.Str("")
  758. gluestr = "assert "
  759. err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
  760. err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
  761. err_name = ast.Name("AssertionError", ast.Load())
  762. fmt = self.helper("_format_explanation", err_msg)
  763. exc = ast.Call(err_name, [fmt], [])
  764. raise_ = ast.Raise(exc, None)
  765. statements_fail = []
  766. statements_fail.extend(self.expl_stmts)
  767. statements_fail.append(raise_)
  768. # Passed
  769. fmt_pass = self.helper("_format_explanation", msg)
  770. orig = self._assert_expr_to_lineno()[assert_.lineno]
  771. hook_call_pass = ast.Expr(
  772. self.helper(
  773. "_call_assertion_pass",
  774. ast.Num(assert_.lineno),
  775. ast.Str(orig),
  776. fmt_pass,
  777. )
  778. )
  779. # If any hooks implement assert_pass hook
  780. hook_impl_test = ast.If(
  781. self.helper("_check_if_assertion_pass_impl"),
  782. self.expl_stmts + [hook_call_pass],
  783. [],
  784. )
  785. statements_pass = [hook_impl_test]
  786. # Test for assertion condition
  787. main_test = ast.If(negation, statements_fail, statements_pass)
  788. self.statements.append(main_test)
  789. if self.format_variables:
  790. variables = [
  791. ast.Name(name, ast.Store()) for name in self.format_variables
  792. ]
  793. clear_format = ast.Assign(variables, ast.NameConstant(None))
  794. self.statements.append(clear_format)
  795. else: # Original assertion rewriting
  796. # Create failure message.
  797. body = self.expl_stmts
  798. self.statements.append(ast.If(negation, body, []))
  799. if assert_.msg:
  800. assertmsg = self.helper("_format_assertmsg", assert_.msg)
  801. explanation = "\n>assert " + explanation
  802. else:
  803. assertmsg = ast.Str("")
  804. explanation = "assert " + explanation
  805. template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
  806. msg = self.pop_format_context(template)
  807. fmt = self.helper("_format_explanation", msg)
  808. err_name = ast.Name("AssertionError", ast.Load())
  809. exc = ast.Call(err_name, [fmt], [])
  810. raise_ = ast.Raise(exc, None)
  811. body.append(raise_)
  812. # Clear temporary variables by setting them to None.
  813. if self.variables:
  814. variables = [ast.Name(name, ast.Store()) for name in self.variables]
  815. clear = ast.Assign(variables, ast.NameConstant(None))
  816. self.statements.append(clear)
  817. # Fix line numbers.
  818. for stmt in self.statements:
  819. set_location(stmt, assert_.lineno, assert_.col_offset)
  820. return self.statements
  821. def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
  822. # Display the repr of the name if it's a local variable or
  823. # _should_repr_global_name() thinks it's acceptable.
  824. locs = ast.Call(self.builtin("locals"), [], [])
  825. inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
  826. dorepr = self.helper("_should_repr_global_name", name)
  827. test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
  828. expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
  829. return name, self.explanation_param(expr)
  830. def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
  831. res_var = self.variable()
  832. expl_list = self.assign(ast.List([], ast.Load()))
  833. app = ast.Attribute(expl_list, "append", ast.Load())
  834. is_or = int(isinstance(boolop.op, ast.Or))
  835. body = save = self.statements
  836. fail_save = self.expl_stmts
  837. levels = len(boolop.values) - 1
  838. self.push_format_context()
  839. # Process each operand, short-circuiting if needed.
  840. for i, v in enumerate(boolop.values):
  841. if i:
  842. fail_inner: List[ast.stmt] = []
  843. # cond is set in a prior loop iteration below
  844. self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
  845. self.expl_stmts = fail_inner
  846. self.push_format_context()
  847. res, expl = self.visit(v)
  848. body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
  849. expl_format = self.pop_format_context(ast.Str(expl))
  850. call = ast.Call(app, [expl_format], [])
  851. self.expl_stmts.append(ast.Expr(call))
  852. if i < levels:
  853. cond: ast.expr = res
  854. if is_or:
  855. cond = ast.UnaryOp(ast.Not(), cond)
  856. inner: List[ast.stmt] = []
  857. self.statements.append(ast.If(cond, inner, []))
  858. self.statements = body = inner
  859. self.statements = save
  860. self.expl_stmts = fail_save
  861. expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
  862. expl = self.pop_format_context(expl_template)
  863. return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
  864. def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
  865. pattern = UNARY_MAP[unary.op.__class__]
  866. operand_res, operand_expl = self.visit(unary.operand)
  867. res = self.assign(ast.UnaryOp(unary.op, operand_res))
  868. return res, pattern % (operand_expl,)
  869. def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
  870. symbol = BINOP_MAP[binop.op.__class__]
  871. left_expr, left_expl = self.visit(binop.left)
  872. right_expr, right_expl = self.visit(binop.right)
  873. explanation = f"({left_expl} {symbol} {right_expl})"
  874. res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
  875. return res, explanation
  876. def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
  877. new_func, func_expl = self.visit(call.func)
  878. arg_expls = []
  879. new_args = []
  880. new_kwargs = []
  881. for arg in call.args:
  882. res, expl = self.visit(arg)
  883. arg_expls.append(expl)
  884. new_args.append(res)
  885. for keyword in call.keywords:
  886. res, expl = self.visit(keyword.value)
  887. new_kwargs.append(ast.keyword(keyword.arg, res))
  888. if keyword.arg:
  889. arg_expls.append(keyword.arg + "=" + expl)
  890. else: # **args have `arg` keywords with an .arg of None
  891. arg_expls.append("**" + expl)
  892. expl = "{}({})".format(func_expl, ", ".join(arg_expls))
  893. new_call = ast.Call(new_func, new_args, new_kwargs)
  894. res = self.assign(new_call)
  895. res_expl = self.explanation_param(self.display(res))
  896. outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
  897. return res, outer_expl
  898. def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
  899. # A Starred node can appear in a function call.
  900. res, expl = self.visit(starred.value)
  901. new_starred = ast.Starred(res, starred.ctx)
  902. return new_starred, "*" + expl
  903. def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
  904. if not isinstance(attr.ctx, ast.Load):
  905. return self.generic_visit(attr)
  906. value, value_expl = self.visit(attr.value)
  907. res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
  908. res_expl = self.explanation_param(self.display(res))
  909. pat = "%s\n{%s = %s.%s\n}"
  910. expl = pat % (res_expl, res_expl, value_expl, attr.attr)
  911. return res, expl
  912. def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
  913. self.push_format_context()
  914. left_res, left_expl = self.visit(comp.left)
  915. if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
  916. left_expl = f"({left_expl})"
  917. res_variables = [self.variable() for i in range(len(comp.ops))]
  918. load_names = [ast.Name(v, ast.Load()) for v in res_variables]
  919. store_names = [ast.Name(v, ast.Store()) for v in res_variables]
  920. it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
  921. expls = []
  922. syms = []
  923. results = [left_res]
  924. for i, op, next_operand in it:
  925. next_res, next_expl = self.visit(next_operand)
  926. if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
  927. next_expl = f"({next_expl})"
  928. results.append(next_res)
  929. sym = BINOP_MAP[op.__class__]
  930. syms.append(ast.Str(sym))
  931. expl = f"{left_expl} {sym} {next_expl}"
  932. expls.append(ast.Str(expl))
  933. res_expr = ast.Compare(left_res, [op], [next_res])
  934. self.statements.append(ast.Assign([store_names[i]], res_expr))
  935. left_res, left_expl = next_res, next_expl
  936. # Use pytest.assertion.util._reprcompare if that's available.
  937. expl_call = self.helper(
  938. "_call_reprcompare",
  939. ast.Tuple(syms, ast.Load()),
  940. ast.Tuple(load_names, ast.Load()),
  941. ast.Tuple(expls, ast.Load()),
  942. ast.Tuple(results, ast.Load()),
  943. )
  944. if len(comp.ops) > 1:
  945. res: ast.expr = ast.BoolOp(ast.And(), load_names)
  946. else:
  947. res = load_names[0]
  948. return res, self.explanation_param(self.pop_format_context(expl_call))
  949. def try_makedirs(cache_dir: Path) -> bool:
  950. """Attempt to create the given directory and sub-directories exist.
  951. Returns True if successful or if it already exists.
  952. """
  953. try:
  954. os.makedirs(os.fspath(cache_dir), exist_ok=True)
  955. except (FileNotFoundError, NotADirectoryError, FileExistsError):
  956. # One of the path components was not a directory:
  957. # - we're in a zip file
  958. # - it is a file
  959. return False
  960. except PermissionError:
  961. return False
  962. except OSError as e:
  963. # as of now, EROFS doesn't have an equivalent OSError-subclass
  964. if e.errno == errno.EROFS:
  965. return False
  966. raise
  967. return True
  968. def get_cache_dir(file_path: Path) -> Path:
  969. """Return the cache directory to write .pyc files for the given .py file path."""
  970. if sys.version_info >= (3, 8) and sys.pycache_prefix:
  971. # given:
  972. # prefix = '/tmp/pycs'
  973. # path = '/home/user/proj/test_app.py'
  974. # we want:
  975. # '/tmp/pycs/home/user/proj'
  976. return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
  977. else:
  978. # classic pycache directory
  979. return file_path.parent / "__pycache__"