You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1254 lines
42KB

  1. # sql/lambdas.py
  2. # Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. import itertools
  8. import operator
  9. import sys
  10. import types
  11. import weakref
  12. from . import coercions
  13. from . import elements
  14. from . import roles
  15. from . import schema
  16. from . import traversals
  17. from . import type_api
  18. from . import visitors
  19. from .base import _clone
  20. from .base import Options
  21. from .operators import ColumnOperators
  22. from .. import exc
  23. from .. import inspection
  24. from .. import util
  25. from ..util import collections_abc
  26. from ..util import compat
  27. _closure_per_cache_key = util.LRUCache(1000)
  28. class LambdaOptions(Options):
  29. enable_tracking = True
  30. track_closure_variables = True
  31. track_on = None
  32. global_track_bound_values = True
  33. track_bound_values = True
  34. lambda_cache = None
  35. def lambda_stmt(
  36. lmb,
  37. enable_tracking=True,
  38. track_closure_variables=True,
  39. track_on=None,
  40. global_track_bound_values=True,
  41. track_bound_values=True,
  42. lambda_cache=None,
  43. ):
  44. """Produce a SQL statement that is cached as a lambda.
  45. The Python code object within the lambda is scanned for both Python
  46. literals that will become bound parameters as well as closure variables
  47. that refer to Core or ORM constructs that may vary. The lambda itself
  48. will be invoked only once per particular set of constructs detected.
  49. E.g.::
  50. from sqlalchemy import lambda_stmt
  51. stmt = lambda_stmt(lambda: table.select())
  52. stmt += lambda s: s.where(table.c.id == 5)
  53. result = connection.execute(stmt)
  54. The object returned is an instance of :class:`_sql.StatementLambdaElement`.
  55. .. versionadded:: 1.4
  56. :param lmb: a Python function, typically a lambda, which takes no arguments
  57. and returns a SQL expression construct
  58. :param enable_tracking: when False, all scanning of the given lambda for
  59. changes in closure variables or bound parameters is disabled. Use for
  60. a lambda that produces the identical results in all cases with no
  61. parameterization.
  62. :param track_closure_variables: when False, changes in closure variables
  63. within the lambda will not be scanned. Use for a lambda where the
  64. state of its closure variables will never change the SQL structure
  65. returned by the lambda.
  66. :param track_bound_values: when False, bound parameter tracking will
  67. be disabled for the given lambda. Use for a lambda that either does
  68. not produce any bound values, or where the initial bound values never
  69. change.
  70. :param global_track_bound_values: when False, bound parameter tracking
  71. will be disabled for the entire statement including additional links
  72. added via the :meth:`_sql.StatementLambdaElement.add_criteria` method.
  73. :param lambda_cache: a dictionary or other mapping-like object where
  74. information about the lambda's Python code as well as the tracked closure
  75. variables in the lambda itself will be stored. Defaults
  76. to a global LRU cache. This cache is independent of the "compiled_cache"
  77. used by the :class:`_engine.Connection` object.
  78. .. seealso::
  79. :ref:`engine_lambda_caching`
  80. """
  81. return StatementLambdaElement(
  82. lmb,
  83. roles.StatementRole,
  84. LambdaOptions(
  85. enable_tracking=enable_tracking,
  86. track_on=track_on,
  87. track_closure_variables=track_closure_variables,
  88. global_track_bound_values=global_track_bound_values,
  89. track_bound_values=track_bound_values,
  90. lambda_cache=lambda_cache,
  91. ),
  92. )
  93. class LambdaElement(elements.ClauseElement):
  94. """A SQL construct where the state is stored as an un-invoked lambda.
  95. The :class:`_sql.LambdaElement` is produced transparently whenever
  96. passing lambda expressions into SQL constructs, such as::
  97. stmt = select(table).where(lambda: table.c.col == parameter)
  98. The :class:`_sql.LambdaElement` is the base of the
  99. :class:`_sql.StatementLambdaElement` which represents a full statement
  100. within a lambda.
  101. .. versionadded:: 1.4
  102. .. seealso::
  103. :ref:`engine_lambda_caching`
  104. """
  105. __visit_name__ = "lambda_element"
  106. _is_lambda_element = True
  107. _traverse_internals = [
  108. ("_resolved", visitors.InternalTraversal.dp_clauseelement)
  109. ]
  110. _transforms = ()
  111. parent_lambda = None
  112. def __repr__(self):
  113. return "%s(%r)" % (self.__class__.__name__, self.fn.__code__)
  114. def __init__(
  115. self, fn, role, opts=LambdaOptions, apply_propagate_attrs=None
  116. ):
  117. self.fn = fn
  118. self.role = role
  119. self.tracker_key = (fn.__code__,)
  120. self.opts = opts
  121. if apply_propagate_attrs is None and (role is roles.StatementRole):
  122. apply_propagate_attrs = self
  123. rec = self._retrieve_tracker_rec(fn, apply_propagate_attrs, opts)
  124. if apply_propagate_attrs is not None:
  125. propagate_attrs = rec.propagate_attrs
  126. if propagate_attrs:
  127. apply_propagate_attrs._propagate_attrs = propagate_attrs
  128. def _retrieve_tracker_rec(self, fn, apply_propagate_attrs, opts):
  129. lambda_cache = opts.lambda_cache
  130. if lambda_cache is None:
  131. lambda_cache = _closure_per_cache_key
  132. tracker_key = self.tracker_key
  133. fn = self.fn
  134. closure = fn.__closure__
  135. tracker = AnalyzedCode.get(
  136. fn,
  137. self,
  138. opts,
  139. )
  140. self._resolved_bindparams = bindparams = []
  141. anon_map = traversals.anon_map()
  142. cache_key = tuple(
  143. [
  144. getter(closure, opts, anon_map, bindparams)
  145. for getter in tracker.closure_trackers
  146. ]
  147. )
  148. if self.parent_lambda is not None:
  149. cache_key = self.parent_lambda.closure_cache_key + cache_key
  150. self.closure_cache_key = cache_key
  151. try:
  152. rec = lambda_cache[tracker_key + cache_key]
  153. except KeyError:
  154. rec = None
  155. if rec is None:
  156. rec = AnalyzedFunction(tracker, self, apply_propagate_attrs, fn)
  157. rec.closure_bindparams = bindparams
  158. lambda_cache[tracker_key + cache_key] = rec
  159. else:
  160. bindparams[:] = [
  161. orig_bind._with_value(new_bind.value, maintain_key=True)
  162. for orig_bind, new_bind in zip(
  163. rec.closure_bindparams, bindparams
  164. )
  165. ]
  166. if self.parent_lambda is not None:
  167. bindparams[:0] = self.parent_lambda._resolved_bindparams
  168. self._rec = rec
  169. lambda_element = self
  170. while lambda_element is not None:
  171. rec = lambda_element._rec
  172. if rec.bindparam_trackers:
  173. tracker_instrumented_fn = rec.tracker_instrumented_fn
  174. for tracker in rec.bindparam_trackers:
  175. tracker(
  176. lambda_element.fn, tracker_instrumented_fn, bindparams
  177. )
  178. lambda_element = lambda_element.parent_lambda
  179. return rec
  180. def __getattr__(self, key):
  181. return getattr(self._rec.expected_expr, key)
  182. @property
  183. def _is_sequence(self):
  184. return self._rec.is_sequence
  185. @property
  186. def _select_iterable(self):
  187. if self._is_sequence:
  188. return itertools.chain.from_iterable(
  189. [element._select_iterable for element in self._resolved]
  190. )
  191. else:
  192. return self._resolved._select_iterable
  193. @property
  194. def _from_objects(self):
  195. if self._is_sequence:
  196. return itertools.chain.from_iterable(
  197. [element._from_objects for element in self._resolved]
  198. )
  199. else:
  200. return self._resolved._from_objects
  201. def _param_dict(self):
  202. return {b.key: b.value for b in self._resolved_bindparams}
  203. def _setup_binds_for_tracked_expr(self, expr):
  204. bindparam_lookup = {b.key: b for b in self._resolved_bindparams}
  205. def replace(thing):
  206. if isinstance(thing, elements.BindParameter):
  207. if thing.key in bindparam_lookup:
  208. bind = bindparam_lookup[thing.key]
  209. if thing.expanding:
  210. bind.expanding = True
  211. bind.expand_op = thing.expand_op
  212. bind.type = thing.type
  213. return bind
  214. if self._rec.is_sequence:
  215. expr = [
  216. visitors.replacement_traverse(sub_expr, {}, replace)
  217. for sub_expr in expr
  218. ]
  219. elif getattr(expr, "is_clause_element", False):
  220. expr = visitors.replacement_traverse(expr, {}, replace)
  221. return expr
  222. def _copy_internals(
  223. self, clone=_clone, deferred_copy_internals=None, **kw
  224. ):
  225. # TODO: this needs A LOT of tests
  226. self._resolved = clone(
  227. self._resolved,
  228. deferred_copy_internals=deferred_copy_internals,
  229. **kw
  230. )
  231. @util.memoized_property
  232. def _resolved(self):
  233. expr = self._rec.expected_expr
  234. if self._resolved_bindparams:
  235. expr = self._setup_binds_for_tracked_expr(expr)
  236. return expr
  237. def _gen_cache_key(self, anon_map, bindparams):
  238. cache_key = (
  239. self.fn.__code__,
  240. self.__class__,
  241. ) + self.closure_cache_key
  242. parent = self.parent_lambda
  243. while parent is not None:
  244. cache_key = (
  245. (parent.fn.__code__,) + parent.closure_cache_key + cache_key
  246. )
  247. parent = parent.parent_lambda
  248. if self._resolved_bindparams:
  249. bindparams.extend(self._resolved_bindparams)
  250. return cache_key
  251. def _invoke_user_fn(self, fn, *arg):
  252. return fn()
  253. class DeferredLambdaElement(LambdaElement):
  254. """A LambdaElement where the lambda accepts arguments and is
  255. invoked within the compile phase with special context.
  256. This lambda doesn't normally produce its real SQL expression outside of the
  257. compile phase. It is passed a fixed set of initial arguments
  258. so that it can generate a sample expression.
  259. """
  260. def __init__(self, fn, role, opts=LambdaOptions, lambda_args=()):
  261. self.lambda_args = lambda_args
  262. super(DeferredLambdaElement, self).__init__(fn, role, opts)
  263. def _invoke_user_fn(self, fn, *arg):
  264. return fn(*self.lambda_args)
  265. def _resolve_with_args(self, *lambda_args):
  266. tracker_fn = self._rec.tracker_instrumented_fn
  267. expr = tracker_fn(*lambda_args)
  268. expr = coercions.expect(self.role, expr)
  269. expr = self._setup_binds_for_tracked_expr(expr)
  270. # this validation is getting very close, but not quite, to achieving
  271. # #5767. The problem is if the base lambda uses an unnamed column
  272. # as is very common with mixins, the parameter name is different
  273. # and it produces a false positive; that is, for the documented case
  274. # that is exactly what people will be doing, it doesn't work, so
  275. # I'm not really sure how to handle this right now.
  276. # expected_binds = [
  277. # b._orig_key
  278. # for b in self._rec.expr._generate_cache_key()[1]
  279. # if b.required
  280. # ]
  281. # got_binds = [
  282. # b._orig_key for b in expr._generate_cache_key()[1] if b.required
  283. # ]
  284. # if expected_binds != got_binds:
  285. # raise exc.InvalidRequestError(
  286. # "Lambda callable at %s produced a different set of bound "
  287. # "parameters than its original run: %s"
  288. # % (self.fn.__code__, ", ".join(got_binds))
  289. # )
  290. # TODO: TEST TEST TEST, this is very out there
  291. for deferred_copy_internals in self._transforms:
  292. expr = deferred_copy_internals(expr)
  293. return expr
  294. def _copy_internals(
  295. self, clone=_clone, deferred_copy_internals=None, **kw
  296. ):
  297. super(DeferredLambdaElement, self)._copy_internals(
  298. clone=clone,
  299. deferred_copy_internals=deferred_copy_internals, # **kw
  300. opts=kw,
  301. )
  302. # TODO: A LOT A LOT of tests. for _resolve_with_args, we don't know
  303. # our expression yet. so hold onto the replacement
  304. if deferred_copy_internals:
  305. self._transforms += (deferred_copy_internals,)
  306. class StatementLambdaElement(roles.AllowsLambdaRole, LambdaElement):
  307. """Represent a composable SQL statement as a :class:`_sql.LambdaElement`.
  308. The :class:`_sql.StatementLambdaElement` is constructed using the
  309. :func:`_sql.lambda_stmt` function::
  310. from sqlalchemy import lambda_stmt
  311. stmt = lambda_stmt(lambda: select(table))
  312. Once constructed, additional criteria can be built onto the statement
  313. by adding subsequent lambdas, which accept the existing statement
  314. object as a single parameter::
  315. stmt += lambda s: s.where(table.c.col == parameter)
  316. .. versionadded:: 1.4
  317. .. seealso::
  318. :ref:`engine_lambda_caching`
  319. """
  320. def __add__(self, other):
  321. return self.add_criteria(other)
  322. def add_criteria(
  323. self,
  324. other,
  325. enable_tracking=True,
  326. track_on=None,
  327. track_closure_variables=True,
  328. track_bound_values=True,
  329. ):
  330. """Add new criteria to this :class:`_sql.StatementLambdaElement`.
  331. E.g.::
  332. >>> def my_stmt(parameter):
  333. ... stmt = lambda_stmt(
  334. ... lambda: select(table.c.x, table.c.y),
  335. ... )
  336. ... stmt = stmt.add_criteria(
  337. ... lambda: table.c.x > parameter
  338. ... )
  339. ... return stmt
  340. The :meth:`_sql.StatementLambdaElement.add_criteria` method is
  341. equivalent to using the Python addition operator to add a new
  342. lambda, except that additional arguments may be added including
  343. ``track_closure_values`` and ``track_on``::
  344. >>> def my_stmt(self, foo):
  345. ... stmt = lambda_stmt(
  346. ... lambda: select(func.max(foo.x, foo.y)),
  347. ... track_closure_variables=False
  348. ... )
  349. ... stmt = stmt.add_criteria(
  350. ... lambda: self.where_criteria,
  351. ... track_on=[self]
  352. ... )
  353. ... return stmt
  354. See :func:`_sql.lambda_stmt` for a description of the parameters
  355. accepted.
  356. """
  357. opts = self.opts + dict(
  358. enable_tracking=enable_tracking,
  359. track_closure_variables=track_closure_variables,
  360. global_track_bound_values=self.opts.global_track_bound_values,
  361. track_on=track_on,
  362. track_bound_values=track_bound_values,
  363. )
  364. return LinkedLambdaElement(other, parent_lambda=self, opts=opts)
  365. def _execute_on_connection(
  366. self, connection, multiparams, params, execution_options
  367. ):
  368. if self._rec.expected_expr.supports_execution:
  369. return connection._execute_clauseelement(
  370. self, multiparams, params, execution_options
  371. )
  372. else:
  373. raise exc.ObjectNotExecutableError(self)
  374. @property
  375. def _with_options(self):
  376. return self._rec.expected_expr._with_options
  377. @property
  378. def _effective_plugin_target(self):
  379. return self._rec.expected_expr._effective_plugin_target
  380. @property
  381. def _execution_options(self):
  382. return self._rec.expected_expr._execution_options
  383. def spoil(self):
  384. """Return a new :class:`.StatementLambdaElement` that will run
  385. all lambdas unconditionally each time.
  386. """
  387. return NullLambdaStatement(self.fn())
  388. class NullLambdaStatement(roles.AllowsLambdaRole, elements.ClauseElement):
  389. """Provides the :class:`.StatementLambdaElement` API but does not
  390. cache or analyze lambdas.
  391. the lambdas are instead invoked immediately.
  392. The intended use is to isolate issues that may arise when using
  393. lambda statements.
  394. """
  395. __visit_name__ = "lambda_element"
  396. _is_lambda_element = True
  397. _traverse_internals = [
  398. ("_resolved", visitors.InternalTraversal.dp_clauseelement)
  399. ]
  400. def __init__(self, statement):
  401. self._resolved = statement
  402. self._propagate_attrs = statement._propagate_attrs
  403. def __getattr__(self, key):
  404. return getattr(self._resolved, key)
  405. def __add__(self, other):
  406. statement = other(self._resolved)
  407. return NullLambdaStatement(statement)
  408. def add_criteria(self, other, **kw):
  409. statement = other(self._resolved)
  410. return NullLambdaStatement(statement)
  411. def _execute_on_connection(
  412. self, connection, multiparams, params, execution_options
  413. ):
  414. if self._resolved.supports_execution:
  415. return connection._execute_clauseelement(
  416. self, multiparams, params, execution_options
  417. )
  418. else:
  419. raise exc.ObjectNotExecutableError(self)
  420. class LinkedLambdaElement(StatementLambdaElement):
  421. """Represent subsequent links of a :class:`.StatementLambdaElement`."""
  422. role = None
  423. def __init__(self, fn, parent_lambda, opts):
  424. self.opts = opts
  425. self.fn = fn
  426. self.parent_lambda = parent_lambda
  427. self.tracker_key = parent_lambda.tracker_key + (fn.__code__,)
  428. self._retrieve_tracker_rec(fn, self, opts)
  429. self._propagate_attrs = parent_lambda._propagate_attrs
  430. def _invoke_user_fn(self, fn, *arg):
  431. return fn(self.parent_lambda._resolved)
  432. class AnalyzedCode(object):
  433. __slots__ = (
  434. "track_closure_variables",
  435. "track_bound_values",
  436. "bindparam_trackers",
  437. "closure_trackers",
  438. "build_py_wrappers",
  439. )
  440. _fns = weakref.WeakKeyDictionary()
  441. @classmethod
  442. def get(cls, fn, lambda_element, lambda_kw, **kw):
  443. try:
  444. # TODO: validate kw haven't changed?
  445. return cls._fns[fn.__code__]
  446. except KeyError:
  447. pass
  448. cls._fns[fn.__code__] = analyzed = AnalyzedCode(
  449. fn, lambda_element, lambda_kw, **kw
  450. )
  451. return analyzed
  452. def __init__(self, fn, lambda_element, opts):
  453. closure = fn.__closure__
  454. self.track_bound_values = (
  455. opts.track_bound_values and opts.global_track_bound_values
  456. )
  457. enable_tracking = opts.enable_tracking
  458. track_on = opts.track_on
  459. track_closure_variables = opts.track_closure_variables
  460. self.track_closure_variables = track_closure_variables and not track_on
  461. # a list of callables generated from _bound_parameter_getter_*
  462. # functions. Each of these uses a PyWrapper object to retrieve
  463. # a parameter value
  464. self.bindparam_trackers = []
  465. # a list of callables generated from _cache_key_getter_* functions
  466. # these callables work to generate a cache key for the lambda
  467. # based on what's inside its closure variables.
  468. self.closure_trackers = []
  469. self.build_py_wrappers = []
  470. if enable_tracking:
  471. if track_on:
  472. self._init_track_on(track_on)
  473. self._init_globals(fn)
  474. if closure:
  475. self._init_closure(fn)
  476. self._setup_additional_closure_trackers(fn, lambda_element, opts)
  477. def _init_track_on(self, track_on):
  478. self.closure_trackers.extend(
  479. self._cache_key_getter_track_on(idx, elem)
  480. for idx, elem in enumerate(track_on)
  481. )
  482. def _init_globals(self, fn):
  483. build_py_wrappers = self.build_py_wrappers
  484. bindparam_trackers = self.bindparam_trackers
  485. track_bound_values = self.track_bound_values
  486. for name in fn.__code__.co_names:
  487. if name not in fn.__globals__:
  488. continue
  489. _bound_value = self._roll_down_to_literal(fn.__globals__[name])
  490. if coercions._deep_is_literal(_bound_value):
  491. build_py_wrappers.append((name, None))
  492. if track_bound_values:
  493. bindparam_trackers.append(
  494. self._bound_parameter_getter_func_globals(name)
  495. )
  496. def _init_closure(self, fn):
  497. build_py_wrappers = self.build_py_wrappers
  498. closure = fn.__closure__
  499. track_bound_values = self.track_bound_values
  500. track_closure_variables = self.track_closure_variables
  501. bindparam_trackers = self.bindparam_trackers
  502. closure_trackers = self.closure_trackers
  503. for closure_index, (fv, cell) in enumerate(
  504. zip(fn.__code__.co_freevars, closure)
  505. ):
  506. _bound_value = self._roll_down_to_literal(cell.cell_contents)
  507. if coercions._deep_is_literal(_bound_value):
  508. build_py_wrappers.append((fv, closure_index))
  509. if track_bound_values:
  510. bindparam_trackers.append(
  511. self._bound_parameter_getter_func_closure(
  512. fv, closure_index
  513. )
  514. )
  515. else:
  516. # for normal cell contents, add them to a list that
  517. # we can compare later when we get new lambdas. if
  518. # any identities have changed, then we will
  519. # recalculate the whole lambda and run it again.
  520. if track_closure_variables:
  521. closure_trackers.append(
  522. self._cache_key_getter_closure_variable(
  523. fn, fv, closure_index, cell.cell_contents
  524. )
  525. )
  526. def _setup_additional_closure_trackers(self, fn, lambda_element, opts):
  527. # an additional step is to actually run the function, then
  528. # go through the PyWrapper objects that were set up to catch a bound
  529. # parameter. then if they *didn't* make a param, oh they're another
  530. # object in the closure we have to track for our cache key. so
  531. # create trackers to catch those.
  532. analyzed_function = AnalyzedFunction(
  533. self,
  534. lambda_element,
  535. None,
  536. fn,
  537. )
  538. closure_trackers = self.closure_trackers
  539. for pywrapper in analyzed_function.closure_pywrappers:
  540. if not pywrapper._sa__has_param:
  541. closure_trackers.append(
  542. self._cache_key_getter_tracked_literal(fn, pywrapper)
  543. )
  544. @classmethod
  545. def _roll_down_to_literal(cls, element):
  546. is_clause_element = hasattr(element, "__clause_element__")
  547. if is_clause_element:
  548. while not isinstance(
  549. element, (elements.ClauseElement, schema.SchemaItem, type)
  550. ):
  551. try:
  552. element = element.__clause_element__()
  553. except AttributeError:
  554. break
  555. if not is_clause_element:
  556. insp = inspection.inspect(element, raiseerr=False)
  557. if insp is not None:
  558. try:
  559. return insp.__clause_element__()
  560. except AttributeError:
  561. return insp
  562. # TODO: should we coerce consts None/True/False here?
  563. return element
  564. else:
  565. return element
  566. def _bound_parameter_getter_func_globals(self, name):
  567. """Return a getter that will extend a list of bound parameters
  568. with new entries from the ``__globals__`` collection of a particular
  569. lambda.
  570. """
  571. def extract_parameter_value(
  572. current_fn, tracker_instrumented_fn, result
  573. ):
  574. wrapper = tracker_instrumented_fn.__globals__[name]
  575. object.__getattribute__(wrapper, "_extract_bound_parameters")(
  576. current_fn.__globals__[name], result
  577. )
  578. return extract_parameter_value
  579. def _bound_parameter_getter_func_closure(self, name, closure_index):
  580. """Return a getter that will extend a list of bound parameters
  581. with new entries from the ``__closure__`` collection of a particular
  582. lambda.
  583. """
  584. def extract_parameter_value(
  585. current_fn, tracker_instrumented_fn, result
  586. ):
  587. wrapper = tracker_instrumented_fn.__closure__[
  588. closure_index
  589. ].cell_contents
  590. object.__getattribute__(wrapper, "_extract_bound_parameters")(
  591. current_fn.__closure__[closure_index].cell_contents, result
  592. )
  593. return extract_parameter_value
  594. def _cache_key_getter_track_on(self, idx, elem):
  595. """Return a getter that will extend a cache key with new entries
  596. from the "track_on" parameter passed to a :class:`.LambdaElement`.
  597. """
  598. if isinstance(elem, tuple):
  599. # tuple must contain hascachekey elements
  600. def get(closure, opts, anon_map, bindparams):
  601. return tuple(
  602. tup_elem._gen_cache_key(anon_map, bindparams)
  603. for tup_elem in opts.track_on[idx]
  604. )
  605. elif isinstance(elem, traversals.HasCacheKey):
  606. def get(closure, opts, anon_map, bindparams):
  607. return opts.track_on[idx]._gen_cache_key(anon_map, bindparams)
  608. else:
  609. def get(closure, opts, anon_map, bindparams):
  610. return opts.track_on[idx]
  611. return get
  612. def _cache_key_getter_closure_variable(
  613. self,
  614. fn,
  615. variable_name,
  616. idx,
  617. cell_contents,
  618. use_clause_element=False,
  619. use_inspect=False,
  620. ):
  621. """Return a getter that will extend a cache key with new entries
  622. from the ``__closure__`` collection of a particular lambda.
  623. """
  624. if isinstance(cell_contents, traversals.HasCacheKey):
  625. def get(closure, opts, anon_map, bindparams):
  626. obj = closure[idx].cell_contents
  627. if use_inspect:
  628. obj = inspection.inspect(obj)
  629. elif use_clause_element:
  630. while hasattr(obj, "__clause_element__"):
  631. if not getattr(obj, "is_clause_element", False):
  632. obj = obj.__clause_element__()
  633. return obj._gen_cache_key(anon_map, bindparams)
  634. elif isinstance(cell_contents, types.FunctionType):
  635. def get(closure, opts, anon_map, bindparams):
  636. return closure[idx].cell_contents.__code__
  637. elif isinstance(cell_contents, collections_abc.Sequence):
  638. def get(closure, opts, anon_map, bindparams):
  639. contents = closure[idx].cell_contents
  640. try:
  641. return tuple(
  642. elem._gen_cache_key(anon_map, bindparams)
  643. for elem in contents
  644. )
  645. except AttributeError as ae:
  646. self._raise_for_uncacheable_closure_variable(
  647. variable_name, fn, from_=ae
  648. )
  649. else:
  650. # if the object is a mapped class or aliased class, or some
  651. # other object in the ORM realm of things like that, imitate
  652. # the logic used in coercions.expect() to roll it down to the
  653. # SQL element
  654. element = cell_contents
  655. is_clause_element = False
  656. while hasattr(element, "__clause_element__"):
  657. is_clause_element = True
  658. if not getattr(element, "is_clause_element", False):
  659. element = element.__clause_element__()
  660. else:
  661. break
  662. if not is_clause_element:
  663. insp = inspection.inspect(element, raiseerr=False)
  664. if insp is not None:
  665. return self._cache_key_getter_closure_variable(
  666. fn, variable_name, idx, insp, use_inspect=True
  667. )
  668. else:
  669. return self._cache_key_getter_closure_variable(
  670. fn, variable_name, idx, element, use_clause_element=True
  671. )
  672. self._raise_for_uncacheable_closure_variable(variable_name, fn)
  673. return get
  674. def _raise_for_uncacheable_closure_variable(
  675. self, variable_name, fn, from_=None
  676. ):
  677. util.raise_(
  678. exc.InvalidRequestError(
  679. "Closure variable named '%s' inside of lambda callable %s "
  680. "does not refer to a cachable SQL element, and also does not "
  681. "appear to be serving as a SQL literal bound value based on "
  682. "the default "
  683. "SQL expression returned by the function. This variable "
  684. "needs to remain outside the scope of a SQL-generating lambda "
  685. "so that a proper cache key may be generated from the "
  686. "lambda's state. Evaluate this variable outside of the "
  687. "lambda, set track_on=[<elements>] to explicitly select "
  688. "closure elements to track, or set "
  689. "track_closure_variables=False to exclude "
  690. "closure variables from being part of the cache key."
  691. % (variable_name, fn.__code__),
  692. ),
  693. from_=from_,
  694. )
  695. def _cache_key_getter_tracked_literal(self, fn, pytracker):
  696. """Return a getter that will extend a cache key with new entries
  697. from the ``__closure__`` collection of a particular lambda.
  698. this getter differs from _cache_key_getter_closure_variable
  699. in that these are detected after the function is run, and PyWrapper
  700. objects have recorded that a particular literal value is in fact
  701. not being interpreted as a bound parameter.
  702. """
  703. elem = pytracker._sa__to_evaluate
  704. closure_index = pytracker._sa__closure_index
  705. variable_name = pytracker._sa__name
  706. return self._cache_key_getter_closure_variable(
  707. fn, variable_name, closure_index, elem
  708. )
  709. class AnalyzedFunction(object):
  710. __slots__ = (
  711. "analyzed_code",
  712. "fn",
  713. "closure_pywrappers",
  714. "tracker_instrumented_fn",
  715. "expr",
  716. "bindparam_trackers",
  717. "expected_expr",
  718. "is_sequence",
  719. "propagate_attrs",
  720. "closure_bindparams",
  721. )
  722. def __init__(
  723. self,
  724. analyzed_code,
  725. lambda_element,
  726. apply_propagate_attrs,
  727. fn,
  728. ):
  729. self.analyzed_code = analyzed_code
  730. self.fn = fn
  731. self.bindparam_trackers = analyzed_code.bindparam_trackers
  732. self._instrument_and_run_function(lambda_element)
  733. self._coerce_expression(lambda_element, apply_propagate_attrs)
  734. def _instrument_and_run_function(self, lambda_element):
  735. analyzed_code = self.analyzed_code
  736. fn = self.fn
  737. self.closure_pywrappers = closure_pywrappers = []
  738. build_py_wrappers = analyzed_code.build_py_wrappers
  739. if not build_py_wrappers:
  740. self.tracker_instrumented_fn = tracker_instrumented_fn = fn
  741. self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
  742. else:
  743. track_closure_variables = analyzed_code.track_closure_variables
  744. closure = fn.__closure__
  745. # will form the __closure__ of the function when we rebuild it
  746. if closure:
  747. new_closure = {
  748. fv: cell.cell_contents
  749. for fv, cell in zip(fn.__code__.co_freevars, closure)
  750. }
  751. else:
  752. new_closure = {}
  753. # will form the __globals__ of the function when we rebuild it
  754. new_globals = fn.__globals__.copy()
  755. for name, closure_index in build_py_wrappers:
  756. if closure_index is not None:
  757. value = closure[closure_index].cell_contents
  758. new_closure[name] = bind = PyWrapper(
  759. fn,
  760. name,
  761. value,
  762. closure_index=closure_index,
  763. track_bound_values=(
  764. self.analyzed_code.track_bound_values
  765. ),
  766. )
  767. if track_closure_variables:
  768. closure_pywrappers.append(bind)
  769. else:
  770. value = fn.__globals__[name]
  771. new_globals[name] = bind = PyWrapper(fn, name, value)
  772. # rewrite the original fn. things that look like they will
  773. # become bound parameters are wrapped in a PyWrapper.
  774. self.tracker_instrumented_fn = (
  775. tracker_instrumented_fn
  776. ) = self._rewrite_code_obj(
  777. fn,
  778. [new_closure[name] for name in fn.__code__.co_freevars],
  779. new_globals,
  780. )
  781. # now invoke the function. This will give us a new SQL
  782. # expression, but all the places that there would be a bound
  783. # parameter, the PyWrapper in its place will give us a bind
  784. # with a predictable name we can match up later.
  785. # additionally, each PyWrapper will log that it did in fact
  786. # create a parameter, otherwise, it's some kind of Python
  787. # object in the closure and we want to track that, to make
  788. # sure it doesn't change to something else, or if it does,
  789. # that we create a different tracked function with that
  790. # variable.
  791. self.expr = lambda_element._invoke_user_fn(tracker_instrumented_fn)
  792. def _coerce_expression(self, lambda_element, apply_propagate_attrs):
  793. """Run the tracker-generated expression through coercion rules.
  794. After the user-defined lambda has been invoked to produce a statement
  795. for re-use, run it through coercion rules to both check that it's the
  796. correct type of object and also to coerce it to its useful form.
  797. """
  798. parent_lambda = lambda_element.parent_lambda
  799. expr = self.expr
  800. if parent_lambda is None:
  801. if isinstance(expr, collections_abc.Sequence):
  802. self.expected_expr = [
  803. coercions.expect(
  804. lambda_element.role,
  805. sub_expr,
  806. apply_propagate_attrs=apply_propagate_attrs,
  807. )
  808. for sub_expr in expr
  809. ]
  810. self.is_sequence = True
  811. else:
  812. self.expected_expr = coercions.expect(
  813. lambda_element.role,
  814. expr,
  815. apply_propagate_attrs=apply_propagate_attrs,
  816. )
  817. self.is_sequence = False
  818. else:
  819. self.expected_expr = expr
  820. self.is_sequence = False
  821. if apply_propagate_attrs is not None:
  822. self.propagate_attrs = apply_propagate_attrs._propagate_attrs
  823. else:
  824. self.propagate_attrs = util.EMPTY_DICT
  825. def _rewrite_code_obj(self, f, cell_values, globals_):
  826. """Return a copy of f, with a new closure and new globals
  827. yes it works in pypy :P
  828. """
  829. argrange = range(len(cell_values))
  830. code = "def make_cells():\n"
  831. if cell_values:
  832. code += " (%s) = (%s)\n" % (
  833. ", ".join("i%d" % i for i in argrange),
  834. ", ".join("o%d" % i for i in argrange),
  835. )
  836. code += " def closure():\n"
  837. code += " return %s\n" % ", ".join("i%d" % i for i in argrange)
  838. code += " return closure.__closure__"
  839. vars_ = {"o%d" % i: cell_values[i] for i in argrange}
  840. compat.exec_(code, vars_, vars_)
  841. closure = vars_["make_cells"]()
  842. func = type(f)(
  843. f.__code__, globals_, f.__name__, f.__defaults__, closure
  844. )
  845. if sys.version_info >= (3,):
  846. func.__annotations__ = f.__annotations__
  847. func.__kwdefaults__ = f.__kwdefaults__
  848. func.__doc__ = f.__doc__
  849. func.__module__ = f.__module__
  850. return func
  851. class PyWrapper(ColumnOperators):
  852. """A wrapper object that is injected into the ``__globals__`` and
  853. ``__closure__`` of a Python function.
  854. When the function is instrumented with :class:`.PyWrapper` objects, it is
  855. then invoked just once in order to set up the wrappers. We look through
  856. all the :class:`.PyWrapper` objects we made to find the ones that generated
  857. a :class:`.BindParameter` object, e.g. the expression system interpreted
  858. something as a literal. Those positions in the globals/closure are then
  859. ones that we will look at, each time a new lambda comes in that refers to
  860. the same ``__code__`` object. In this way, we keep a single version of
  861. the SQL expression that this lambda produced, without calling upon the
  862. Python function that created it more than once, unless its other closure
  863. variables have changed. The expression is then transformed to have the
  864. new bound values embedded into it.
  865. """
  866. def __init__(
  867. self,
  868. fn,
  869. name,
  870. to_evaluate,
  871. closure_index=None,
  872. getter=None,
  873. track_bound_values=True,
  874. ):
  875. self.fn = fn
  876. self._name = name
  877. self._to_evaluate = to_evaluate
  878. self._param = None
  879. self._has_param = False
  880. self._bind_paths = {}
  881. self._getter = getter
  882. self._closure_index = closure_index
  883. self.track_bound_values = track_bound_values
  884. def __call__(self, *arg, **kw):
  885. elem = object.__getattribute__(self, "_to_evaluate")
  886. value = elem(*arg, **kw)
  887. if (
  888. self._sa_track_bound_values
  889. and coercions._deep_is_literal(value)
  890. and not isinstance(
  891. # TODO: coverage where an ORM option or similar is here
  892. value,
  893. traversals.HasCacheKey,
  894. )
  895. ):
  896. name = object.__getattribute__(self, "_name")
  897. raise exc.InvalidRequestError(
  898. "Can't invoke Python callable %s() inside of lambda "
  899. "expression argument at %s; lambda SQL constructs should "
  900. "not invoke functions from closure variables to produce "
  901. "literal values since the "
  902. "lambda SQL system normally extracts bound values without "
  903. "actually "
  904. "invoking the lambda or any functions within it. Call the "
  905. "function outside of the "
  906. "lambda and assign to a local variable that is used in the "
  907. "lambda as a closure variable, or set "
  908. "track_bound_values=False if the return value of this "
  909. "function is used in some other way other than a SQL bound "
  910. "value." % (name, self._sa_fn.__code__)
  911. )
  912. else:
  913. return value
  914. def operate(self, op, *other, **kwargs):
  915. elem = object.__getattribute__(self, "__clause_element__")()
  916. return op(elem, *other, **kwargs)
  917. def reverse_operate(self, op, other, **kwargs):
  918. elem = object.__getattribute__(self, "__clause_element__")()
  919. return op(other, elem, **kwargs)
  920. def _extract_bound_parameters(self, starting_point, result_list):
  921. param = object.__getattribute__(self, "_param")
  922. if param is not None:
  923. param = param._with_value(starting_point, maintain_key=True)
  924. result_list.append(param)
  925. for pywrapper in object.__getattribute__(self, "_bind_paths").values():
  926. getter = object.__getattribute__(pywrapper, "_getter")
  927. element = getter(starting_point)
  928. pywrapper._sa__extract_bound_parameters(element, result_list)
  929. def __clause_element__(self):
  930. param = object.__getattribute__(self, "_param")
  931. to_evaluate = object.__getattribute__(self, "_to_evaluate")
  932. if param is None:
  933. name = object.__getattribute__(self, "_name")
  934. self._param = param = elements.BindParameter(
  935. name, required=False, unique=True
  936. )
  937. self._has_param = True
  938. param.type = type_api._resolve_value_to_type(to_evaluate)
  939. return param._with_value(to_evaluate, maintain_key=True)
  940. def __bool__(self):
  941. to_evaluate = object.__getattribute__(self, "_to_evaluate")
  942. return bool(to_evaluate)
  943. def __nonzero__(self):
  944. to_evaluate = object.__getattribute__(self, "_to_evaluate")
  945. return bool(to_evaluate)
  946. def __getattribute__(self, key):
  947. if key.startswith("_sa_"):
  948. return object.__getattribute__(self, key[4:])
  949. elif key in (
  950. "__clause_element__",
  951. "operate",
  952. "reverse_operate",
  953. "__class__",
  954. "__dict__",
  955. ):
  956. return object.__getattribute__(self, key)
  957. if key.startswith("__"):
  958. elem = object.__getattribute__(self, "_to_evaluate")
  959. return getattr(elem, key)
  960. else:
  961. return self._sa__add_getter(key, operator.attrgetter)
  962. def __iter__(self):
  963. elem = object.__getattribute__(self, "_to_evaluate")
  964. return iter(elem)
  965. def __getitem__(self, key):
  966. elem = object.__getattribute__(self, "_to_evaluate")
  967. if not hasattr(elem, "__getitem__"):
  968. raise AttributeError("__getitem__")
  969. if isinstance(key, PyWrapper):
  970. # TODO: coverage
  971. raise exc.InvalidRequestError(
  972. "Dictionary keys / list indexes inside of a cached "
  973. "lambda must be Python literals only"
  974. )
  975. return self._sa__add_getter(key, operator.itemgetter)
  976. def _add_getter(self, key, getter_fn):
  977. bind_paths = object.__getattribute__(self, "_bind_paths")
  978. bind_path_key = (key, getter_fn)
  979. if bind_path_key in bind_paths:
  980. return bind_paths[bind_path_key]
  981. getter = getter_fn(key)
  982. elem = object.__getattribute__(self, "_to_evaluate")
  983. value = getter(elem)
  984. rolled_down_value = AnalyzedCode._roll_down_to_literal(value)
  985. if coercions._deep_is_literal(rolled_down_value):
  986. wrapper = PyWrapper(self._sa_fn, key, value, getter=getter)
  987. bind_paths[bind_path_key] = wrapper
  988. return wrapper
  989. else:
  990. return value
  991. @inspection._inspects(LambdaElement)
  992. def insp(lmb):
  993. return inspection.inspect(lmb._resolved)