You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

440 lines
14KB

  1. # testing/assertsql.py
  2. # Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. import collections
  8. import contextlib
  9. import re
  10. from .. import event
  11. from .. import util
  12. from ..engine import url
  13. from ..engine.default import DefaultDialect
  14. from ..engine.util import _distill_cursor_params
  15. from ..schema import _DDLCompiles
  16. class AssertRule(object):
  17. is_consumed = False
  18. errormessage = None
  19. consume_statement = True
  20. def process_statement(self, execute_observed):
  21. pass
  22. def no_more_statements(self):
  23. assert False, (
  24. "All statements are complete, but pending "
  25. "assertion rules remain"
  26. )
  27. class SQLMatchRule(AssertRule):
  28. pass
  29. class CursorSQL(SQLMatchRule):
  30. def __init__(self, statement, params=None, consume_statement=True):
  31. self.statement = statement
  32. self.params = params
  33. self.consume_statement = consume_statement
  34. def process_statement(self, execute_observed):
  35. stmt = execute_observed.statements[0]
  36. if self.statement != stmt.statement or (
  37. self.params is not None and self.params != stmt.parameters
  38. ):
  39. self.errormessage = (
  40. "Testing for exact SQL %s parameters %s received %s %s"
  41. % (
  42. self.statement,
  43. self.params,
  44. stmt.statement,
  45. stmt.parameters,
  46. )
  47. )
  48. else:
  49. execute_observed.statements.pop(0)
  50. self.is_consumed = True
  51. if not execute_observed.statements:
  52. self.consume_statement = True
  53. class CompiledSQL(SQLMatchRule):
  54. def __init__(self, statement, params=None, dialect="default"):
  55. self.statement = statement
  56. self.params = params
  57. self.dialect = dialect
  58. def _compare_sql(self, execute_observed, received_statement):
  59. stmt = re.sub(r"[\n\t]", "", self.statement)
  60. return received_statement == stmt
  61. def _compile_dialect(self, execute_observed):
  62. if self.dialect == "default":
  63. dialect = DefaultDialect()
  64. # this is currently what tests are expecting
  65. # dialect.supports_default_values = True
  66. dialect.supports_default_metavalue = True
  67. return dialect
  68. else:
  69. # ugh
  70. if self.dialect == "postgresql":
  71. params = {"implicit_returning": True}
  72. else:
  73. params = {}
  74. return url.URL.create(self.dialect).get_dialect()(**params)
  75. def _received_statement(self, execute_observed):
  76. """reconstruct the statement and params in terms
  77. of a target dialect, which for CompiledSQL is just DefaultDialect."""
  78. context = execute_observed.context
  79. compare_dialect = self._compile_dialect(execute_observed)
  80. if "schema_translate_map" in context.execution_options:
  81. map_ = context.execution_options["schema_translate_map"]
  82. else:
  83. map_ = None
  84. if isinstance(execute_observed.clauseelement, _DDLCompiles):
  85. compiled = execute_observed.clauseelement.compile(
  86. dialect=compare_dialect, schema_translate_map=map_
  87. )
  88. else:
  89. compiled = execute_observed.clauseelement.compile(
  90. dialect=compare_dialect,
  91. column_keys=context.compiled.column_keys,
  92. for_executemany=context.compiled.for_executemany,
  93. schema_translate_map=map_,
  94. )
  95. _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
  96. parameters = execute_observed.parameters
  97. if not parameters:
  98. _received_parameters = [compiled.construct_params()]
  99. else:
  100. _received_parameters = [
  101. compiled.construct_params(m) for m in parameters
  102. ]
  103. return _received_statement, _received_parameters
  104. def process_statement(self, execute_observed):
  105. context = execute_observed.context
  106. _received_statement, _received_parameters = self._received_statement(
  107. execute_observed
  108. )
  109. params = self._all_params(context)
  110. equivalent = self._compare_sql(execute_observed, _received_statement)
  111. if equivalent:
  112. if params is not None:
  113. all_params = list(params)
  114. all_received = list(_received_parameters)
  115. while all_params and all_received:
  116. param = dict(all_params.pop(0))
  117. for idx, received in enumerate(list(all_received)):
  118. # do a positive compare only
  119. for param_key in param:
  120. # a key in param did not match current
  121. # 'received'
  122. if (
  123. param_key not in received
  124. or received[param_key] != param[param_key]
  125. ):
  126. break
  127. else:
  128. # all keys in param matched 'received';
  129. # onto next param
  130. del all_received[idx]
  131. break
  132. else:
  133. # param did not match any entry
  134. # in all_received
  135. equivalent = False
  136. break
  137. if all_params or all_received:
  138. equivalent = False
  139. if equivalent:
  140. self.is_consumed = True
  141. self.errormessage = None
  142. else:
  143. self.errormessage = self._failure_message(params) % {
  144. "received_statement": _received_statement,
  145. "received_parameters": _received_parameters,
  146. }
  147. def _all_params(self, context):
  148. if self.params:
  149. if callable(self.params):
  150. params = self.params(context)
  151. else:
  152. params = self.params
  153. if not isinstance(params, list):
  154. params = [params]
  155. return params
  156. else:
  157. return None
  158. def _failure_message(self, expected_params):
  159. return (
  160. "Testing for compiled statement\n%r partial params %s, "
  161. "received\n%%(received_statement)r with params "
  162. "%%(received_parameters)r"
  163. % (
  164. self.statement.replace("%", "%%"),
  165. repr(expected_params).replace("%", "%%"),
  166. )
  167. )
  168. class RegexSQL(CompiledSQL):
  169. def __init__(self, regex, params=None, dialect="default"):
  170. SQLMatchRule.__init__(self)
  171. self.regex = re.compile(regex)
  172. self.orig_regex = regex
  173. self.params = params
  174. self.dialect = dialect
  175. def _failure_message(self, expected_params):
  176. return (
  177. "Testing for compiled statement ~%r partial params %s, "
  178. "received %%(received_statement)r with params "
  179. "%%(received_parameters)r"
  180. % (
  181. self.orig_regex.replace("%", "%%"),
  182. repr(expected_params).replace("%", "%%"),
  183. )
  184. )
  185. def _compare_sql(self, execute_observed, received_statement):
  186. return bool(self.regex.match(received_statement))
  187. class DialectSQL(CompiledSQL):
  188. def _compile_dialect(self, execute_observed):
  189. return execute_observed.context.dialect
  190. def _compare_no_space(self, real_stmt, received_stmt):
  191. stmt = re.sub(r"[\n\t]", "", real_stmt)
  192. return received_stmt == stmt
  193. def _received_statement(self, execute_observed):
  194. received_stmt, received_params = super(
  195. DialectSQL, self
  196. )._received_statement(execute_observed)
  197. # TODO: why do we need this part?
  198. for real_stmt in execute_observed.statements:
  199. if self._compare_no_space(real_stmt.statement, received_stmt):
  200. break
  201. else:
  202. raise AssertionError(
  203. "Can't locate compiled statement %r in list of "
  204. "statements actually invoked" % received_stmt
  205. )
  206. return received_stmt, execute_observed.context.compiled_parameters
  207. def _compare_sql(self, execute_observed, received_statement):
  208. stmt = re.sub(r"[\n\t]", "", self.statement)
  209. # convert our comparison statement to have the
  210. # paramstyle of the received
  211. paramstyle = execute_observed.context.dialect.paramstyle
  212. if paramstyle == "pyformat":
  213. stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
  214. else:
  215. # positional params
  216. repl = None
  217. if paramstyle == "qmark":
  218. repl = "?"
  219. elif paramstyle == "format":
  220. repl = r"%s"
  221. elif paramstyle == "numeric":
  222. repl = None
  223. stmt = re.sub(r":([\w_]+)", repl, stmt)
  224. return received_statement == stmt
  225. class CountStatements(AssertRule):
  226. def __init__(self, count):
  227. self.count = count
  228. self._statement_count = 0
  229. def process_statement(self, execute_observed):
  230. self._statement_count += 1
  231. def no_more_statements(self):
  232. if self.count != self._statement_count:
  233. assert False, "desired statement count %d does not match %d" % (
  234. self.count,
  235. self._statement_count,
  236. )
  237. class AllOf(AssertRule):
  238. def __init__(self, *rules):
  239. self.rules = set(rules)
  240. def process_statement(self, execute_observed):
  241. for rule in list(self.rules):
  242. rule.errormessage = None
  243. rule.process_statement(execute_observed)
  244. if rule.is_consumed:
  245. self.rules.discard(rule)
  246. if not self.rules:
  247. self.is_consumed = True
  248. break
  249. elif not rule.errormessage:
  250. # rule is not done yet
  251. self.errormessage = None
  252. break
  253. else:
  254. self.errormessage = list(self.rules)[0].errormessage
  255. class EachOf(AssertRule):
  256. def __init__(self, *rules):
  257. self.rules = list(rules)
  258. def process_statement(self, execute_observed):
  259. while self.rules:
  260. rule = self.rules[0]
  261. rule.process_statement(execute_observed)
  262. if rule.is_consumed:
  263. self.rules.pop(0)
  264. elif rule.errormessage:
  265. self.errormessage = rule.errormessage
  266. if rule.consume_statement:
  267. break
  268. if not self.rules:
  269. self.is_consumed = True
  270. def no_more_statements(self):
  271. if self.rules and not self.rules[0].is_consumed:
  272. self.rules[0].no_more_statements()
  273. elif self.rules:
  274. super(EachOf, self).no_more_statements()
  275. class Conditional(EachOf):
  276. def __init__(self, condition, rules, else_rules):
  277. if condition:
  278. super(Conditional, self).__init__(*rules)
  279. else:
  280. super(Conditional, self).__init__(*else_rules)
  281. class Or(AllOf):
  282. def process_statement(self, execute_observed):
  283. for rule in self.rules:
  284. rule.process_statement(execute_observed)
  285. if rule.is_consumed:
  286. self.is_consumed = True
  287. break
  288. else:
  289. self.errormessage = list(self.rules)[0].errormessage
  290. class SQLExecuteObserved(object):
  291. def __init__(self, context, clauseelement, multiparams, params):
  292. self.context = context
  293. self.clauseelement = clauseelement
  294. self.parameters = _distill_cursor_params(
  295. context.connection, tuple(multiparams), params
  296. )
  297. self.statements = []
  298. def __repr__(self):
  299. return str(self.statements)
  300. class SQLCursorExecuteObserved(
  301. collections.namedtuple(
  302. "SQLCursorExecuteObserved",
  303. ["statement", "parameters", "context", "executemany"],
  304. )
  305. ):
  306. pass
  307. class SQLAsserter(object):
  308. def __init__(self):
  309. self.accumulated = []
  310. def _close(self):
  311. self._final = self.accumulated
  312. del self.accumulated
  313. def assert_(self, *rules):
  314. rule = EachOf(*rules)
  315. observed = list(self._final)
  316. while observed:
  317. statement = observed.pop(0)
  318. rule.process_statement(statement)
  319. if rule.is_consumed:
  320. break
  321. elif rule.errormessage:
  322. assert False, rule.errormessage
  323. if observed:
  324. assert False, "Additional SQL statements remain:\n%s" % observed
  325. elif not rule.is_consumed:
  326. rule.no_more_statements()
  327. @contextlib.contextmanager
  328. def assert_engine(engine):
  329. asserter = SQLAsserter()
  330. orig = []
  331. @event.listens_for(engine, "before_execute")
  332. def connection_execute(
  333. conn, clauseelement, multiparams, params, execution_options
  334. ):
  335. # grab the original statement + params before any cursor
  336. # execution
  337. orig[:] = clauseelement, multiparams, params
  338. @event.listens_for(engine, "after_cursor_execute")
  339. def cursor_execute(
  340. conn, cursor, statement, parameters, context, executemany
  341. ):
  342. if not context:
  343. return
  344. # then grab real cursor statements and associate them all
  345. # around a single context
  346. if (
  347. asserter.accumulated
  348. and asserter.accumulated[-1].context is context
  349. ):
  350. obs = asserter.accumulated[-1]
  351. else:
  352. obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
  353. asserter.accumulated.append(obs)
  354. obs.statements.append(
  355. SQLCursorExecuteObserved(
  356. statement, parameters, context, executemany
  357. )
  358. )
  359. try:
  360. yield asserter
  361. finally:
  362. event.remove(engine, "after_cursor_execute", cursor_execute)
  363. event.remove(engine, "before_execute", connection_execute)
  364. asserter._close()