You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

442 lines
13KB

  1. # ext/declarative/clsregistry.py
  2. # Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. """Routines to handle the string class registry used by declarative.
  8. This system allows specification of classes and expressions used in
  9. :func:`_orm.relationship` using strings.
  10. """
  11. import weakref
  12. from . import attributes
  13. from . import interfaces
  14. from .descriptor_props import SynonymProperty
  15. from .properties import ColumnProperty
  16. from .util import class_mapper
  17. from .. import exc
  18. from .. import inspection
  19. from .. import util
  20. from ..sql.schema import _get_table_key
  21. # strong references to registries which we place in
  22. # the _decl_class_registry, which is usually weak referencing.
  23. # the internal registries here link to classes with weakrefs and remove
  24. # themselves when all references to contained classes are removed.
  25. _registries = set()
  26. def add_class(classname, cls, decl_class_registry):
  27. """Add a class to the _decl_class_registry associated with the
  28. given declarative class.
  29. """
  30. if classname in decl_class_registry:
  31. # class already exists.
  32. existing = decl_class_registry[classname]
  33. if not isinstance(existing, _MultipleClassMarker):
  34. existing = decl_class_registry[classname] = _MultipleClassMarker(
  35. [cls, existing]
  36. )
  37. else:
  38. decl_class_registry[classname] = cls
  39. try:
  40. root_module = decl_class_registry["_sa_module_registry"]
  41. except KeyError:
  42. decl_class_registry[
  43. "_sa_module_registry"
  44. ] = root_module = _ModuleMarker("_sa_module_registry", None)
  45. tokens = cls.__module__.split(".")
  46. # build up a tree like this:
  47. # modulename: myapp.snacks.nuts
  48. #
  49. # myapp->snack->nuts->(classes)
  50. # snack->nuts->(classes)
  51. # nuts->(classes)
  52. #
  53. # this allows partial token paths to be used.
  54. while tokens:
  55. token = tokens.pop(0)
  56. module = root_module.get_module(token)
  57. for token in tokens:
  58. module = module.get_module(token)
  59. module.add_class(classname, cls)
  60. def remove_class(classname, cls, decl_class_registry):
  61. if classname in decl_class_registry:
  62. existing = decl_class_registry[classname]
  63. if isinstance(existing, _MultipleClassMarker):
  64. existing.remove_item(cls)
  65. else:
  66. del decl_class_registry[classname]
  67. try:
  68. root_module = decl_class_registry["_sa_module_registry"]
  69. except KeyError:
  70. return
  71. tokens = cls.__module__.split(".")
  72. while tokens:
  73. token = tokens.pop(0)
  74. module = root_module.get_module(token)
  75. for token in tokens:
  76. module = module.get_module(token)
  77. module.remove_class(classname, cls)
  78. def _key_is_empty(key, decl_class_registry, test):
  79. """test if a key is empty of a certain object.
  80. used for unit tests against the registry to see if garbage collection
  81. is working.
  82. "test" is a callable that will be passed an object should return True
  83. if the given object is the one we were looking for.
  84. We can't pass the actual object itself b.c. this is for testing garbage
  85. collection; the caller will have to have removed references to the
  86. object itself.
  87. """
  88. if key not in decl_class_registry:
  89. return True
  90. thing = decl_class_registry[key]
  91. if isinstance(thing, _MultipleClassMarker):
  92. for sub_thing in thing.contents:
  93. if test(sub_thing):
  94. return False
  95. else:
  96. return not test(thing)
  97. class _MultipleClassMarker(object):
  98. """refers to multiple classes of the same name
  99. within _decl_class_registry.
  100. """
  101. __slots__ = "on_remove", "contents", "__weakref__"
  102. def __init__(self, classes, on_remove=None):
  103. self.on_remove = on_remove
  104. self.contents = set(
  105. [weakref.ref(item, self._remove_item) for item in classes]
  106. )
  107. _registries.add(self)
  108. def remove_item(self, cls):
  109. self._remove_item(weakref.ref(cls))
  110. def __iter__(self):
  111. return (ref() for ref in self.contents)
  112. def attempt_get(self, path, key):
  113. if len(self.contents) > 1:
  114. raise exc.InvalidRequestError(
  115. 'Multiple classes found for path "%s" '
  116. "in the registry of this declarative "
  117. "base. Please use a fully module-qualified path."
  118. % (".".join(path + [key]))
  119. )
  120. else:
  121. ref = list(self.contents)[0]
  122. cls = ref()
  123. if cls is None:
  124. raise NameError(key)
  125. return cls
  126. def _remove_item(self, ref):
  127. self.contents.discard(ref)
  128. if not self.contents:
  129. _registries.discard(self)
  130. if self.on_remove:
  131. self.on_remove()
  132. def add_item(self, item):
  133. # protect against class registration race condition against
  134. # asynchronous garbage collection calling _remove_item,
  135. # [ticket:3208]
  136. modules = set(
  137. [
  138. cls.__module__
  139. for cls in [ref() for ref in self.contents]
  140. if cls is not None
  141. ]
  142. )
  143. if item.__module__ in modules:
  144. util.warn(
  145. "This declarative base already contains a class with the "
  146. "same class name and module name as %s.%s, and will "
  147. "be replaced in the string-lookup table."
  148. % (item.__module__, item.__name__)
  149. )
  150. self.contents.add(weakref.ref(item, self._remove_item))
  151. class _ModuleMarker(object):
  152. """Refers to a module name within
  153. _decl_class_registry.
  154. """
  155. __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
  156. def __init__(self, name, parent):
  157. self.parent = parent
  158. self.name = name
  159. self.contents = {}
  160. self.mod_ns = _ModNS(self)
  161. if self.parent:
  162. self.path = self.parent.path + [self.name]
  163. else:
  164. self.path = []
  165. _registries.add(self)
  166. def __contains__(self, name):
  167. return name in self.contents
  168. def __getitem__(self, name):
  169. return self.contents[name]
  170. def _remove_item(self, name):
  171. self.contents.pop(name, None)
  172. if not self.contents and self.parent is not None:
  173. self.parent._remove_item(self.name)
  174. _registries.discard(self)
  175. def resolve_attr(self, key):
  176. return getattr(self.mod_ns, key)
  177. def get_module(self, name):
  178. if name not in self.contents:
  179. marker = _ModuleMarker(name, self)
  180. self.contents[name] = marker
  181. else:
  182. marker = self.contents[name]
  183. return marker
  184. def add_class(self, name, cls):
  185. if name in self.contents:
  186. existing = self.contents[name]
  187. existing.add_item(cls)
  188. else:
  189. existing = self.contents[name] = _MultipleClassMarker(
  190. [cls], on_remove=lambda: self._remove_item(name)
  191. )
  192. def remove_class(self, name, cls):
  193. if name in self.contents:
  194. existing = self.contents[name]
  195. existing.remove_item(cls)
  196. class _ModNS(object):
  197. __slots__ = ("__parent",)
  198. def __init__(self, parent):
  199. self.__parent = parent
  200. def __getattr__(self, key):
  201. try:
  202. value = self.__parent.contents[key]
  203. except KeyError:
  204. pass
  205. else:
  206. if value is not None:
  207. if isinstance(value, _ModuleMarker):
  208. return value.mod_ns
  209. else:
  210. assert isinstance(value, _MultipleClassMarker)
  211. return value.attempt_get(self.__parent.path, key)
  212. raise AttributeError(
  213. "Module %r has no mapped classes "
  214. "registered under the name %r" % (self.__parent.name, key)
  215. )
  216. class _GetColumns(object):
  217. __slots__ = ("cls",)
  218. def __init__(self, cls):
  219. self.cls = cls
  220. def __getattr__(self, key):
  221. mp = class_mapper(self.cls, configure=False)
  222. if mp:
  223. if key not in mp.all_orm_descriptors:
  224. raise AttributeError(
  225. "Class %r does not have a mapped column named %r"
  226. % (self.cls, key)
  227. )
  228. desc = mp.all_orm_descriptors[key]
  229. if desc.extension_type is interfaces.NOT_EXTENSION:
  230. prop = desc.property
  231. if isinstance(prop, SynonymProperty):
  232. key = prop.name
  233. elif not isinstance(prop, ColumnProperty):
  234. raise exc.InvalidRequestError(
  235. "Property %r is not an instance of"
  236. " ColumnProperty (i.e. does not correspond"
  237. " directly to a Column)." % key
  238. )
  239. return getattr(self.cls, key)
  240. inspection._inspects(_GetColumns)(
  241. lambda target: inspection.inspect(target.cls)
  242. )
  243. class _GetTable(object):
  244. __slots__ = "key", "metadata"
  245. def __init__(self, key, metadata):
  246. self.key = key
  247. self.metadata = metadata
  248. def __getattr__(self, key):
  249. return self.metadata.tables[_get_table_key(key, self.key)]
  250. def _determine_container(key, value):
  251. if isinstance(value, _MultipleClassMarker):
  252. value = value.attempt_get([], key)
  253. return _GetColumns(value)
  254. class _class_resolver(object):
  255. __slots__ = (
  256. "cls",
  257. "prop",
  258. "arg",
  259. "fallback",
  260. "_dict",
  261. "_resolvers",
  262. "favor_tables",
  263. )
  264. def __init__(self, cls, prop, fallback, arg, favor_tables=False):
  265. self.cls = cls
  266. self.prop = prop
  267. self.arg = arg
  268. self.fallback = fallback
  269. self._dict = util.PopulateDict(self._access_cls)
  270. self._resolvers = ()
  271. self.favor_tables = favor_tables
  272. def _access_cls(self, key):
  273. cls = self.cls
  274. manager = attributes.manager_of_class(cls)
  275. decl_base = manager.registry
  276. decl_class_registry = decl_base._class_registry
  277. metadata = decl_base.metadata
  278. if self.favor_tables:
  279. if key in metadata.tables:
  280. return metadata.tables[key]
  281. elif key in metadata._schemas:
  282. return _GetTable(key, cls.metadata)
  283. if key in decl_class_registry:
  284. return _determine_container(key, decl_class_registry[key])
  285. if not self.favor_tables:
  286. if key in metadata.tables:
  287. return metadata.tables[key]
  288. elif key in metadata._schemas:
  289. return _GetTable(key, cls.metadata)
  290. if (
  291. "_sa_module_registry" in decl_class_registry
  292. and key in decl_class_registry["_sa_module_registry"]
  293. ):
  294. registry = decl_class_registry["_sa_module_registry"]
  295. return registry.resolve_attr(key)
  296. elif self._resolvers:
  297. for resolv in self._resolvers:
  298. value = resolv(key)
  299. if value is not None:
  300. return value
  301. return self.fallback[key]
  302. def _raise_for_name(self, name, err):
  303. util.raise_(
  304. exc.InvalidRequestError(
  305. "When initializing mapper %s, expression %r failed to "
  306. "locate a name (%r). If this is a class name, consider "
  307. "adding this relationship() to the %r class after "
  308. "both dependent classes have been defined."
  309. % (self.prop.parent, self.arg, name, self.cls)
  310. ),
  311. from_=err,
  312. )
  313. def _resolve_name(self):
  314. name = self.arg
  315. d = self._dict
  316. rval = None
  317. try:
  318. for token in name.split("."):
  319. if rval is None:
  320. rval = d[token]
  321. else:
  322. rval = getattr(rval, token)
  323. except KeyError as err:
  324. self._raise_for_name(name, err)
  325. except NameError as n:
  326. self._raise_for_name(n.args[0], n)
  327. else:
  328. if isinstance(rval, _GetColumns):
  329. return rval.cls
  330. else:
  331. return rval
  332. def __call__(self):
  333. try:
  334. x = eval(self.arg, globals(), self._dict)
  335. if isinstance(x, _GetColumns):
  336. return x.cls
  337. else:
  338. return x
  339. except NameError as n:
  340. self._raise_for_name(n.args[0], n)
  341. _fallback_dict = None
  342. def _resolver(cls, prop):
  343. global _fallback_dict
  344. if _fallback_dict is None:
  345. import sqlalchemy
  346. from sqlalchemy.orm import foreign, remote
  347. _fallback_dict = util.immutabledict(sqlalchemy.__dict__).union(
  348. {"foreign": foreign, "remote": remote}
  349. )
  350. def resolve_arg(arg, favor_tables=False):
  351. return _class_resolver(
  352. cls, prop, _fallback_dict, arg, favor_tables=favor_tables
  353. )
  354. def resolve_name(arg):
  355. return _class_resolver(cls, prop, _fallback_dict, arg)._resolve_name
  356. return resolve_name, resolve_arg