You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

242 lines
6.7KB

  1. # orm/evaluator.py
  2. # Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. import operator
  8. from .. import inspect
  9. from .. import util
  10. from ..sql import and_
  11. from ..sql import operators
  12. class UnevaluatableError(Exception):
  13. pass
  14. class _NoObject(operators.ColumnOperators):
  15. def operate(self, *arg, **kw):
  16. return None
  17. def reverse_operate(self, *arg, **kw):
  18. return None
  19. _NO_OBJECT = _NoObject()
  20. _straight_ops = set(
  21. getattr(operators, op)
  22. for op in (
  23. "add",
  24. "mul",
  25. "sub",
  26. "div",
  27. "mod",
  28. "truediv",
  29. "lt",
  30. "le",
  31. "ne",
  32. "gt",
  33. "ge",
  34. "eq",
  35. )
  36. )
  37. _extended_ops = {
  38. operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None),
  39. operators.not_in_op: (
  40. lambda a, b: a not in b if a is not _NO_OBJECT else None
  41. ),
  42. }
  43. _notimplemented_ops = set(
  44. getattr(operators, op)
  45. for op in (
  46. "like_op",
  47. "not_like_op",
  48. "ilike_op",
  49. "not_ilike_op",
  50. "startswith_op",
  51. "between_op",
  52. "endswith_op",
  53. "concat_op",
  54. )
  55. )
  56. class EvaluatorCompiler(object):
  57. def __init__(self, target_cls=None):
  58. self.target_cls = target_cls
  59. def process(self, *clauses):
  60. if len(clauses) > 1:
  61. clause = and_(*clauses)
  62. elif clauses:
  63. clause = clauses[0]
  64. meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
  65. if not meth:
  66. raise UnevaluatableError(
  67. "Cannot evaluate %s" % type(clause).__name__
  68. )
  69. return meth(clause)
  70. def visit_grouping(self, clause):
  71. return self.process(clause.element)
  72. def visit_null(self, clause):
  73. return lambda obj: None
  74. def visit_false(self, clause):
  75. return lambda obj: False
  76. def visit_true(self, clause):
  77. return lambda obj: True
  78. def visit_column(self, clause):
  79. if "parentmapper" in clause._annotations:
  80. parentmapper = clause._annotations["parentmapper"]
  81. if self.target_cls and not issubclass(
  82. self.target_cls, parentmapper.class_
  83. ):
  84. raise UnevaluatableError(
  85. "Can't evaluate criteria against alternate class %s"
  86. % parentmapper.class_
  87. )
  88. key = parentmapper._columntoproperty[clause].key
  89. else:
  90. key = clause.key
  91. if (
  92. self.target_cls
  93. and key in inspect(self.target_cls).column_attrs
  94. ):
  95. util.warn(
  96. "Evaluating non-mapped column expression '%s' onto "
  97. "ORM instances; this is a deprecated use case. Please "
  98. "make use of the actual mapped columns in ORM-evaluated "
  99. "UPDATE / DELETE expressions." % clause
  100. )
  101. else:
  102. raise UnevaluatableError("Cannot evaluate column: %s" % clause)
  103. get_corresponding_attr = operator.attrgetter(key)
  104. return (
  105. lambda obj: get_corresponding_attr(obj)
  106. if obj is not None
  107. else _NO_OBJECT
  108. )
  109. def visit_tuple(self, clause):
  110. return self.visit_clauselist(clause)
  111. def visit_clauselist(self, clause):
  112. evaluators = list(map(self.process, clause.clauses))
  113. if clause.operator is operators.or_:
  114. def evaluate(obj):
  115. has_null = False
  116. for sub_evaluate in evaluators:
  117. value = sub_evaluate(obj)
  118. if value:
  119. return True
  120. has_null = has_null or value is None
  121. if has_null:
  122. return None
  123. return False
  124. elif clause.operator is operators.and_:
  125. def evaluate(obj):
  126. for sub_evaluate in evaluators:
  127. value = sub_evaluate(obj)
  128. if not value:
  129. if value is None or value is _NO_OBJECT:
  130. return None
  131. return False
  132. return True
  133. elif clause.operator is operators.comma_op:
  134. def evaluate(obj):
  135. values = []
  136. for sub_evaluate in evaluators:
  137. value = sub_evaluate(obj)
  138. if value is None or value is _NO_OBJECT:
  139. return None
  140. values.append(value)
  141. return tuple(values)
  142. else:
  143. raise UnevaluatableError(
  144. "Cannot evaluate clauselist with operator %s" % clause.operator
  145. )
  146. return evaluate
  147. def visit_binary(self, clause):
  148. eval_left, eval_right = list(
  149. map(self.process, [clause.left, clause.right])
  150. )
  151. operator = clause.operator
  152. if operator is operators.is_:
  153. def evaluate(obj):
  154. return eval_left(obj) == eval_right(obj)
  155. elif operator is operators.is_not:
  156. def evaluate(obj):
  157. return eval_left(obj) != eval_right(obj)
  158. elif operator in _extended_ops:
  159. def evaluate(obj):
  160. left_val = eval_left(obj)
  161. right_val = eval_right(obj)
  162. if left_val is None or right_val is None:
  163. return None
  164. return _extended_ops[operator](left_val, right_val)
  165. elif operator in _straight_ops:
  166. def evaluate(obj):
  167. left_val = eval_left(obj)
  168. right_val = eval_right(obj)
  169. if left_val is None or right_val is None:
  170. return None
  171. return operator(eval_left(obj), eval_right(obj))
  172. else:
  173. raise UnevaluatableError(
  174. "Cannot evaluate %s with operator %s"
  175. % (type(clause).__name__, clause.operator)
  176. )
  177. return evaluate
  178. def visit_unary(self, clause):
  179. eval_inner = self.process(clause.element)
  180. if clause.operator is operators.inv:
  181. def evaluate(obj):
  182. value = eval_inner(obj)
  183. if value is None:
  184. return None
  185. return not value
  186. return evaluate
  187. raise UnevaluatableError(
  188. "Cannot evaluate %s with operator %s"
  189. % (type(clause).__name__, clause.operator)
  190. )
  191. def visit_bindparam(self, clause):
  192. if clause.callable:
  193. val = clause.callable()
  194. else:
  195. val = clause.value
  196. return lambda obj: val