You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

85 lines
2.2KB

  1. import abc
  2. import functools
  3. import weakref
  4. from . import exc as async_exc
  5. class ReversibleProxy:
  6. # weakref.ref(async proxy object) -> weakref.ref(sync proxied object)
  7. _proxy_objects = {}
  8. def _assign_proxied(self, target):
  9. if target is not None:
  10. target_ref = weakref.ref(target, ReversibleProxy._target_gced)
  11. proxy_ref = weakref.ref(
  12. self,
  13. functools.partial(ReversibleProxy._target_gced, target_ref),
  14. )
  15. ReversibleProxy._proxy_objects[target_ref] = proxy_ref
  16. return target
  17. @classmethod
  18. def _target_gced(cls, ref, proxy_ref=None):
  19. cls._proxy_objects.pop(ref, None)
  20. @classmethod
  21. def _regenerate_proxy_for_target(cls, target):
  22. raise NotImplementedError()
  23. @classmethod
  24. def _retrieve_proxy_for_target(cls, target, regenerate=True):
  25. try:
  26. proxy_ref = cls._proxy_objects[weakref.ref(target)]
  27. except KeyError:
  28. pass
  29. else:
  30. proxy = proxy_ref()
  31. if proxy is not None:
  32. return proxy
  33. if regenerate:
  34. return cls._regenerate_proxy_for_target(target)
  35. else:
  36. return None
  37. class StartableContext(abc.ABC):
  38. @abc.abstractmethod
  39. async def start(self, is_ctxmanager=False):
  40. pass
  41. def __await__(self):
  42. return self.start().__await__()
  43. async def __aenter__(self):
  44. return await self.start(is_ctxmanager=True)
  45. @abc.abstractmethod
  46. async def __aexit__(self, type_, value, traceback):
  47. pass
  48. def _raise_for_not_started(self):
  49. raise async_exc.AsyncContextNotStarted(
  50. "%s context has not been started and object has not been awaited."
  51. % (self.__class__.__name__)
  52. )
  53. class ProxyComparable(ReversibleProxy):
  54. def __hash__(self):
  55. return id(self)
  56. def __eq__(self, other):
  57. return (
  58. isinstance(other, self.__class__)
  59. and self._proxied == other._proxied
  60. )
  61. def __ne__(self, other):
  62. return (
  63. not isinstance(other, self.__class__)
  64. or self._proxied != other._proxied
  65. )