25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1466 lines
48KB

  1. from collections import deque
  2. from collections import namedtuple
  3. import itertools
  4. import operator
  5. from . import operators
  6. from .visitors import ExtendedInternalTraversal
  7. from .visitors import InternalTraversal
  8. from .. import util
  9. from ..inspection import inspect
  10. from ..util import collections_abc
  11. from ..util import HasMemoized
  12. from ..util import py37
  13. SKIP_TRAVERSE = util.symbol("skip_traverse")
  14. COMPARE_FAILED = False
  15. COMPARE_SUCCEEDED = True
  16. NO_CACHE = util.symbol("no_cache")
  17. CACHE_IN_PLACE = util.symbol("cache_in_place")
  18. CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key")
  19. STATIC_CACHE_KEY = util.symbol("static_cache_key")
  20. PROPAGATE_ATTRS = util.symbol("propagate_attrs")
  21. ANON_NAME = util.symbol("anon_name")
  22. def compare(obj1, obj2, **kw):
  23. if kw.get("use_proxies", False):
  24. strategy = ColIdentityComparatorStrategy()
  25. else:
  26. strategy = TraversalComparatorStrategy()
  27. return strategy.compare(obj1, obj2, **kw)
  28. def _preconfigure_traversals(target_hierarchy):
  29. for cls in util.walk_subclasses(target_hierarchy):
  30. if hasattr(cls, "_traverse_internals"):
  31. cls._generate_cache_attrs()
  32. _copy_internals.generate_dispatch(
  33. cls,
  34. cls._traverse_internals,
  35. "_generated_copy_internals_traversal",
  36. )
  37. _get_children.generate_dispatch(
  38. cls,
  39. cls._traverse_internals,
  40. "_generated_get_children_traversal",
  41. )
  42. class HasCacheKey(object):
  43. _cache_key_traversal = NO_CACHE
  44. __slots__ = ()
  45. @classmethod
  46. def _generate_cache_attrs(cls):
  47. """generate cache key dispatcher for a new class.
  48. This sets the _generated_cache_key_traversal attribute once called
  49. so should only be called once per class.
  50. """
  51. inherit = cls.__dict__.get("inherit_cache", False)
  52. if inherit:
  53. _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
  54. if _cache_key_traversal is None:
  55. try:
  56. _cache_key_traversal = cls._traverse_internals
  57. except AttributeError:
  58. cls._generated_cache_key_traversal = NO_CACHE
  59. return NO_CACHE
  60. # TODO: wouldn't we instead get this from our superclass?
  61. # also, our superclass may not have this yet, but in any case,
  62. # we'd generate for the superclass that has it. this is a little
  63. # more complicated, so for the moment this is a little less
  64. # efficient on startup but simpler.
  65. return _cache_key_traversal_visitor.generate_dispatch(
  66. cls, _cache_key_traversal, "_generated_cache_key_traversal"
  67. )
  68. else:
  69. _cache_key_traversal = cls.__dict__.get(
  70. "_cache_key_traversal", None
  71. )
  72. if _cache_key_traversal is None:
  73. _cache_key_traversal = cls.__dict__.get(
  74. "_traverse_internals", None
  75. )
  76. if _cache_key_traversal is None:
  77. cls._generated_cache_key_traversal = NO_CACHE
  78. return NO_CACHE
  79. return _cache_key_traversal_visitor.generate_dispatch(
  80. cls, _cache_key_traversal, "_generated_cache_key_traversal"
  81. )
  82. @util.preload_module("sqlalchemy.sql.elements")
  83. def _gen_cache_key(self, anon_map, bindparams):
  84. """return an optional cache key.
  85. The cache key is a tuple which can contain any series of
  86. objects that are hashable and also identifies
  87. this object uniquely within the presence of a larger SQL expression
  88. or statement, for the purposes of caching the resulting query.
  89. The cache key should be based on the SQL compiled structure that would
  90. ultimately be produced. That is, two structures that are composed in
  91. exactly the same way should produce the same cache key; any difference
  92. in the structures that would affect the SQL string or the type handlers
  93. should result in a different cache key.
  94. If a structure cannot produce a useful cache key, the NO_CACHE
  95. symbol should be added to the anon_map and the method should
  96. return None.
  97. """
  98. idself = id(self)
  99. cls = self.__class__
  100. if idself in anon_map:
  101. return (anon_map[idself], cls)
  102. else:
  103. # inline of
  104. # id_ = anon_map[idself]
  105. anon_map[idself] = id_ = str(anon_map.index)
  106. anon_map.index += 1
  107. try:
  108. dispatcher = cls.__dict__["_generated_cache_key_traversal"]
  109. except KeyError:
  110. # most of the dispatchers are generated up front
  111. # in sqlalchemy/sql/__init__.py ->
  112. # traversals.py-> _preconfigure_traversals().
  113. # this block will generate any remaining dispatchers.
  114. dispatcher = cls._generate_cache_attrs()
  115. if dispatcher is NO_CACHE:
  116. anon_map[NO_CACHE] = True
  117. return None
  118. result = (id_, cls)
  119. # inline of _cache_key_traversal_visitor.run_generated_dispatch()
  120. for attrname, obj, meth in dispatcher(
  121. self, _cache_key_traversal_visitor
  122. ):
  123. if obj is not None:
  124. # TODO: see if C code can help here as Python lacks an
  125. # efficient switch construct
  126. if meth is STATIC_CACHE_KEY:
  127. sck = obj._static_cache_key
  128. if sck is NO_CACHE:
  129. anon_map[NO_CACHE] = True
  130. return None
  131. result += (attrname, sck)
  132. elif meth is ANON_NAME:
  133. elements = util.preloaded.sql_elements
  134. if isinstance(obj, elements._anonymous_label):
  135. obj = obj.apply_map(anon_map)
  136. result += (attrname, obj)
  137. elif meth is CALL_GEN_CACHE_KEY:
  138. result += (
  139. attrname,
  140. obj._gen_cache_key(anon_map, bindparams),
  141. )
  142. # remaining cache functions are against
  143. # Python tuples, dicts, lists, etc. so we can skip
  144. # if they are empty
  145. elif obj:
  146. if meth is CACHE_IN_PLACE:
  147. result += (attrname, obj)
  148. elif meth is PROPAGATE_ATTRS:
  149. result += (
  150. attrname,
  151. obj["compile_state_plugin"],
  152. obj["plugin_subject"]._gen_cache_key(
  153. anon_map, bindparams
  154. )
  155. if obj["plugin_subject"]
  156. else None,
  157. )
  158. elif meth is InternalTraversal.dp_annotations_key:
  159. # obj is here is the _annotations dict. however, we
  160. # want to use the memoized cache key version of it. for
  161. # Columns, this should be long lived. For select()
  162. # statements, not so much, but they usually won't have
  163. # annotations.
  164. result += self._annotations_cache_key
  165. elif (
  166. meth is InternalTraversal.dp_clauseelement_list
  167. or meth is InternalTraversal.dp_clauseelement_tuple
  168. or meth
  169. is InternalTraversal.dp_memoized_select_entities
  170. ):
  171. result += (
  172. attrname,
  173. tuple(
  174. [
  175. elem._gen_cache_key(anon_map, bindparams)
  176. for elem in obj
  177. ]
  178. ),
  179. )
  180. else:
  181. result += meth(
  182. attrname, obj, self, anon_map, bindparams
  183. )
  184. return result
  185. def _generate_cache_key(self):
  186. """return a cache key.
  187. The cache key is a tuple which can contain any series of
  188. objects that are hashable and also identifies
  189. this object uniquely within the presence of a larger SQL expression
  190. or statement, for the purposes of caching the resulting query.
  191. The cache key should be based on the SQL compiled structure that would
  192. ultimately be produced. That is, two structures that are composed in
  193. exactly the same way should produce the same cache key; any difference
  194. in the structures that would affect the SQL string or the type handlers
  195. should result in a different cache key.
  196. The cache key returned by this method is an instance of
  197. :class:`.CacheKey`, which consists of a tuple representing the
  198. cache key, as well as a list of :class:`.BindParameter` objects
  199. which are extracted from the expression. While two expressions
  200. that produce identical cache key tuples will themselves generate
  201. identical SQL strings, the list of :class:`.BindParameter` objects
  202. indicates the bound values which may have different values in
  203. each one; these bound parameters must be consulted in order to
  204. execute the statement with the correct parameters.
  205. a :class:`_expression.ClauseElement` structure that does not implement
  206. a :meth:`._gen_cache_key` method and does not implement a
  207. :attr:`.traverse_internals` attribute will not be cacheable; when
  208. such an element is embedded into a larger structure, this method
  209. will return None, indicating no cache key is available.
  210. """
  211. bindparams = []
  212. _anon_map = anon_map()
  213. key = self._gen_cache_key(_anon_map, bindparams)
  214. if NO_CACHE in _anon_map:
  215. return None
  216. else:
  217. return CacheKey(key, bindparams)
  218. @classmethod
  219. def _generate_cache_key_for_object(cls, obj):
  220. bindparams = []
  221. _anon_map = anon_map()
  222. key = obj._gen_cache_key(_anon_map, bindparams)
  223. if NO_CACHE in _anon_map:
  224. return None
  225. else:
  226. return CacheKey(key, bindparams)
  227. class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
  228. @HasMemoized.memoized_instancemethod
  229. def _generate_cache_key(self):
  230. return HasCacheKey._generate_cache_key(self)
  231. class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
  232. def __hash__(self):
  233. """CacheKey itself is not hashable - hash the .key portion"""
  234. return None
  235. def to_offline_string(self, statement_cache, statement, parameters):
  236. """Generate an "offline string" form of this :class:`.CacheKey`
  237. The "offline string" is basically the string SQL for the
  238. statement plus a repr of the bound parameter values in series.
  239. Whereas the :class:`.CacheKey` object is dependent on in-memory
  240. identities in order to work as a cache key, the "offline" version
  241. is suitable for a cache that will work for other processes as well.
  242. The given ``statement_cache`` is a dictionary-like object where the
  243. string form of the statement itself will be cached. This dictionary
  244. should be in a longer lived scope in order to reduce the time spent
  245. stringifying statements.
  246. """
  247. if self.key not in statement_cache:
  248. statement_cache[self.key] = sql_str = str(statement)
  249. else:
  250. sql_str = statement_cache[self.key]
  251. return repr(
  252. (
  253. sql_str,
  254. tuple(
  255. parameters.get(bindparam.key, bindparam.value)
  256. for bindparam in self.bindparams
  257. ),
  258. )
  259. )
  260. def __eq__(self, other):
  261. return self.key == other.key
  262. @classmethod
  263. def _diff_tuples(cls, left, right):
  264. ck1 = CacheKey(left, [])
  265. ck2 = CacheKey(right, [])
  266. return ck1._diff(ck2)
  267. def _whats_different(self, other):
  268. k1 = self.key
  269. k2 = other.key
  270. stack = []
  271. pickup_index = 0
  272. while True:
  273. s1, s2 = k1, k2
  274. for idx in stack:
  275. s1 = s1[idx]
  276. s2 = s2[idx]
  277. for idx, (e1, e2) in enumerate(util.zip_longest(s1, s2)):
  278. if idx < pickup_index:
  279. continue
  280. if e1 != e2:
  281. if isinstance(e1, tuple) and isinstance(e2, tuple):
  282. stack.append(idx)
  283. break
  284. else:
  285. yield "key%s[%d]: %s != %s" % (
  286. "".join("[%d]" % id_ for id_ in stack),
  287. idx,
  288. e1,
  289. e2,
  290. )
  291. else:
  292. pickup_index = stack.pop(-1)
  293. break
  294. def _diff(self, other):
  295. return ", ".join(self._whats_different(other))
  296. def __str__(self):
  297. stack = [self.key]
  298. output = []
  299. sentinel = object()
  300. indent = -1
  301. while stack:
  302. elem = stack.pop(0)
  303. if elem is sentinel:
  304. output.append((" " * (indent * 2)) + "),")
  305. indent -= 1
  306. elif isinstance(elem, tuple):
  307. if not elem:
  308. output.append((" " * ((indent + 1) * 2)) + "()")
  309. else:
  310. indent += 1
  311. stack = list(elem) + [sentinel] + stack
  312. output.append((" " * (indent * 2)) + "(")
  313. else:
  314. if isinstance(elem, HasCacheKey):
  315. repr_ = "<%s object at %s>" % (
  316. type(elem).__name__,
  317. hex(id(elem)),
  318. )
  319. else:
  320. repr_ = repr(elem)
  321. output.append((" " * (indent * 2)) + " " + repr_ + ", ")
  322. return "CacheKey(key=%s)" % ("\n".join(output),)
  323. def _generate_param_dict(self):
  324. """used for testing"""
  325. from .compiler import prefix_anon_map
  326. _anon_map = prefix_anon_map()
  327. return {b.key % _anon_map: b.effective_value for b in self.bindparams}
  328. def _apply_params_to_element(self, original_cache_key, target_element):
  329. translate = {
  330. k.key: v.value
  331. for k, v in zip(original_cache_key.bindparams, self.bindparams)
  332. }
  333. return target_element.params(translate)
  334. def _clone(element, **kw):
  335. return element._clone()
  336. class _CacheKey(ExtendedInternalTraversal):
  337. # very common elements are inlined into the main _get_cache_key() method
  338. # to produce a dramatic savings in Python function call overhead
  339. visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY
  340. visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
  341. visit_annotations_key = InternalTraversal.dp_annotations_key
  342. visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
  343. visit_memoized_select_entities = (
  344. InternalTraversal.dp_memoized_select_entities
  345. )
  346. visit_string = (
  347. visit_boolean
  348. ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE
  349. visit_statement_hint_list = CACHE_IN_PLACE
  350. visit_type = STATIC_CACHE_KEY
  351. visit_anon_name = ANON_NAME
  352. visit_propagate_attrs = PROPAGATE_ATTRS
  353. def visit_with_context_options(
  354. self, attrname, obj, parent, anon_map, bindparams
  355. ):
  356. return tuple((fn.__code__, c_key) for fn, c_key in obj)
  357. def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
  358. return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
  359. def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
  360. return tuple(obj)
  361. def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
  362. return (
  363. attrname,
  364. obj._gen_cache_key(anon_map, bindparams)
  365. if isinstance(obj, HasCacheKey)
  366. else obj,
  367. )
  368. def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
  369. return (
  370. attrname,
  371. tuple(
  372. elem._gen_cache_key(anon_map, bindparams)
  373. if isinstance(elem, HasCacheKey)
  374. else elem
  375. for elem in obj
  376. ),
  377. )
  378. def visit_has_cache_key_tuples(
  379. self, attrname, obj, parent, anon_map, bindparams
  380. ):
  381. if not obj:
  382. return ()
  383. return (
  384. attrname,
  385. tuple(
  386. tuple(
  387. elem._gen_cache_key(anon_map, bindparams)
  388. for elem in tup_elem
  389. )
  390. for tup_elem in obj
  391. ),
  392. )
  393. def visit_has_cache_key_list(
  394. self, attrname, obj, parent, anon_map, bindparams
  395. ):
  396. if not obj:
  397. return ()
  398. return (
  399. attrname,
  400. tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
  401. )
  402. visit_executable_options = visit_has_cache_key_list
  403. def visit_inspectable_list(
  404. self, attrname, obj, parent, anon_map, bindparams
  405. ):
  406. return self.visit_has_cache_key_list(
  407. attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
  408. )
  409. def visit_clauseelement_tuples(
  410. self, attrname, obj, parent, anon_map, bindparams
  411. ):
  412. return self.visit_has_cache_key_tuples(
  413. attrname, obj, parent, anon_map, bindparams
  414. )
  415. def visit_fromclause_ordered_set(
  416. self, attrname, obj, parent, anon_map, bindparams
  417. ):
  418. if not obj:
  419. return ()
  420. return (
  421. attrname,
  422. tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]),
  423. )
  424. def visit_clauseelement_unordered_set(
  425. self, attrname, obj, parent, anon_map, bindparams
  426. ):
  427. if not obj:
  428. return ()
  429. cache_keys = [
  430. elem._gen_cache_key(anon_map, bindparams) for elem in obj
  431. ]
  432. return (
  433. attrname,
  434. tuple(
  435. sorted(cache_keys)
  436. ), # cache keys all start with (id_, class)
  437. )
  438. def visit_named_ddl_element(
  439. self, attrname, obj, parent, anon_map, bindparams
  440. ):
  441. return (attrname, obj.name)
  442. def visit_prefix_sequence(
  443. self, attrname, obj, parent, anon_map, bindparams
  444. ):
  445. if not obj:
  446. return ()
  447. return (
  448. attrname,
  449. tuple(
  450. [
  451. (clause._gen_cache_key(anon_map, bindparams), strval)
  452. for clause, strval in obj
  453. ]
  454. ),
  455. )
  456. def visit_setup_join_tuple(
  457. self, attrname, obj, parent, anon_map, bindparams
  458. ):
  459. is_legacy = "legacy" in attrname
  460. return tuple(
  461. (
  462. target
  463. if is_legacy and isinstance(target, str)
  464. else target._gen_cache_key(anon_map, bindparams),
  465. onclause
  466. if is_legacy and isinstance(onclause, str)
  467. else onclause._gen_cache_key(anon_map, bindparams)
  468. if onclause is not None
  469. else None,
  470. from_._gen_cache_key(anon_map, bindparams)
  471. if from_ is not None
  472. else None,
  473. tuple([(key, flags[key]) for key in sorted(flags)]),
  474. )
  475. for (target, onclause, from_, flags) in obj
  476. )
  477. def visit_table_hint_list(
  478. self, attrname, obj, parent, anon_map, bindparams
  479. ):
  480. if not obj:
  481. return ()
  482. return (
  483. attrname,
  484. tuple(
  485. [
  486. (
  487. clause._gen_cache_key(anon_map, bindparams),
  488. dialect_name,
  489. text,
  490. )
  491. for (clause, dialect_name), text in obj.items()
  492. ]
  493. ),
  494. )
  495. def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
  496. return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
  497. def visit_dialect_options(
  498. self, attrname, obj, parent, anon_map, bindparams
  499. ):
  500. return (
  501. attrname,
  502. tuple(
  503. (
  504. dialect_name,
  505. tuple(
  506. [
  507. (key, obj[dialect_name][key])
  508. for key in sorted(obj[dialect_name])
  509. ]
  510. ),
  511. )
  512. for dialect_name in sorted(obj)
  513. ),
  514. )
  515. def visit_string_clauseelement_dict(
  516. self, attrname, obj, parent, anon_map, bindparams
  517. ):
  518. return (
  519. attrname,
  520. tuple(
  521. (key, obj[key]._gen_cache_key(anon_map, bindparams))
  522. for key in sorted(obj)
  523. ),
  524. )
  525. def visit_string_multi_dict(
  526. self, attrname, obj, parent, anon_map, bindparams
  527. ):
  528. return (
  529. attrname,
  530. tuple(
  531. (
  532. key,
  533. value._gen_cache_key(anon_map, bindparams)
  534. if isinstance(value, HasCacheKey)
  535. else value,
  536. )
  537. for key, value in [(key, obj[key]) for key in sorted(obj)]
  538. ),
  539. )
  540. def visit_fromclause_canonical_column_collection(
  541. self, attrname, obj, parent, anon_map, bindparams
  542. ):
  543. # inlining into the internals of ColumnCollection
  544. return (
  545. attrname,
  546. tuple(
  547. col._gen_cache_key(anon_map, bindparams)
  548. for k, col in obj._collection
  549. ),
  550. )
  551. def visit_unknown_structure(
  552. self, attrname, obj, parent, anon_map, bindparams
  553. ):
  554. anon_map[NO_CACHE] = True
  555. return ()
  556. def visit_dml_ordered_values(
  557. self, attrname, obj, parent, anon_map, bindparams
  558. ):
  559. return (
  560. attrname,
  561. tuple(
  562. (
  563. key._gen_cache_key(anon_map, bindparams)
  564. if hasattr(key, "__clause_element__")
  565. else key,
  566. value._gen_cache_key(anon_map, bindparams),
  567. )
  568. for key, value in obj
  569. ),
  570. )
  571. def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
  572. if py37:
  573. # in py37 we can assume two dictionaries created in the same
  574. # insert ordering will retain that sorting
  575. return (
  576. attrname,
  577. tuple(
  578. (
  579. k._gen_cache_key(anon_map, bindparams)
  580. if hasattr(k, "__clause_element__")
  581. else k,
  582. obj[k]._gen_cache_key(anon_map, bindparams),
  583. )
  584. for k in obj
  585. ),
  586. )
  587. else:
  588. expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
  589. if expr_values:
  590. # expr values can't be sorted deterministically right now,
  591. # so no cache
  592. anon_map[NO_CACHE] = True
  593. return ()
  594. str_values = expr_values.symmetric_difference(obj)
  595. return (
  596. attrname,
  597. tuple(
  598. (k, obj[k]._gen_cache_key(anon_map, bindparams))
  599. for k in sorted(str_values)
  600. ),
  601. )
  602. def visit_dml_multi_values(
  603. self, attrname, obj, parent, anon_map, bindparams
  604. ):
  605. # multivalues are simply not cacheable right now
  606. anon_map[NO_CACHE] = True
  607. return ()
  608. _cache_key_traversal_visitor = _CacheKey()
  609. class HasCopyInternals(object):
  610. def _clone(self, **kw):
  611. raise NotImplementedError()
  612. def _copy_internals(self, omit_attrs=(), **kw):
  613. """Reassign internal elements to be clones of themselves.
  614. Called during a copy-and-traverse operation on newly
  615. shallow-copied elements to create a deep copy.
  616. The given clone function should be used, which may be applying
  617. additional transformations to the element (i.e. replacement
  618. traversal, cloned traversal, annotations).
  619. """
  620. try:
  621. traverse_internals = self._traverse_internals
  622. except AttributeError:
  623. # user-defined classes may not have a _traverse_internals
  624. return
  625. for attrname, obj, meth in _copy_internals.run_generated_dispatch(
  626. self, traverse_internals, "_generated_copy_internals_traversal"
  627. ):
  628. if attrname in omit_attrs:
  629. continue
  630. if obj is not None:
  631. result = meth(attrname, self, obj, **kw)
  632. if result is not None:
  633. setattr(self, attrname, result)
  634. class _CopyInternals(InternalTraversal):
  635. """Generate a _copy_internals internal traversal dispatch for classes
  636. with a _traverse_internals collection."""
  637. def visit_clauseelement(
  638. self, attrname, parent, element, clone=_clone, **kw
  639. ):
  640. return clone(element, **kw)
  641. def visit_clauseelement_list(
  642. self, attrname, parent, element, clone=_clone, **kw
  643. ):
  644. return [clone(clause, **kw) for clause in element]
  645. def visit_clauseelement_tuple(
  646. self, attrname, parent, element, clone=_clone, **kw
  647. ):
  648. return tuple([clone(clause, **kw) for clause in element])
  649. def visit_executable_options(
  650. self, attrname, parent, element, clone=_clone, **kw
  651. ):
  652. return tuple([clone(clause, **kw) for clause in element])
  653. def visit_clauseelement_unordered_set(
  654. self, attrname, parent, element, clone=_clone, **kw
  655. ):
  656. return {clone(clause, **kw) for clause in element}
  657. def visit_clauseelement_tuples(
  658. self, attrname, parent, element, clone=_clone, **kw
  659. ):
  660. return [
  661. tuple(clone(tup_elem, **kw) for tup_elem in elem)
  662. for elem in element
  663. ]
  664. def visit_string_clauseelement_dict(
  665. self, attrname, parent, element, clone=_clone, **kw
  666. ):
  667. return dict(
  668. (key, clone(value, **kw)) for key, value in element.items()
  669. )
  670. def visit_setup_join_tuple(
  671. self, attrname, parent, element, clone=_clone, **kw
  672. ):
  673. return tuple(
  674. (
  675. clone(target, **kw) if target is not None else None,
  676. clone(onclause, **kw) if onclause is not None else None,
  677. clone(from_, **kw) if from_ is not None else None,
  678. flags,
  679. )
  680. for (target, onclause, from_, flags) in element
  681. )
  682. def visit_memoized_select_entities(self, attrname, parent, element, **kw):
  683. return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
  684. def visit_dml_ordered_values(
  685. self, attrname, parent, element, clone=_clone, **kw
  686. ):
  687. # sequence of 2-tuples
  688. return [
  689. (
  690. clone(key, **kw)
  691. if hasattr(key, "__clause_element__")
  692. else key,
  693. clone(value, **kw),
  694. )
  695. for key, value in element
  696. ]
  697. def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
  698. return {
  699. (
  700. clone(key, **kw) if hasattr(key, "__clause_element__") else key
  701. ): clone(value, **kw)
  702. for key, value in element.items()
  703. }
  704. def visit_dml_multi_values(
  705. self, attrname, parent, element, clone=_clone, **kw
  706. ):
  707. # sequence of sequences, each sequence contains a list/dict/tuple
  708. def copy(elem):
  709. if isinstance(elem, (list, tuple)):
  710. return [
  711. clone(value, **kw)
  712. if hasattr(value, "__clause_element__")
  713. else value
  714. for value in elem
  715. ]
  716. elif isinstance(elem, dict):
  717. return {
  718. (
  719. clone(key, **kw)
  720. if hasattr(key, "__clause_element__")
  721. else key
  722. ): (
  723. clone(value, **kw)
  724. if hasattr(value, "__clause_element__")
  725. else value
  726. )
  727. for key, value in elem.items()
  728. }
  729. else:
  730. # TODO: use abc classes
  731. assert False
  732. return [
  733. [copy(sub_element) for sub_element in sequence]
  734. for sequence in element
  735. ]
  736. def visit_propagate_attrs(
  737. self, attrname, parent, element, clone=_clone, **kw
  738. ):
  739. return element
  740. _copy_internals = _CopyInternals()
  741. def _flatten_clauseelement(element):
  742. while hasattr(element, "__clause_element__") and not getattr(
  743. element, "is_clause_element", False
  744. ):
  745. element = element.__clause_element__()
  746. return element
  747. class _GetChildren(InternalTraversal):
  748. """Generate a _children_traversal internal traversal dispatch for classes
  749. with a _traverse_internals collection."""
  750. def visit_has_cache_key(self, element, **kw):
  751. # the GetChildren traversal refers explicitly to ClauseElement
  752. # structures. Within these, a plain HasCacheKey is not a
  753. # ClauseElement, so don't include these.
  754. return ()
  755. def visit_clauseelement(self, element, **kw):
  756. return (element,)
  757. def visit_clauseelement_list(self, element, **kw):
  758. return element
  759. def visit_clauseelement_tuple(self, element, **kw):
  760. return element
  761. def visit_clauseelement_tuples(self, element, **kw):
  762. return itertools.chain.from_iterable(element)
  763. def visit_fromclause_canonical_column_collection(self, element, **kw):
  764. return ()
  765. def visit_string_clauseelement_dict(self, element, **kw):
  766. return element.values()
  767. def visit_fromclause_ordered_set(self, element, **kw):
  768. return element
  769. def visit_clauseelement_unordered_set(self, element, **kw):
  770. return element
  771. def visit_setup_join_tuple(self, element, **kw):
  772. for (target, onclause, from_, flags) in element:
  773. if from_ is not None:
  774. yield from_
  775. if not isinstance(target, str):
  776. yield _flatten_clauseelement(target)
  777. if onclause is not None and not isinstance(onclause, str):
  778. yield _flatten_clauseelement(onclause)
  779. def visit_memoized_select_entities(self, element, **kw):
  780. return self.visit_clauseelement_tuple(element, **kw)
  781. def visit_dml_ordered_values(self, element, **kw):
  782. for k, v in element:
  783. if hasattr(k, "__clause_element__"):
  784. yield k
  785. yield v
  786. def visit_dml_values(self, element, **kw):
  787. expr_values = {k for k in element if hasattr(k, "__clause_element__")}
  788. str_values = expr_values.symmetric_difference(element)
  789. for k in sorted(str_values):
  790. yield element[k]
  791. for k in expr_values:
  792. yield k
  793. yield element[k]
  794. def visit_dml_multi_values(self, element, **kw):
  795. return ()
  796. def visit_propagate_attrs(self, element, **kw):
  797. return ()
  798. _get_children = _GetChildren()
  799. @util.preload_module("sqlalchemy.sql.elements")
  800. def _resolve_name_for_compare(element, name, anon_map, **kw):
  801. if isinstance(name, util.preloaded.sql_elements._anonymous_label):
  802. name = name.apply_map(anon_map)
  803. return name
  804. class anon_map(dict):
  805. """A map that creates new keys for missing key access.
  806. Produces an incrementing sequence given a series of unique keys.
  807. This is similar to the compiler prefix_anon_map class although simpler.
  808. Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
  809. is otherwise usually used for this type of operation.
  810. """
  811. def __init__(self):
  812. self.index = 0
  813. def __missing__(self, key):
  814. self[key] = val = str(self.index)
  815. self.index += 1
  816. return val
  817. class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
  818. __slots__ = "stack", "cache", "anon_map"
  819. def __init__(self):
  820. self.stack = deque()
  821. self.cache = set()
  822. def _memoized_attr_anon_map(self):
  823. return (anon_map(), anon_map())
  824. def compare(self, obj1, obj2, **kw):
  825. stack = self.stack
  826. cache = self.cache
  827. compare_annotations = kw.get("compare_annotations", False)
  828. stack.append((obj1, obj2))
  829. while stack:
  830. left, right = stack.popleft()
  831. if left is right:
  832. continue
  833. elif left is None or right is None:
  834. # we know they are different so no match
  835. return False
  836. elif (left, right) in cache:
  837. continue
  838. cache.add((left, right))
  839. visit_name = left.__visit_name__
  840. if visit_name != right.__visit_name__:
  841. return False
  842. meth = getattr(self, "compare_%s" % visit_name, None)
  843. if meth:
  844. attributes_compared = meth(left, right, **kw)
  845. if attributes_compared is COMPARE_FAILED:
  846. return False
  847. elif attributes_compared is SKIP_TRAVERSE:
  848. continue
  849. # attributes_compared is returned as a list of attribute
  850. # names that were "handled" by the comparison method above.
  851. # remaining attribute names in the _traverse_internals
  852. # will be compared.
  853. else:
  854. attributes_compared = ()
  855. for (
  856. (left_attrname, left_visit_sym),
  857. (right_attrname, right_visit_sym),
  858. ) in util.zip_longest(
  859. left._traverse_internals,
  860. right._traverse_internals,
  861. fillvalue=(None, None),
  862. ):
  863. if not compare_annotations and (
  864. (left_attrname == "_annotations")
  865. or (right_attrname == "_annotations")
  866. ):
  867. continue
  868. if (
  869. left_attrname != right_attrname
  870. or left_visit_sym is not right_visit_sym
  871. ):
  872. return False
  873. elif left_attrname in attributes_compared:
  874. continue
  875. dispatch = self.dispatch(left_visit_sym)
  876. left_child = operator.attrgetter(left_attrname)(left)
  877. right_child = operator.attrgetter(right_attrname)(right)
  878. if left_child is None:
  879. if right_child is not None:
  880. return False
  881. else:
  882. continue
  883. comparison = dispatch(
  884. left_attrname, left, left_child, right, right_child, **kw
  885. )
  886. if comparison is COMPARE_FAILED:
  887. return False
  888. return True
  889. def compare_inner(self, obj1, obj2, **kw):
  890. comparator = self.__class__()
  891. return comparator.compare(obj1, obj2, **kw)
  892. def visit_has_cache_key(
  893. self, attrname, left_parent, left, right_parent, right, **kw
  894. ):
  895. if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
  896. self.anon_map[1], []
  897. ):
  898. return COMPARE_FAILED
  899. def visit_propagate_attrs(
  900. self, attrname, left_parent, left, right_parent, right, **kw
  901. ):
  902. return self.compare_inner(
  903. left.get("plugin_subject", None), right.get("plugin_subject", None)
  904. )
  905. def visit_has_cache_key_list(
  906. self, attrname, left_parent, left, right_parent, right, **kw
  907. ):
  908. for l, r in util.zip_longest(left, right, fillvalue=None):
  909. if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
  910. self.anon_map[1], []
  911. ):
  912. return COMPARE_FAILED
  913. visit_executable_options = visit_has_cache_key_list
  914. def visit_clauseelement(
  915. self, attrname, left_parent, left, right_parent, right, **kw
  916. ):
  917. self.stack.append((left, right))
  918. def visit_fromclause_canonical_column_collection(
  919. self, attrname, left_parent, left, right_parent, right, **kw
  920. ):
  921. for lcol, rcol in util.zip_longest(left, right, fillvalue=None):
  922. self.stack.append((lcol, rcol))
  923. def visit_fromclause_derived_column_collection(
  924. self, attrname, left_parent, left, right_parent, right, **kw
  925. ):
  926. pass
  927. def visit_string_clauseelement_dict(
  928. self, attrname, left_parent, left, right_parent, right, **kw
  929. ):
  930. for lstr, rstr in util.zip_longest(
  931. sorted(left), sorted(right), fillvalue=None
  932. ):
  933. if lstr != rstr:
  934. return COMPARE_FAILED
  935. self.stack.append((left[lstr], right[rstr]))
  936. def visit_clauseelement_tuples(
  937. self, attrname, left_parent, left, right_parent, right, **kw
  938. ):
  939. for ltup, rtup in util.zip_longest(left, right, fillvalue=None):
  940. if ltup is None or rtup is None:
  941. return COMPARE_FAILED
  942. for l, r in util.zip_longest(ltup, rtup, fillvalue=None):
  943. self.stack.append((l, r))
  944. def visit_clauseelement_list(
  945. self, attrname, left_parent, left, right_parent, right, **kw
  946. ):
  947. for l, r in util.zip_longest(left, right, fillvalue=None):
  948. self.stack.append((l, r))
  949. def visit_clauseelement_tuple(
  950. self, attrname, left_parent, left, right_parent, right, **kw
  951. ):
  952. for l, r in util.zip_longest(left, right, fillvalue=None):
  953. self.stack.append((l, r))
  954. def _compare_unordered_sequences(self, seq1, seq2, **kw):
  955. if seq1 is None:
  956. return seq2 is None
  957. completed = set()
  958. for clause in seq1:
  959. for other_clause in set(seq2).difference(completed):
  960. if self.compare_inner(clause, other_clause, **kw):
  961. completed.add(other_clause)
  962. break
  963. return len(completed) == len(seq1) == len(seq2)
  964. def visit_clauseelement_unordered_set(
  965. self, attrname, left_parent, left, right_parent, right, **kw
  966. ):
  967. return self._compare_unordered_sequences(left, right, **kw)
  968. def visit_fromclause_ordered_set(
  969. self, attrname, left_parent, left, right_parent, right, **kw
  970. ):
  971. for l, r in util.zip_longest(left, right, fillvalue=None):
  972. self.stack.append((l, r))
  973. def visit_string(
  974. self, attrname, left_parent, left, right_parent, right, **kw
  975. ):
  976. return left == right
  977. def visit_string_list(
  978. self, attrname, left_parent, left, right_parent, right, **kw
  979. ):
  980. return left == right
  981. def visit_anon_name(
  982. self, attrname, left_parent, left, right_parent, right, **kw
  983. ):
  984. return _resolve_name_for_compare(
  985. left_parent, left, self.anon_map[0], **kw
  986. ) == _resolve_name_for_compare(
  987. right_parent, right, self.anon_map[1], **kw
  988. )
  989. def visit_boolean(
  990. self, attrname, left_parent, left, right_parent, right, **kw
  991. ):
  992. return left == right
  993. def visit_operator(
  994. self, attrname, left_parent, left, right_parent, right, **kw
  995. ):
  996. return left is right
  997. def visit_type(
  998. self, attrname, left_parent, left, right_parent, right, **kw
  999. ):
  1000. return left._compare_type_affinity(right)
  1001. def visit_plain_dict(
  1002. self, attrname, left_parent, left, right_parent, right, **kw
  1003. ):
  1004. return left == right
  1005. def visit_dialect_options(
  1006. self, attrname, left_parent, left, right_parent, right, **kw
  1007. ):
  1008. return left == right
  1009. def visit_annotations_key(
  1010. self, attrname, left_parent, left, right_parent, right, **kw
  1011. ):
  1012. if left and right:
  1013. return (
  1014. left_parent._annotations_cache_key
  1015. == right_parent._annotations_cache_key
  1016. )
  1017. else:
  1018. return left == right
  1019. def visit_with_context_options(
  1020. self, attrname, left_parent, left, right_parent, right, **kw
  1021. ):
  1022. return tuple((fn.__code__, c_key) for fn, c_key in left) == tuple(
  1023. (fn.__code__, c_key) for fn, c_key in right
  1024. )
  1025. def visit_plain_obj(
  1026. self, attrname, left_parent, left, right_parent, right, **kw
  1027. ):
  1028. return left == right
  1029. def visit_named_ddl_element(
  1030. self, attrname, left_parent, left, right_parent, right, **kw
  1031. ):
  1032. if left is None:
  1033. if right is not None:
  1034. return COMPARE_FAILED
  1035. return left.name == right.name
  1036. def visit_prefix_sequence(
  1037. self, attrname, left_parent, left, right_parent, right, **kw
  1038. ):
  1039. for (l_clause, l_str), (r_clause, r_str) in util.zip_longest(
  1040. left, right, fillvalue=(None, None)
  1041. ):
  1042. if l_str != r_str:
  1043. return COMPARE_FAILED
  1044. else:
  1045. self.stack.append((l_clause, r_clause))
  1046. def visit_setup_join_tuple(
  1047. self, attrname, left_parent, left, right_parent, right, **kw
  1048. ):
  1049. # TODO: look at attrname for "legacy_join" and use different structure
  1050. for (
  1051. (l_target, l_onclause, l_from, l_flags),
  1052. (r_target, r_onclause, r_from, r_flags),
  1053. ) in util.zip_longest(left, right, fillvalue=(None, None, None, None)):
  1054. if l_flags != r_flags:
  1055. return COMPARE_FAILED
  1056. self.stack.append((l_target, r_target))
  1057. self.stack.append((l_onclause, r_onclause))
  1058. self.stack.append((l_from, r_from))
  1059. def visit_memoized_select_entities(
  1060. self, attrname, left_parent, left, right_parent, right, **kw
  1061. ):
  1062. return self.visit_clauseelement_tuple(
  1063. attrname, left_parent, left, right_parent, right, **kw
  1064. )
  1065. def visit_table_hint_list(
  1066. self, attrname, left_parent, left, right_parent, right, **kw
  1067. ):
  1068. left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
  1069. right_keys = sorted(
  1070. right, key=lambda elem: (elem[0].fullname, elem[1])
  1071. )
  1072. for (ltable, ldialect), (rtable, rdialect) in util.zip_longest(
  1073. left_keys, right_keys, fillvalue=(None, None)
  1074. ):
  1075. if ldialect != rdialect:
  1076. return COMPARE_FAILED
  1077. elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
  1078. return COMPARE_FAILED
  1079. else:
  1080. self.stack.append((ltable, rtable))
  1081. def visit_statement_hint_list(
  1082. self, attrname, left_parent, left, right_parent, right, **kw
  1083. ):
  1084. return left == right
  1085. def visit_unknown_structure(
  1086. self, attrname, left_parent, left, right_parent, right, **kw
  1087. ):
  1088. raise NotImplementedError()
  1089. def visit_dml_ordered_values(
  1090. self, attrname, left_parent, left, right_parent, right, **kw
  1091. ):
  1092. # sequence of tuple pairs
  1093. for (lk, lv), (rk, rv) in util.zip_longest(
  1094. left, right, fillvalue=(None, None)
  1095. ):
  1096. if not self._compare_dml_values_or_ce(lk, rk, **kw):
  1097. return COMPARE_FAILED
  1098. def _compare_dml_values_or_ce(self, lv, rv, **kw):
  1099. lvce = hasattr(lv, "__clause_element__")
  1100. rvce = hasattr(rv, "__clause_element__")
  1101. if lvce != rvce:
  1102. return False
  1103. elif lvce and not self.compare_inner(lv, rv, **kw):
  1104. return False
  1105. elif not lvce and lv != rv:
  1106. return False
  1107. elif not self.compare_inner(lv, rv, **kw):
  1108. return False
  1109. return True
  1110. def visit_dml_values(
  1111. self, attrname, left_parent, left, right_parent, right, **kw
  1112. ):
  1113. if left is None or right is None or len(left) != len(right):
  1114. return COMPARE_FAILED
  1115. if isinstance(left, collections_abc.Sequence):
  1116. for lv, rv in zip(left, right):
  1117. if not self._compare_dml_values_or_ce(lv, rv, **kw):
  1118. return COMPARE_FAILED
  1119. elif isinstance(right, collections_abc.Sequence):
  1120. return COMPARE_FAILED
  1121. elif py37:
  1122. # dictionaries guaranteed to support insert ordering in
  1123. # py37 so that we can compare the keys in order. without
  1124. # this, we can't compare SQL expression keys because we don't
  1125. # know which key is which
  1126. for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
  1127. if not self._compare_dml_values_or_ce(lk, rk, **kw):
  1128. return COMPARE_FAILED
  1129. if not self._compare_dml_values_or_ce(lv, rv, **kw):
  1130. return COMPARE_FAILED
  1131. else:
  1132. for lk in left:
  1133. lv = left[lk]
  1134. if lk not in right:
  1135. return COMPARE_FAILED
  1136. rv = right[lk]
  1137. if not self._compare_dml_values_or_ce(lv, rv, **kw):
  1138. return COMPARE_FAILED
  1139. def visit_dml_multi_values(
  1140. self, attrname, left_parent, left, right_parent, right, **kw
  1141. ):
  1142. for lseq, rseq in util.zip_longest(left, right, fillvalue=None):
  1143. if lseq is None or rseq is None:
  1144. return COMPARE_FAILED
  1145. for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None):
  1146. if (
  1147. self.visit_dml_values(
  1148. attrname, left_parent, ld, right_parent, rd, **kw
  1149. )
  1150. is COMPARE_FAILED
  1151. ):
  1152. return COMPARE_FAILED
  1153. def compare_clauselist(self, left, right, **kw):
  1154. if left.operator is right.operator:
  1155. if operators.is_associative(left.operator):
  1156. if self._compare_unordered_sequences(
  1157. left.clauses, right.clauses, **kw
  1158. ):
  1159. return ["operator", "clauses"]
  1160. else:
  1161. return COMPARE_FAILED
  1162. else:
  1163. return ["operator"]
  1164. else:
  1165. return COMPARE_FAILED
  1166. def compare_binary(self, left, right, **kw):
  1167. if left.operator == right.operator:
  1168. if operators.is_commutative(left.operator):
  1169. if (
  1170. self.compare_inner(left.left, right.left, **kw)
  1171. and self.compare_inner(left.right, right.right, **kw)
  1172. ) or (
  1173. self.compare_inner(left.left, right.right, **kw)
  1174. and self.compare_inner(left.right, right.left, **kw)
  1175. ):
  1176. return ["operator", "negate", "left", "right"]
  1177. else:
  1178. return COMPARE_FAILED
  1179. else:
  1180. return ["operator", "negate"]
  1181. else:
  1182. return COMPARE_FAILED
  1183. def compare_bindparam(self, left, right, **kw):
  1184. compare_keys = kw.pop("compare_keys", True)
  1185. compare_values = kw.pop("compare_values", True)
  1186. if compare_values:
  1187. omit = []
  1188. else:
  1189. # this means, "skip these, we already compared"
  1190. omit = ["callable", "value"]
  1191. if not compare_keys:
  1192. omit.append("key")
  1193. return omit
  1194. class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
  1195. def compare_column_element(
  1196. self, left, right, use_proxies=True, equivalents=(), **kw
  1197. ):
  1198. """Compare ColumnElements using proxies and equivalent collections.
  1199. This is a comparison strategy specific to the ORM.
  1200. """
  1201. to_compare = (right,)
  1202. if equivalents and right in equivalents:
  1203. to_compare = equivalents[right].union(to_compare)
  1204. for oth in to_compare:
  1205. if use_proxies and left.shares_lineage(oth):
  1206. return SKIP_TRAVERSE
  1207. elif hash(left) == hash(right):
  1208. return SKIP_TRAVERSE
  1209. else:
  1210. return COMPARE_FAILED
  1211. def compare_column(self, left, right, **kw):
  1212. return self.compare_column_element(left, right, **kw)
  1213. def compare_label(self, left, right, **kw):
  1214. return self.compare_column_element(left, right, **kw)
  1215. def compare_table(self, left, right, **kw):
  1216. # tables compare on identity, since it's not really feasible to
  1217. # compare them column by column with the above rules
  1218. return SKIP_TRAVERSE if left is right else COMPARE_FAILED