You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

91 lines
3.0KB

  1. import threading
  2. import traceback
  3. import warnings
  4. from types import TracebackType
  5. from typing import Any
  6. from typing import Callable
  7. from typing import Generator
  8. from typing import Optional
  9. from typing import Type
  10. import pytest
  11. # Copied from cpython/Lib/test/support/threading_helper.py, with modifications.
  12. class catch_threading_exception:
  13. """Context manager catching threading.Thread exception using
  14. threading.excepthook.
  15. Storing exc_value using a custom hook can create a reference cycle. The
  16. reference cycle is broken explicitly when the context manager exits.
  17. Storing thread using a custom hook can resurrect it if it is set to an
  18. object which is being finalized. Exiting the context manager clears the
  19. stored object.
  20. Usage:
  21. with threading_helper.catch_threading_exception() as cm:
  22. # code spawning a thread which raises an exception
  23. ...
  24. # check the thread exception: use cm.args
  25. ...
  26. # cm.args attribute no longer exists at this point
  27. # (to break a reference cycle)
  28. """
  29. def __init__(self) -> None:
  30. # See https://github.com/python/typeshed/issues/4767 regarding the underscore.
  31. self.args: Optional["threading._ExceptHookArgs"] = None
  32. self._old_hook: Optional[Callable[["threading._ExceptHookArgs"], Any]] = None
  33. def _hook(self, args: "threading._ExceptHookArgs") -> None:
  34. self.args = args
  35. def __enter__(self) -> "catch_threading_exception":
  36. self._old_hook = threading.excepthook
  37. threading.excepthook = self._hook
  38. return self
  39. def __exit__(
  40. self,
  41. exc_type: Optional[Type[BaseException]],
  42. exc_val: Optional[BaseException],
  43. exc_tb: Optional[TracebackType],
  44. ) -> None:
  45. assert self._old_hook is not None
  46. threading.excepthook = self._old_hook
  47. self._old_hook = None
  48. del self.args
  49. def thread_exception_runtest_hook() -> Generator[None, None, None]:
  50. with catch_threading_exception() as cm:
  51. yield
  52. if cm.args:
  53. if cm.args.thread is not None:
  54. thread_name = cm.args.thread.name
  55. else:
  56. thread_name = "<unknown>"
  57. msg = f"Exception in thread {thread_name}\n\n"
  58. msg += "".join(
  59. traceback.format_exception(
  60. cm.args.exc_type, cm.args.exc_value, cm.args.exc_traceback,
  61. )
  62. )
  63. warnings.warn(pytest.PytestUnhandledThreadExceptionWarning(msg))
  64. @pytest.hookimpl(hookwrapper=True, trylast=True)
  65. def pytest_runtest_setup() -> Generator[None, None, None]:
  66. yield from thread_exception_runtest_hook()
  67. @pytest.hookimpl(hookwrapper=True, tryfirst=True)
  68. def pytest_runtest_call() -> Generator[None, None, None]:
  69. yield from thread_exception_runtest_hook()
  70. @pytest.hookimpl(hookwrapper=True, tryfirst=True)
  71. def pytest_runtest_teardown() -> Generator[None, None, None]:
  72. yield from thread_exception_runtest_hook()