You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

2365 lines
77KB

  1. # orm/persistence.py
  2. # Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
  3. # <see AUTHORS file>
  4. #
  5. # This module is part of SQLAlchemy and is released under
  6. # the MIT License: http://www.opensource.org/licenses/mit-license.php
  7. """private module containing functions used to emit INSERT, UPDATE
  8. and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
  9. mappers.
  10. The functions here are called only by the unit of work functions
  11. in unitofwork.py.
  12. """
  13. from itertools import chain
  14. from itertools import groupby
  15. import operator
  16. from . import attributes
  17. from . import evaluator
  18. from . import exc as orm_exc
  19. from . import loading
  20. from . import sync
  21. from .base import state_str
  22. from .. import exc as sa_exc
  23. from .. import future
  24. from .. import sql
  25. from .. import util
  26. from ..engine import result as _result
  27. from ..sql import coercions
  28. from ..sql import expression
  29. from ..sql import operators
  30. from ..sql import roles
  31. from ..sql import select
  32. from ..sql.base import _entity_namespace_key
  33. from ..sql.base import CompileState
  34. from ..sql.base import Options
  35. from ..sql.dml import DeleteDMLState
  36. from ..sql.dml import UpdateDMLState
  37. from ..sql.elements import BooleanClauseList
  38. from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
  39. def _bulk_insert(
  40. mapper,
  41. mappings,
  42. session_transaction,
  43. isstates,
  44. return_defaults,
  45. render_nulls,
  46. ):
  47. base_mapper = mapper.base_mapper
  48. if session_transaction.session.connection_callable:
  49. raise NotImplementedError(
  50. "connection_callable / per-instance sharding "
  51. "not supported in bulk_insert()"
  52. )
  53. if isstates:
  54. if return_defaults:
  55. states = [(state, state.dict) for state in mappings]
  56. mappings = [dict_ for (state, dict_) in states]
  57. else:
  58. mappings = [state.dict for state in mappings]
  59. else:
  60. mappings = list(mappings)
  61. connection = session_transaction.connection(base_mapper)
  62. for table, super_mapper in base_mapper._sorted_tables.items():
  63. if not mapper.isa(super_mapper):
  64. continue
  65. records = (
  66. (
  67. None,
  68. state_dict,
  69. params,
  70. mapper,
  71. connection,
  72. value_params,
  73. has_all_pks,
  74. has_all_defaults,
  75. )
  76. for (
  77. state,
  78. state_dict,
  79. params,
  80. mp,
  81. conn,
  82. value_params,
  83. has_all_pks,
  84. has_all_defaults,
  85. ) in _collect_insert_commands(
  86. table,
  87. ((None, mapping, mapper, connection) for mapping in mappings),
  88. bulk=True,
  89. return_defaults=return_defaults,
  90. render_nulls=render_nulls,
  91. )
  92. )
  93. _emit_insert_statements(
  94. base_mapper,
  95. None,
  96. super_mapper,
  97. table,
  98. records,
  99. bookkeeping=return_defaults,
  100. )
  101. if return_defaults and isstates:
  102. identity_cls = mapper._identity_class
  103. identity_props = [p.key for p in mapper._identity_key_props]
  104. for state, dict_ in states:
  105. state.key = (
  106. identity_cls,
  107. tuple([dict_[key] for key in identity_props]),
  108. )
  109. def _bulk_update(
  110. mapper, mappings, session_transaction, isstates, update_changed_only
  111. ):
  112. base_mapper = mapper.base_mapper
  113. search_keys = mapper._primary_key_propkeys
  114. if mapper._version_id_prop:
  115. search_keys = {mapper._version_id_prop.key}.union(search_keys)
  116. def _changed_dict(mapper, state):
  117. return dict(
  118. (k, v)
  119. for k, v in state.dict.items()
  120. if k in state.committed_state or k in search_keys
  121. )
  122. if isstates:
  123. if update_changed_only:
  124. mappings = [_changed_dict(mapper, state) for state in mappings]
  125. else:
  126. mappings = [state.dict for state in mappings]
  127. else:
  128. mappings = list(mappings)
  129. if session_transaction.session.connection_callable:
  130. raise NotImplementedError(
  131. "connection_callable / per-instance sharding "
  132. "not supported in bulk_update()"
  133. )
  134. connection = session_transaction.connection(base_mapper)
  135. for table, super_mapper in base_mapper._sorted_tables.items():
  136. if not mapper.isa(super_mapper):
  137. continue
  138. records = _collect_update_commands(
  139. None,
  140. table,
  141. (
  142. (
  143. None,
  144. mapping,
  145. mapper,
  146. connection,
  147. (
  148. mapping[mapper._version_id_prop.key]
  149. if mapper._version_id_prop
  150. else None
  151. ),
  152. )
  153. for mapping in mappings
  154. ),
  155. bulk=True,
  156. )
  157. _emit_update_statements(
  158. base_mapper,
  159. None,
  160. super_mapper,
  161. table,
  162. records,
  163. bookkeeping=False,
  164. )
  165. def save_obj(base_mapper, states, uowtransaction, single=False):
  166. """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
  167. of objects.
  168. This is called within the context of a UOWTransaction during a
  169. flush operation, given a list of states to be flushed. The
  170. base mapper in an inheritance hierarchy handles the inserts/
  171. updates for all descendant mappers.
  172. """
  173. # if batch=false, call _save_obj separately for each object
  174. if not single and not base_mapper.batch:
  175. for state in _sort_states(base_mapper, states):
  176. save_obj(base_mapper, [state], uowtransaction, single=True)
  177. return
  178. states_to_update = []
  179. states_to_insert = []
  180. for (
  181. state,
  182. dict_,
  183. mapper,
  184. connection,
  185. has_identity,
  186. row_switch,
  187. update_version_id,
  188. ) in _organize_states_for_save(base_mapper, states, uowtransaction):
  189. if has_identity or row_switch:
  190. states_to_update.append(
  191. (state, dict_, mapper, connection, update_version_id)
  192. )
  193. else:
  194. states_to_insert.append((state, dict_, mapper, connection))
  195. for table, mapper in base_mapper._sorted_tables.items():
  196. if table not in mapper._pks_by_table:
  197. continue
  198. insert = _collect_insert_commands(table, states_to_insert)
  199. update = _collect_update_commands(
  200. uowtransaction, table, states_to_update
  201. )
  202. _emit_update_statements(
  203. base_mapper,
  204. uowtransaction,
  205. mapper,
  206. table,
  207. update,
  208. )
  209. _emit_insert_statements(
  210. base_mapper,
  211. uowtransaction,
  212. mapper,
  213. table,
  214. insert,
  215. )
  216. _finalize_insert_update_commands(
  217. base_mapper,
  218. uowtransaction,
  219. chain(
  220. (
  221. (state, state_dict, mapper, connection, False)
  222. for (state, state_dict, mapper, connection) in states_to_insert
  223. ),
  224. (
  225. (state, state_dict, mapper, connection, True)
  226. for (
  227. state,
  228. state_dict,
  229. mapper,
  230. connection,
  231. update_version_id,
  232. ) in states_to_update
  233. ),
  234. ),
  235. )
  236. def post_update(base_mapper, states, uowtransaction, post_update_cols):
  237. """Issue UPDATE statements on behalf of a relationship() which
  238. specifies post_update.
  239. """
  240. states_to_update = list(
  241. _organize_states_for_post_update(base_mapper, states, uowtransaction)
  242. )
  243. for table, mapper in base_mapper._sorted_tables.items():
  244. if table not in mapper._pks_by_table:
  245. continue
  246. update = (
  247. (
  248. state,
  249. state_dict,
  250. sub_mapper,
  251. connection,
  252. mapper._get_committed_state_attr_by_column(
  253. state, state_dict, mapper.version_id_col
  254. )
  255. if mapper.version_id_col is not None
  256. else None,
  257. )
  258. for state, state_dict, sub_mapper, connection in states_to_update
  259. if table in sub_mapper._pks_by_table
  260. )
  261. update = _collect_post_update_commands(
  262. base_mapper, uowtransaction, table, update, post_update_cols
  263. )
  264. _emit_post_update_statements(
  265. base_mapper,
  266. uowtransaction,
  267. mapper,
  268. table,
  269. update,
  270. )
  271. def delete_obj(base_mapper, states, uowtransaction):
  272. """Issue ``DELETE`` statements for a list of objects.
  273. This is called within the context of a UOWTransaction during a
  274. flush operation.
  275. """
  276. states_to_delete = list(
  277. _organize_states_for_delete(base_mapper, states, uowtransaction)
  278. )
  279. table_to_mapper = base_mapper._sorted_tables
  280. for table in reversed(list(table_to_mapper.keys())):
  281. mapper = table_to_mapper[table]
  282. if table not in mapper._pks_by_table:
  283. continue
  284. elif mapper.inherits and mapper.passive_deletes:
  285. continue
  286. delete = _collect_delete_commands(
  287. base_mapper, uowtransaction, table, states_to_delete
  288. )
  289. _emit_delete_statements(
  290. base_mapper,
  291. uowtransaction,
  292. mapper,
  293. table,
  294. delete,
  295. )
  296. for (
  297. state,
  298. state_dict,
  299. mapper,
  300. connection,
  301. update_version_id,
  302. ) in states_to_delete:
  303. mapper.dispatch.after_delete(mapper, connection, state)
  304. def _organize_states_for_save(base_mapper, states, uowtransaction):
  305. """Make an initial pass across a set of states for INSERT or
  306. UPDATE.
  307. This includes splitting out into distinct lists for
  308. each, calling before_insert/before_update, obtaining
  309. key information for each state including its dictionary,
  310. mapper, the connection to use for the execution per state,
  311. and the identity flag.
  312. """
  313. for state, dict_, mapper, connection in _connections_for_states(
  314. base_mapper, uowtransaction, states
  315. ):
  316. has_identity = bool(state.key)
  317. instance_key = state.key or mapper._identity_key_from_state(state)
  318. row_switch = update_version_id = None
  319. # call before_XXX extensions
  320. if not has_identity:
  321. mapper.dispatch.before_insert(mapper, connection, state)
  322. else:
  323. mapper.dispatch.before_update(mapper, connection, state)
  324. if mapper._validate_polymorphic_identity:
  325. mapper._validate_polymorphic_identity(mapper, state, dict_)
  326. # detect if we have a "pending" instance (i.e. has
  327. # no instance_key attached to it), and another instance
  328. # with the same identity key already exists as persistent.
  329. # convert to an UPDATE if so.
  330. if (
  331. not has_identity
  332. and instance_key in uowtransaction.session.identity_map
  333. ):
  334. instance = uowtransaction.session.identity_map[instance_key]
  335. existing = attributes.instance_state(instance)
  336. if not uowtransaction.was_already_deleted(existing):
  337. if not uowtransaction.is_deleted(existing):
  338. util.warn(
  339. "New instance %s with identity key %s conflicts "
  340. "with persistent instance %s"
  341. % (state_str(state), instance_key, state_str(existing))
  342. )
  343. else:
  344. base_mapper._log_debug(
  345. "detected row switch for identity %s. "
  346. "will update %s, remove %s from "
  347. "transaction",
  348. instance_key,
  349. state_str(state),
  350. state_str(existing),
  351. )
  352. # remove the "delete" flag from the existing element
  353. uowtransaction.remove_state_actions(existing)
  354. row_switch = existing
  355. if (has_identity or row_switch) and mapper.version_id_col is not None:
  356. update_version_id = mapper._get_committed_state_attr_by_column(
  357. row_switch if row_switch else state,
  358. row_switch.dict if row_switch else dict_,
  359. mapper.version_id_col,
  360. )
  361. yield (
  362. state,
  363. dict_,
  364. mapper,
  365. connection,
  366. has_identity,
  367. row_switch,
  368. update_version_id,
  369. )
  370. def _organize_states_for_post_update(base_mapper, states, uowtransaction):
  371. """Make an initial pass across a set of states for UPDATE
  372. corresponding to post_update.
  373. This includes obtaining key information for each state
  374. including its dictionary, mapper, the connection to use for
  375. the execution per state.
  376. """
  377. return _connections_for_states(base_mapper, uowtransaction, states)
  378. def _organize_states_for_delete(base_mapper, states, uowtransaction):
  379. """Make an initial pass across a set of states for DELETE.
  380. This includes calling out before_delete and obtaining
  381. key information for each state including its dictionary,
  382. mapper, the connection to use for the execution per state.
  383. """
  384. for state, dict_, mapper, connection in _connections_for_states(
  385. base_mapper, uowtransaction, states
  386. ):
  387. mapper.dispatch.before_delete(mapper, connection, state)
  388. if mapper.version_id_col is not None:
  389. update_version_id = mapper._get_committed_state_attr_by_column(
  390. state, dict_, mapper.version_id_col
  391. )
  392. else:
  393. update_version_id = None
  394. yield (state, dict_, mapper, connection, update_version_id)
  395. def _collect_insert_commands(
  396. table,
  397. states_to_insert,
  398. bulk=False,
  399. return_defaults=False,
  400. render_nulls=False,
  401. ):
  402. """Identify sets of values to use in INSERT statements for a
  403. list of states.
  404. """
  405. for state, state_dict, mapper, connection in states_to_insert:
  406. if table not in mapper._pks_by_table:
  407. continue
  408. params = {}
  409. value_params = {}
  410. propkey_to_col = mapper._propkey_to_col[table]
  411. eval_none = mapper._insert_cols_evaluating_none[table]
  412. for propkey in set(propkey_to_col).intersection(state_dict):
  413. value = state_dict[propkey]
  414. col = propkey_to_col[propkey]
  415. if value is None and col not in eval_none and not render_nulls:
  416. continue
  417. elif not bulk and (
  418. hasattr(value, "__clause_element__")
  419. or isinstance(value, sql.ClauseElement)
  420. ):
  421. value_params[col] = (
  422. value.__clause_element__()
  423. if hasattr(value, "__clause_element__")
  424. else value
  425. )
  426. else:
  427. params[col.key] = value
  428. if not bulk:
  429. # for all the columns that have no default and we don't have
  430. # a value and where "None" is not a special value, add
  431. # explicit None to the INSERT. This is a legacy behavior
  432. # which might be worth removing, as it should not be necessary
  433. # and also produces confusion, given that "missing" and None
  434. # now have distinct meanings
  435. for colkey in (
  436. mapper._insert_cols_as_none[table]
  437. .difference(params)
  438. .difference([c.key for c in value_params])
  439. ):
  440. params[colkey] = None
  441. if not bulk or return_defaults:
  442. # params are in terms of Column key objects, so
  443. # compare to pk_keys_by_table
  444. has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
  445. if mapper.base_mapper.eager_defaults:
  446. has_all_defaults = mapper._server_default_cols[table].issubset(
  447. params
  448. )
  449. else:
  450. has_all_defaults = True
  451. else:
  452. has_all_defaults = has_all_pks = True
  453. if (
  454. mapper.version_id_generator is not False
  455. and mapper.version_id_col is not None
  456. and mapper.version_id_col in mapper._cols_by_table[table]
  457. ):
  458. params[mapper.version_id_col.key] = mapper.version_id_generator(
  459. None
  460. )
  461. yield (
  462. state,
  463. state_dict,
  464. params,
  465. mapper,
  466. connection,
  467. value_params,
  468. has_all_pks,
  469. has_all_defaults,
  470. )
  471. def _collect_update_commands(
  472. uowtransaction, table, states_to_update, bulk=False
  473. ):
  474. """Identify sets of values to use in UPDATE statements for a
  475. list of states.
  476. This function works intricately with the history system
  477. to determine exactly what values should be updated
  478. as well as how the row should be matched within an UPDATE
  479. statement. Includes some tricky scenarios where the primary
  480. key of an object might have been changed.
  481. """
  482. for (
  483. state,
  484. state_dict,
  485. mapper,
  486. connection,
  487. update_version_id,
  488. ) in states_to_update:
  489. if table not in mapper._pks_by_table:
  490. continue
  491. pks = mapper._pks_by_table[table]
  492. value_params = {}
  493. propkey_to_col = mapper._propkey_to_col[table]
  494. if bulk:
  495. # keys here are mapped attribute keys, so
  496. # look at mapper attribute keys for pk
  497. params = dict(
  498. (propkey_to_col[propkey].key, state_dict[propkey])
  499. for propkey in set(propkey_to_col)
  500. .intersection(state_dict)
  501. .difference(mapper._pk_attr_keys_by_table[table])
  502. )
  503. has_all_defaults = True
  504. else:
  505. params = {}
  506. for propkey in set(propkey_to_col).intersection(
  507. state.committed_state
  508. ):
  509. value = state_dict[propkey]
  510. col = propkey_to_col[propkey]
  511. if hasattr(value, "__clause_element__") or isinstance(
  512. value, sql.ClauseElement
  513. ):
  514. value_params[col] = (
  515. value.__clause_element__()
  516. if hasattr(value, "__clause_element__")
  517. else value
  518. )
  519. # guard against values that generate non-__nonzero__
  520. # objects for __eq__()
  521. elif (
  522. state.manager[propkey].impl.is_equal(
  523. value, state.committed_state[propkey]
  524. )
  525. is not True
  526. ):
  527. params[col.key] = value
  528. if mapper.base_mapper.eager_defaults:
  529. has_all_defaults = (
  530. mapper._server_onupdate_default_cols[table]
  531. ).issubset(params)
  532. else:
  533. has_all_defaults = True
  534. if (
  535. update_version_id is not None
  536. and mapper.version_id_col in mapper._cols_by_table[table]
  537. ):
  538. if not bulk and not (params or value_params):
  539. # HACK: check for history in other tables, in case the
  540. # history is only in a different table than the one
  541. # where the version_id_col is. This logic was lost
  542. # from 0.9 -> 1.0.0 and restored in 1.0.6.
  543. for prop in mapper._columntoproperty.values():
  544. history = state.manager[prop.key].impl.get_history(
  545. state, state_dict, attributes.PASSIVE_NO_INITIALIZE
  546. )
  547. if history.added:
  548. break
  549. else:
  550. # no net change, break
  551. continue
  552. col = mapper.version_id_col
  553. no_params = not params and not value_params
  554. params[col._label] = update_version_id
  555. if (
  556. bulk or col.key not in params
  557. ) and mapper.version_id_generator is not False:
  558. val = mapper.version_id_generator(update_version_id)
  559. params[col.key] = val
  560. elif mapper.version_id_generator is False and no_params:
  561. # no version id generator, no values set on the table,
  562. # and version id wasn't manually incremented.
  563. # set version id to itself so we get an UPDATE
  564. # statement
  565. params[col.key] = update_version_id
  566. elif not (params or value_params):
  567. continue
  568. has_all_pks = True
  569. expect_pk_cascaded = False
  570. if bulk:
  571. # keys here are mapped attribute keys, so
  572. # look at mapper attribute keys for pk
  573. pk_params = dict(
  574. (propkey_to_col[propkey]._label, state_dict.get(propkey))
  575. for propkey in set(propkey_to_col).intersection(
  576. mapper._pk_attr_keys_by_table[table]
  577. )
  578. )
  579. else:
  580. pk_params = {}
  581. for col in pks:
  582. propkey = mapper._columntoproperty[col].key
  583. history = state.manager[propkey].impl.get_history(
  584. state, state_dict, attributes.PASSIVE_OFF
  585. )
  586. if history.added:
  587. if (
  588. not history.deleted
  589. or ("pk_cascaded", state, col)
  590. in uowtransaction.attributes
  591. ):
  592. expect_pk_cascaded = True
  593. pk_params[col._label] = history.added[0]
  594. params.pop(col.key, None)
  595. else:
  596. # else, use the old value to locate the row
  597. pk_params[col._label] = history.deleted[0]
  598. if col in value_params:
  599. has_all_pks = False
  600. else:
  601. pk_params[col._label] = history.unchanged[0]
  602. if pk_params[col._label] is None:
  603. raise orm_exc.FlushError(
  604. "Can't update table %s using NULL for primary "
  605. "key value on column %s" % (table, col)
  606. )
  607. if params or value_params:
  608. params.update(pk_params)
  609. yield (
  610. state,
  611. state_dict,
  612. params,
  613. mapper,
  614. connection,
  615. value_params,
  616. has_all_defaults,
  617. has_all_pks,
  618. )
  619. elif expect_pk_cascaded:
  620. # no UPDATE occurs on this table, but we expect that CASCADE rules
  621. # have changed the primary key of the row; propagate this event to
  622. # other columns that expect to have been modified. this normally
  623. # occurs after the UPDATE is emitted however we invoke it here
  624. # explicitly in the absence of our invoking an UPDATE
  625. for m, equated_pairs in mapper._table_to_equated[table]:
  626. sync.populate(
  627. state,
  628. m,
  629. state,
  630. m,
  631. equated_pairs,
  632. uowtransaction,
  633. mapper.passive_updates,
  634. )
  635. def _collect_post_update_commands(
  636. base_mapper, uowtransaction, table, states_to_update, post_update_cols
  637. ):
  638. """Identify sets of values to use in UPDATE statements for a
  639. list of states within a post_update operation.
  640. """
  641. for (
  642. state,
  643. state_dict,
  644. mapper,
  645. connection,
  646. update_version_id,
  647. ) in states_to_update:
  648. # assert table in mapper._pks_by_table
  649. pks = mapper._pks_by_table[table]
  650. params = {}
  651. hasdata = False
  652. for col in mapper._cols_by_table[table]:
  653. if col in pks:
  654. params[col._label] = mapper._get_state_attr_by_column(
  655. state, state_dict, col, passive=attributes.PASSIVE_OFF
  656. )
  657. elif col in post_update_cols or col.onupdate is not None:
  658. prop = mapper._columntoproperty[col]
  659. history = state.manager[prop.key].impl.get_history(
  660. state, state_dict, attributes.PASSIVE_NO_INITIALIZE
  661. )
  662. if history.added:
  663. value = history.added[0]
  664. params[col.key] = value
  665. hasdata = True
  666. if hasdata:
  667. if (
  668. update_version_id is not None
  669. and mapper.version_id_col in mapper._cols_by_table[table]
  670. ):
  671. col = mapper.version_id_col
  672. params[col._label] = update_version_id
  673. if (
  674. bool(state.key)
  675. and col.key not in params
  676. and mapper.version_id_generator is not False
  677. ):
  678. val = mapper.version_id_generator(update_version_id)
  679. params[col.key] = val
  680. yield state, state_dict, mapper, connection, params
  681. def _collect_delete_commands(
  682. base_mapper, uowtransaction, table, states_to_delete
  683. ):
  684. """Identify values to use in DELETE statements for a list of
  685. states to be deleted."""
  686. for (
  687. state,
  688. state_dict,
  689. mapper,
  690. connection,
  691. update_version_id,
  692. ) in states_to_delete:
  693. if table not in mapper._pks_by_table:
  694. continue
  695. params = {}
  696. for col in mapper._pks_by_table[table]:
  697. params[
  698. col.key
  699. ] = value = mapper._get_committed_state_attr_by_column(
  700. state, state_dict, col
  701. )
  702. if value is None:
  703. raise orm_exc.FlushError(
  704. "Can't delete from table %s "
  705. "using NULL for primary "
  706. "key value on column %s" % (table, col)
  707. )
  708. if (
  709. update_version_id is not None
  710. and mapper.version_id_col in mapper._cols_by_table[table]
  711. ):
  712. params[mapper.version_id_col.key] = update_version_id
  713. yield params, connection
  714. def _emit_update_statements(
  715. base_mapper,
  716. uowtransaction,
  717. mapper,
  718. table,
  719. update,
  720. bookkeeping=True,
  721. ):
  722. """Emit UPDATE statements corresponding to value lists collected
  723. by _collect_update_commands()."""
  724. needs_version_id = (
  725. mapper.version_id_col is not None
  726. and mapper.version_id_col in mapper._cols_by_table[table]
  727. )
  728. execution_options = {"compiled_cache": base_mapper._compiled_cache}
  729. def update_stmt():
  730. clauses = BooleanClauseList._construct_raw(operators.and_)
  731. for col in mapper._pks_by_table[table]:
  732. clauses.clauses.append(
  733. col == sql.bindparam(col._label, type_=col.type)
  734. )
  735. if needs_version_id:
  736. clauses.clauses.append(
  737. mapper.version_id_col
  738. == sql.bindparam(
  739. mapper.version_id_col._label,
  740. type_=mapper.version_id_col.type,
  741. )
  742. )
  743. stmt = table.update().where(clauses)
  744. return stmt
  745. cached_stmt = base_mapper._memo(("update", table), update_stmt)
  746. for (
  747. (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
  748. records,
  749. ) in groupby(
  750. update,
  751. lambda rec: (
  752. rec[4], # connection
  753. set(rec[2]), # set of parameter keys
  754. bool(rec[5]), # whether or not we have "value" parameters
  755. rec[6], # has_all_defaults
  756. rec[7], # has all pks
  757. ),
  758. ):
  759. rows = 0
  760. records = list(records)
  761. statement = cached_stmt
  762. return_defaults = False
  763. if not has_all_pks:
  764. statement = statement.return_defaults()
  765. return_defaults = True
  766. elif (
  767. bookkeeping
  768. and not has_all_defaults
  769. and mapper.base_mapper.eager_defaults
  770. ):
  771. statement = statement.return_defaults()
  772. return_defaults = True
  773. elif mapper.version_id_col is not None:
  774. statement = statement.return_defaults(mapper.version_id_col)
  775. return_defaults = True
  776. assert_singlerow = (
  777. connection.dialect.supports_sane_rowcount
  778. if not return_defaults
  779. else connection.dialect.supports_sane_rowcount_returning
  780. )
  781. assert_multirow = (
  782. assert_singlerow
  783. and connection.dialect.supports_sane_multi_rowcount
  784. )
  785. allow_multirow = has_all_defaults and not needs_version_id
  786. if hasvalue:
  787. for (
  788. state,
  789. state_dict,
  790. params,
  791. mapper,
  792. connection,
  793. value_params,
  794. has_all_defaults,
  795. has_all_pks,
  796. ) in records:
  797. c = connection._execute_20(
  798. statement.values(value_params),
  799. params,
  800. execution_options=execution_options,
  801. )
  802. if bookkeeping:
  803. _postfetch(
  804. mapper,
  805. uowtransaction,
  806. table,
  807. state,
  808. state_dict,
  809. c,
  810. c.context.compiled_parameters[0],
  811. value_params,
  812. True,
  813. c.returned_defaults,
  814. )
  815. rows += c.rowcount
  816. check_rowcount = assert_singlerow
  817. else:
  818. if not allow_multirow:
  819. check_rowcount = assert_singlerow
  820. for (
  821. state,
  822. state_dict,
  823. params,
  824. mapper,
  825. connection,
  826. value_params,
  827. has_all_defaults,
  828. has_all_pks,
  829. ) in records:
  830. c = connection._execute_20(
  831. statement, params, execution_options=execution_options
  832. )
  833. # TODO: why with bookkeeping=False?
  834. if bookkeeping:
  835. _postfetch(
  836. mapper,
  837. uowtransaction,
  838. table,
  839. state,
  840. state_dict,
  841. c,
  842. c.context.compiled_parameters[0],
  843. value_params,
  844. True,
  845. c.returned_defaults,
  846. )
  847. rows += c.rowcount
  848. else:
  849. multiparams = [rec[2] for rec in records]
  850. check_rowcount = assert_multirow or (
  851. assert_singlerow and len(multiparams) == 1
  852. )
  853. c = connection._execute_20(
  854. statement, multiparams, execution_options=execution_options
  855. )
  856. rows += c.rowcount
  857. for (
  858. state,
  859. state_dict,
  860. params,
  861. mapper,
  862. connection,
  863. value_params,
  864. has_all_defaults,
  865. has_all_pks,
  866. ) in records:
  867. if bookkeeping:
  868. _postfetch(
  869. mapper,
  870. uowtransaction,
  871. table,
  872. state,
  873. state_dict,
  874. c,
  875. c.context.compiled_parameters[0],
  876. value_params,
  877. True,
  878. c.returned_defaults
  879. if not c.context.executemany
  880. else None,
  881. )
  882. if check_rowcount:
  883. if rows != len(records):
  884. raise orm_exc.StaleDataError(
  885. "UPDATE statement on table '%s' expected to "
  886. "update %d row(s); %d were matched."
  887. % (table.description, len(records), rows)
  888. )
  889. elif needs_version_id:
  890. util.warn(
  891. "Dialect %s does not support updated rowcount "
  892. "- versioning cannot be verified."
  893. % c.dialect.dialect_description
  894. )
  895. def _emit_insert_statements(
  896. base_mapper,
  897. uowtransaction,
  898. mapper,
  899. table,
  900. insert,
  901. bookkeeping=True,
  902. ):
  903. """Emit INSERT statements corresponding to value lists collected
  904. by _collect_insert_commands()."""
  905. cached_stmt = base_mapper._memo(("insert", table), table.insert)
  906. execution_options = {"compiled_cache": base_mapper._compiled_cache}
  907. for (
  908. (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
  909. records,
  910. ) in groupby(
  911. insert,
  912. lambda rec: (
  913. rec[4], # connection
  914. set(rec[2]), # parameter keys
  915. bool(rec[5]), # whether we have "value" parameters
  916. rec[6],
  917. rec[7],
  918. ),
  919. ):
  920. statement = cached_stmt
  921. if (
  922. not bookkeeping
  923. or (
  924. has_all_defaults
  925. or not base_mapper.eager_defaults
  926. or not connection.dialect.implicit_returning
  927. )
  928. and has_all_pks
  929. and not hasvalue
  930. ):
  931. # the "we don't need newly generated values back" section.
  932. # here we have all the PKs, all the defaults or we don't want
  933. # to fetch them, or the dialect doesn't support RETURNING at all
  934. # so we have to post-fetch / use lastrowid anyway.
  935. records = list(records)
  936. multiparams = [rec[2] for rec in records]
  937. c = connection._execute_20(
  938. statement, multiparams, execution_options=execution_options
  939. )
  940. if bookkeeping:
  941. for (
  942. (
  943. state,
  944. state_dict,
  945. params,
  946. mapper_rec,
  947. conn,
  948. value_params,
  949. has_all_pks,
  950. has_all_defaults,
  951. ),
  952. last_inserted_params,
  953. ) in zip(records, c.context.compiled_parameters):
  954. if state:
  955. _postfetch(
  956. mapper_rec,
  957. uowtransaction,
  958. table,
  959. state,
  960. state_dict,
  961. c,
  962. last_inserted_params,
  963. value_params,
  964. False,
  965. c.returned_defaults
  966. if not c.context.executemany
  967. else None,
  968. )
  969. else:
  970. _postfetch_bulk_save(mapper_rec, state_dict, table)
  971. else:
  972. # here, we need defaults and/or pk values back.
  973. records = list(records)
  974. if (
  975. not hasvalue
  976. and connection.dialect.insert_executemany_returning
  977. and len(records) > 1
  978. ):
  979. do_executemany = True
  980. else:
  981. do_executemany = False
  982. if not has_all_defaults and base_mapper.eager_defaults:
  983. statement = statement.return_defaults()
  984. elif mapper.version_id_col is not None:
  985. statement = statement.return_defaults(mapper.version_id_col)
  986. elif do_executemany:
  987. statement = statement.return_defaults(*table.primary_key)
  988. if do_executemany:
  989. multiparams = [rec[2] for rec in records]
  990. c = connection._execute_20(
  991. statement, multiparams, execution_options=execution_options
  992. )
  993. if bookkeeping:
  994. for (
  995. (
  996. state,
  997. state_dict,
  998. params,
  999. mapper_rec,
  1000. conn,
  1001. value_params,
  1002. has_all_pks,
  1003. has_all_defaults,
  1004. ),
  1005. last_inserted_params,
  1006. inserted_primary_key,
  1007. returned_defaults,
  1008. ) in util.zip_longest(
  1009. records,
  1010. c.context.compiled_parameters,
  1011. c.inserted_primary_key_rows,
  1012. c.returned_defaults_rows or (),
  1013. ):
  1014. for pk, col in zip(
  1015. inserted_primary_key,
  1016. mapper._pks_by_table[table],
  1017. ):
  1018. prop = mapper_rec._columntoproperty[col]
  1019. if state_dict.get(prop.key) is None:
  1020. state_dict[prop.key] = pk
  1021. if state:
  1022. _postfetch(
  1023. mapper_rec,
  1024. uowtransaction,
  1025. table,
  1026. state,
  1027. state_dict,
  1028. c,
  1029. last_inserted_params,
  1030. value_params,
  1031. False,
  1032. returned_defaults,
  1033. )
  1034. else:
  1035. _postfetch_bulk_save(mapper_rec, state_dict, table)
  1036. else:
  1037. for (
  1038. state,
  1039. state_dict,
  1040. params,
  1041. mapper_rec,
  1042. connection,
  1043. value_params,
  1044. has_all_pks,
  1045. has_all_defaults,
  1046. ) in records:
  1047. if value_params:
  1048. result = connection._execute_20(
  1049. statement.values(value_params),
  1050. params,
  1051. execution_options=execution_options,
  1052. )
  1053. else:
  1054. result = connection._execute_20(
  1055. statement,
  1056. params,
  1057. execution_options=execution_options,
  1058. )
  1059. primary_key = result.inserted_primary_key
  1060. for pk, col in zip(
  1061. primary_key, mapper._pks_by_table[table]
  1062. ):
  1063. prop = mapper_rec._columntoproperty[col]
  1064. if (
  1065. col in value_params
  1066. or state_dict.get(prop.key) is None
  1067. ):
  1068. state_dict[prop.key] = pk
  1069. if bookkeeping:
  1070. if state:
  1071. _postfetch(
  1072. mapper_rec,
  1073. uowtransaction,
  1074. table,
  1075. state,
  1076. state_dict,
  1077. result,
  1078. result.context.compiled_parameters[0],
  1079. value_params,
  1080. False,
  1081. result.returned_defaults
  1082. if not result.context.executemany
  1083. else None,
  1084. )
  1085. else:
  1086. _postfetch_bulk_save(mapper_rec, state_dict, table)
  1087. def _emit_post_update_statements(
  1088. base_mapper, uowtransaction, mapper, table, update
  1089. ):
  1090. """Emit UPDATE statements corresponding to value lists collected
  1091. by _collect_post_update_commands()."""
  1092. execution_options = {"compiled_cache": base_mapper._compiled_cache}
  1093. needs_version_id = (
  1094. mapper.version_id_col is not None
  1095. and mapper.version_id_col in mapper._cols_by_table[table]
  1096. )
  1097. def update_stmt():
  1098. clauses = BooleanClauseList._construct_raw(operators.and_)
  1099. for col in mapper._pks_by_table[table]:
  1100. clauses.clauses.append(
  1101. col == sql.bindparam(col._label, type_=col.type)
  1102. )
  1103. if needs_version_id:
  1104. clauses.clauses.append(
  1105. mapper.version_id_col
  1106. == sql.bindparam(
  1107. mapper.version_id_col._label,
  1108. type_=mapper.version_id_col.type,
  1109. )
  1110. )
  1111. stmt = table.update().where(clauses)
  1112. if mapper.version_id_col is not None:
  1113. stmt = stmt.return_defaults(mapper.version_id_col)
  1114. return stmt
  1115. statement = base_mapper._memo(("post_update", table), update_stmt)
  1116. # execute each UPDATE in the order according to the original
  1117. # list of states to guarantee row access order, but
  1118. # also group them into common (connection, cols) sets
  1119. # to support executemany().
  1120. for key, records in groupby(
  1121. update,
  1122. lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
  1123. ):
  1124. rows = 0
  1125. records = list(records)
  1126. connection = key[0]
  1127. assert_singlerow = (
  1128. connection.dialect.supports_sane_rowcount
  1129. if mapper.version_id_col is None
  1130. else connection.dialect.supports_sane_rowcount_returning
  1131. )
  1132. assert_multirow = (
  1133. assert_singlerow
  1134. and connection.dialect.supports_sane_multi_rowcount
  1135. )
  1136. allow_multirow = not needs_version_id or assert_multirow
  1137. if not allow_multirow:
  1138. check_rowcount = assert_singlerow
  1139. for state, state_dict, mapper_rec, connection, params in records:
  1140. c = connection._execute_20(
  1141. statement, params, execution_options=execution_options
  1142. )
  1143. _postfetch_post_update(
  1144. mapper_rec,
  1145. uowtransaction,
  1146. table,
  1147. state,
  1148. state_dict,
  1149. c,
  1150. c.context.compiled_parameters[0],
  1151. )
  1152. rows += c.rowcount
  1153. else:
  1154. multiparams = [
  1155. params
  1156. for state, state_dict, mapper_rec, conn, params in records
  1157. ]
  1158. check_rowcount = assert_multirow or (
  1159. assert_singlerow and len(multiparams) == 1
  1160. )
  1161. c = connection._execute_20(
  1162. statement, multiparams, execution_options=execution_options
  1163. )
  1164. rows += c.rowcount
  1165. for state, state_dict, mapper_rec, connection, params in records:
  1166. _postfetch_post_update(
  1167. mapper_rec,
  1168. uowtransaction,
  1169. table,
  1170. state,
  1171. state_dict,
  1172. c,
  1173. c.context.compiled_parameters[0],
  1174. )
  1175. if check_rowcount:
  1176. if rows != len(records):
  1177. raise orm_exc.StaleDataError(
  1178. "UPDATE statement on table '%s' expected to "
  1179. "update %d row(s); %d were matched."
  1180. % (table.description, len(records), rows)
  1181. )
  1182. elif needs_version_id:
  1183. util.warn(
  1184. "Dialect %s does not support updated rowcount "
  1185. "- versioning cannot be verified."
  1186. % c.dialect.dialect_description
  1187. )
  1188. def _emit_delete_statements(
  1189. base_mapper, uowtransaction, mapper, table, delete
  1190. ):
  1191. """Emit DELETE statements corresponding to value lists collected
  1192. by _collect_delete_commands()."""
  1193. need_version_id = (
  1194. mapper.version_id_col is not None
  1195. and mapper.version_id_col in mapper._cols_by_table[table]
  1196. )
  1197. def delete_stmt():
  1198. clauses = BooleanClauseList._construct_raw(operators.and_)
  1199. for col in mapper._pks_by_table[table]:
  1200. clauses.clauses.append(
  1201. col == sql.bindparam(col.key, type_=col.type)
  1202. )
  1203. if need_version_id:
  1204. clauses.clauses.append(
  1205. mapper.version_id_col
  1206. == sql.bindparam(
  1207. mapper.version_id_col.key, type_=mapper.version_id_col.type
  1208. )
  1209. )
  1210. return table.delete().where(clauses)
  1211. statement = base_mapper._memo(("delete", table), delete_stmt)
  1212. for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
  1213. del_objects = [params for params, connection in recs]
  1214. execution_options = {"compiled_cache": base_mapper._compiled_cache}
  1215. expected = len(del_objects)
  1216. rows_matched = -1
  1217. only_warn = False
  1218. if (
  1219. need_version_id
  1220. and not connection.dialect.supports_sane_multi_rowcount
  1221. ):
  1222. if connection.dialect.supports_sane_rowcount:
  1223. rows_matched = 0
  1224. # execute deletes individually so that versioned
  1225. # rows can be verified
  1226. for params in del_objects:
  1227. c = connection._execute_20(
  1228. statement, params, execution_options=execution_options
  1229. )
  1230. rows_matched += c.rowcount
  1231. else:
  1232. util.warn(
  1233. "Dialect %s does not support deleted rowcount "
  1234. "- versioning cannot be verified."
  1235. % connection.dialect.dialect_description
  1236. )
  1237. connection._execute_20(
  1238. statement, del_objects, execution_options=execution_options
  1239. )
  1240. else:
  1241. c = connection._execute_20(
  1242. statement, del_objects, execution_options=execution_options
  1243. )
  1244. if not need_version_id:
  1245. only_warn = True
  1246. rows_matched = c.rowcount
  1247. if (
  1248. base_mapper.confirm_deleted_rows
  1249. and rows_matched > -1
  1250. and expected != rows_matched
  1251. and (
  1252. connection.dialect.supports_sane_multi_rowcount
  1253. or len(del_objects) == 1
  1254. )
  1255. ):
  1256. # TODO: why does this "only warn" if versioning is turned off,
  1257. # whereas the UPDATE raises?
  1258. if only_warn:
  1259. util.warn(
  1260. "DELETE statement on table '%s' expected to "
  1261. "delete %d row(s); %d were matched. Please set "
  1262. "confirm_deleted_rows=False within the mapper "
  1263. "configuration to prevent this warning."
  1264. % (table.description, expected, rows_matched)
  1265. )
  1266. else:
  1267. raise orm_exc.StaleDataError(
  1268. "DELETE statement on table '%s' expected to "
  1269. "delete %d row(s); %d were matched. Please set "
  1270. "confirm_deleted_rows=False within the mapper "
  1271. "configuration to prevent this warning."
  1272. % (table.description, expected, rows_matched)
  1273. )
  1274. def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
  1275. """finalize state on states that have been inserted or updated,
  1276. including calling after_insert/after_update events.
  1277. """
  1278. for state, state_dict, mapper, connection, has_identity in states:
  1279. if mapper._readonly_props:
  1280. readonly = state.unmodified_intersection(
  1281. [
  1282. p.key
  1283. for p in mapper._readonly_props
  1284. if (
  1285. p.expire_on_flush
  1286. and (not p.deferred or p.key in state.dict)
  1287. )
  1288. or (
  1289. not p.expire_on_flush
  1290. and not p.deferred
  1291. and p.key not in state.dict
  1292. )
  1293. ]
  1294. )
  1295. if readonly:
  1296. state._expire_attributes(state.dict, readonly)
  1297. # if eager_defaults option is enabled, load
  1298. # all expired cols. Else if we have a version_id_col, make sure
  1299. # it isn't expired.
  1300. toload_now = []
  1301. if base_mapper.eager_defaults:
  1302. toload_now.extend(
  1303. state._unloaded_non_object.intersection(
  1304. mapper._server_default_plus_onupdate_propkeys
  1305. )
  1306. )
  1307. if (
  1308. mapper.version_id_col is not None
  1309. and mapper.version_id_generator is False
  1310. ):
  1311. if mapper._version_id_prop.key in state.unloaded:
  1312. toload_now.extend([mapper._version_id_prop.key])
  1313. if toload_now:
  1314. state.key = base_mapper._identity_key_from_state(state)
  1315. stmt = future.select(mapper).set_label_style(
  1316. LABEL_STYLE_TABLENAME_PLUS_COL
  1317. )
  1318. loading.load_on_ident(
  1319. uowtransaction.session,
  1320. stmt,
  1321. state.key,
  1322. refresh_state=state,
  1323. only_load_props=toload_now,
  1324. )
  1325. # call after_XXX extensions
  1326. if not has_identity:
  1327. mapper.dispatch.after_insert(mapper, connection, state)
  1328. else:
  1329. mapper.dispatch.after_update(mapper, connection, state)
  1330. if (
  1331. mapper.version_id_generator is False
  1332. and mapper.version_id_col is not None
  1333. ):
  1334. if state_dict[mapper._version_id_prop.key] is None:
  1335. raise orm_exc.FlushError(
  1336. "Instance does not contain a non-NULL version value"
  1337. )
  1338. def _postfetch_post_update(
  1339. mapper, uowtransaction, table, state, dict_, result, params
  1340. ):
  1341. if uowtransaction.is_deleted(state):
  1342. return
  1343. prefetch_cols = result.context.compiled.prefetch
  1344. postfetch_cols = result.context.compiled.postfetch
  1345. if (
  1346. mapper.version_id_col is not None
  1347. and mapper.version_id_col in mapper._cols_by_table[table]
  1348. ):
  1349. prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
  1350. refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
  1351. if refresh_flush:
  1352. load_evt_attrs = []
  1353. for c in prefetch_cols:
  1354. if c.key in params and c in mapper._columntoproperty:
  1355. dict_[mapper._columntoproperty[c].key] = params[c.key]
  1356. if refresh_flush:
  1357. load_evt_attrs.append(mapper._columntoproperty[c].key)
  1358. if refresh_flush and load_evt_attrs:
  1359. mapper.class_manager.dispatch.refresh_flush(
  1360. state, uowtransaction, load_evt_attrs
  1361. )
  1362. if postfetch_cols:
  1363. state._expire_attributes(
  1364. state.dict,
  1365. [
  1366. mapper._columntoproperty[c].key
  1367. for c in postfetch_cols
  1368. if c in mapper._columntoproperty
  1369. ],
  1370. )
  1371. def _postfetch(
  1372. mapper,
  1373. uowtransaction,
  1374. table,
  1375. state,
  1376. dict_,
  1377. result,
  1378. params,
  1379. value_params,
  1380. isupdate,
  1381. returned_defaults,
  1382. ):
  1383. """Expire attributes in need of newly persisted database state,
  1384. after an INSERT or UPDATE statement has proceeded for that
  1385. state."""
  1386. prefetch_cols = result.context.compiled.prefetch
  1387. postfetch_cols = result.context.compiled.postfetch
  1388. returning_cols = result.context.compiled.returning
  1389. if (
  1390. mapper.version_id_col is not None
  1391. and mapper.version_id_col in mapper._cols_by_table[table]
  1392. ):
  1393. prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
  1394. refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
  1395. if refresh_flush:
  1396. load_evt_attrs = []
  1397. if returning_cols:
  1398. row = returned_defaults
  1399. if row is not None:
  1400. for row_value, col in zip(row, returning_cols):
  1401. # pk cols returned from insert are handled
  1402. # distinctly, don't step on the values here
  1403. if col.primary_key and result.context.isinsert:
  1404. continue
  1405. # note that columns can be in the "return defaults" that are
  1406. # not mapped to this mapper, typically because they are
  1407. # "excluded", which can be specified directly or also occurs
  1408. # when using declarative w/ single table inheritance
  1409. prop = mapper._columntoproperty.get(col)
  1410. if prop:
  1411. dict_[prop.key] = row_value
  1412. if refresh_flush:
  1413. load_evt_attrs.append(prop.key)
  1414. for c in prefetch_cols:
  1415. if c.key in params and c in mapper._columntoproperty:
  1416. dict_[mapper._columntoproperty[c].key] = params[c.key]
  1417. if refresh_flush:
  1418. load_evt_attrs.append(mapper._columntoproperty[c].key)
  1419. if refresh_flush and load_evt_attrs:
  1420. mapper.class_manager.dispatch.refresh_flush(
  1421. state, uowtransaction, load_evt_attrs
  1422. )
  1423. if isupdate and value_params:
  1424. # explicitly suit the use case specified by
  1425. # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
  1426. # database which are set to themselves in order to do a version bump.
  1427. postfetch_cols.extend(
  1428. [
  1429. col
  1430. for col in value_params
  1431. if col.primary_key and col not in returning_cols
  1432. ]
  1433. )
  1434. if postfetch_cols:
  1435. state._expire_attributes(
  1436. state.dict,
  1437. [
  1438. mapper._columntoproperty[c].key
  1439. for c in postfetch_cols
  1440. if c in mapper._columntoproperty
  1441. ],
  1442. )
  1443. # synchronize newly inserted ids from one table to the next
  1444. # TODO: this still goes a little too often. would be nice to
  1445. # have definitive list of "columns that changed" here
  1446. for m, equated_pairs in mapper._table_to_equated[table]:
  1447. sync.populate(
  1448. state,
  1449. m,
  1450. state,
  1451. m,
  1452. equated_pairs,
  1453. uowtransaction,
  1454. mapper.passive_updates,
  1455. )
  1456. def _postfetch_bulk_save(mapper, dict_, table):
  1457. for m, equated_pairs in mapper._table_to_equated[table]:
  1458. sync.bulk_populate_inherit_keys(dict_, m, equated_pairs)
  1459. def _connections_for_states(base_mapper, uowtransaction, states):
  1460. """Return an iterator of (state, state.dict, mapper, connection).
  1461. The states are sorted according to _sort_states, then paired
  1462. with the connection they should be using for the given
  1463. unit of work transaction.
  1464. """
  1465. # if session has a connection callable,
  1466. # organize individual states with the connection
  1467. # to use for update
  1468. if uowtransaction.session.connection_callable:
  1469. connection_callable = uowtransaction.session.connection_callable
  1470. else:
  1471. connection = uowtransaction.transaction.connection(base_mapper)
  1472. connection_callable = None
  1473. for state in _sort_states(base_mapper, states):
  1474. if connection_callable:
  1475. connection = connection_callable(base_mapper, state.obj())
  1476. mapper = state.manager.mapper
  1477. yield state, state.dict, mapper, connection
  1478. def _sort_states(mapper, states):
  1479. pending = set(states)
  1480. persistent = set(s for s in pending if s.key is not None)
  1481. pending.difference_update(persistent)
  1482. try:
  1483. persistent_sorted = sorted(
  1484. persistent, key=mapper._persistent_sortkey_fn
  1485. )
  1486. except TypeError as err:
  1487. util.raise_(
  1488. sa_exc.InvalidRequestError(
  1489. "Could not sort objects by primary key; primary key "
  1490. "values must be sortable in Python (was: %s)" % err
  1491. ),
  1492. replace_context=err,
  1493. )
  1494. return (
  1495. sorted(pending, key=operator.attrgetter("insert_order"))
  1496. + persistent_sorted
  1497. )
  1498. _EMPTY_DICT = util.immutabledict()
  1499. class BulkUDCompileState(CompileState):
  1500. class default_update_options(Options):
  1501. _synchronize_session = "evaluate"
  1502. _autoflush = True
  1503. _subject_mapper = None
  1504. _resolved_values = _EMPTY_DICT
  1505. _resolved_keys_as_propnames = _EMPTY_DICT
  1506. _value_evaluators = _EMPTY_DICT
  1507. _matched_objects = None
  1508. _matched_rows = None
  1509. _refresh_identity_token = None
  1510. @classmethod
  1511. def orm_pre_session_exec(
  1512. cls,
  1513. session,
  1514. statement,
  1515. params,
  1516. execution_options,
  1517. bind_arguments,
  1518. is_reentrant_invoke,
  1519. ):
  1520. if is_reentrant_invoke:
  1521. return statement, execution_options
  1522. (
  1523. update_options,
  1524. execution_options,
  1525. ) = BulkUDCompileState.default_update_options.from_execution_options(
  1526. "_sa_orm_update_options",
  1527. {"synchronize_session"},
  1528. execution_options,
  1529. statement._execution_options,
  1530. )
  1531. sync = update_options._synchronize_session
  1532. if sync is not None:
  1533. if sync not in ("evaluate", "fetch", False):
  1534. raise sa_exc.ArgumentError(
  1535. "Valid strategies for session synchronization "
  1536. "are 'evaluate', 'fetch', False"
  1537. )
  1538. bind_arguments["clause"] = statement
  1539. try:
  1540. plugin_subject = statement._propagate_attrs["plugin_subject"]
  1541. except KeyError:
  1542. assert False, "statement had 'orm' plugin but no plugin_subject"
  1543. else:
  1544. bind_arguments["mapper"] = plugin_subject.mapper
  1545. update_options += {"_subject_mapper": plugin_subject.mapper}
  1546. if update_options._autoflush:
  1547. session._autoflush()
  1548. statement = statement._annotate(
  1549. {"synchronize_session": update_options._synchronize_session}
  1550. )
  1551. # this stage of the execution is called before the do_orm_execute event
  1552. # hook. meaning for an extension like horizontal sharding, this step
  1553. # happens before the extension splits out into multiple backends and
  1554. # runs only once. if we do pre_sync_fetch, we execute a SELECT
  1555. # statement, which the horizontal sharding extension splits amongst the
  1556. # shards and combines the results together.
  1557. if update_options._synchronize_session == "evaluate":
  1558. update_options = cls._do_pre_synchronize_evaluate(
  1559. session,
  1560. statement,
  1561. params,
  1562. execution_options,
  1563. bind_arguments,
  1564. update_options,
  1565. )
  1566. elif update_options._synchronize_session == "fetch":
  1567. update_options = cls._do_pre_synchronize_fetch(
  1568. session,
  1569. statement,
  1570. params,
  1571. execution_options,
  1572. bind_arguments,
  1573. update_options,
  1574. )
  1575. return (
  1576. statement,
  1577. util.immutabledict(execution_options).union(
  1578. dict(_sa_orm_update_options=update_options)
  1579. ),
  1580. )
  1581. @classmethod
  1582. def orm_setup_cursor_result(
  1583. cls,
  1584. session,
  1585. statement,
  1586. params,
  1587. execution_options,
  1588. bind_arguments,
  1589. result,
  1590. ):
  1591. # this stage of the execution is called after the
  1592. # do_orm_execute event hook. meaning for an extension like
  1593. # horizontal sharding, this step happens *within* the horizontal
  1594. # sharding event handler which calls session.execute() re-entrantly
  1595. # and will occur for each backend individually.
  1596. # the sharding extension then returns its own merged result from the
  1597. # individual ones we return here.
  1598. update_options = execution_options["_sa_orm_update_options"]
  1599. if update_options._synchronize_session == "evaluate":
  1600. cls._do_post_synchronize_evaluate(session, result, update_options)
  1601. elif update_options._synchronize_session == "fetch":
  1602. cls._do_post_synchronize_fetch(session, result, update_options)
  1603. return result
  1604. @classmethod
  1605. def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
  1606. """Apply extra criteria filtering.
  1607. For all distinct single-table-inheritance mappers represented in the
  1608. table being updated or deleted, produce additional WHERE criteria such
  1609. that only the appropriate subtypes are selected from the total results.
  1610. Additionally, add WHERE criteria originating from LoaderCriteriaOptions
  1611. collected from the statement.
  1612. """
  1613. return_crit = ()
  1614. adapter = ext_info._adapter if ext_info.is_aliased_class else None
  1615. if (
  1616. "additional_entity_criteria",
  1617. ext_info.mapper,
  1618. ) in global_attributes:
  1619. return_crit += tuple(
  1620. ae._resolve_where_criteria(ext_info)
  1621. for ae in global_attributes[
  1622. ("additional_entity_criteria", ext_info.mapper)
  1623. ]
  1624. if ae.include_aliases or ae.entity is ext_info
  1625. )
  1626. if ext_info.mapper._single_table_criterion is not None:
  1627. return_crit += (ext_info.mapper._single_table_criterion,)
  1628. if adapter:
  1629. return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
  1630. return return_crit
  1631. @classmethod
  1632. def _do_pre_synchronize_evaluate(
  1633. cls,
  1634. session,
  1635. statement,
  1636. params,
  1637. execution_options,
  1638. bind_arguments,
  1639. update_options,
  1640. ):
  1641. mapper = update_options._subject_mapper
  1642. target_cls = mapper.class_
  1643. value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT
  1644. try:
  1645. evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
  1646. crit = ()
  1647. if statement._where_criteria:
  1648. crit += statement._where_criteria
  1649. global_attributes = {}
  1650. for opt in statement._with_options:
  1651. if opt._is_criteria_option:
  1652. opt.get_global_criteria(global_attributes)
  1653. if global_attributes:
  1654. crit += cls._adjust_for_extra_criteria(
  1655. global_attributes, mapper
  1656. )
  1657. if crit:
  1658. eval_condition = evaluator_compiler.process(*crit)
  1659. else:
  1660. def eval_condition(obj):
  1661. return True
  1662. except evaluator.UnevaluatableError as err:
  1663. util.raise_(
  1664. sa_exc.InvalidRequestError(
  1665. 'Could not evaluate current criteria in Python: "%s". '
  1666. "Specify 'fetch' or False for the "
  1667. "synchronize_session execution option." % err
  1668. ),
  1669. from_=err,
  1670. )
  1671. if statement.__visit_name__ == "lambda_element":
  1672. # ._resolved is called on every LambdaElement in order to
  1673. # generate the cache key, so this access does not add
  1674. # additional expense
  1675. effective_statement = statement._resolved
  1676. else:
  1677. effective_statement = statement
  1678. if effective_statement.__visit_name__ == "update":
  1679. resolved_values = cls._get_resolved_values(
  1680. mapper, effective_statement
  1681. )
  1682. value_evaluators = {}
  1683. resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
  1684. mapper, resolved_values
  1685. )
  1686. for key, value in resolved_keys_as_propnames:
  1687. try:
  1688. _evaluator = evaluator_compiler.process(
  1689. coercions.expect(roles.ExpressionElementRole, value)
  1690. )
  1691. except evaluator.UnevaluatableError:
  1692. pass
  1693. else:
  1694. value_evaluators[key] = _evaluator
  1695. # TODO: detect when the where clause is a trivial primary key match.
  1696. matched_objects = [
  1697. state.obj()
  1698. for state in session.identity_map.all_states()
  1699. if state.mapper.isa(mapper)
  1700. and not state.expired
  1701. and eval_condition(state.obj())
  1702. and (
  1703. update_options._refresh_identity_token is None
  1704. # TODO: coverage for the case where horizontal sharding
  1705. # invokes an update() or delete() given an explicit identity
  1706. # token up front
  1707. or state.identity_token
  1708. == update_options._refresh_identity_token
  1709. )
  1710. ]
  1711. return update_options + {
  1712. "_matched_objects": matched_objects,
  1713. "_value_evaluators": value_evaluators,
  1714. "_resolved_keys_as_propnames": resolved_keys_as_propnames,
  1715. }
  1716. @classmethod
  1717. def _get_resolved_values(cls, mapper, statement):
  1718. if statement._multi_values:
  1719. return []
  1720. elif statement._ordered_values:
  1721. iterator = statement._ordered_values
  1722. elif statement._values:
  1723. iterator = statement._values.items()
  1724. else:
  1725. return []
  1726. values = []
  1727. if iterator:
  1728. for k, v in iterator:
  1729. if mapper:
  1730. if isinstance(k, util.string_types):
  1731. desc = _entity_namespace_key(mapper, k)
  1732. values.extend(desc._bulk_update_tuples(v))
  1733. elif "entity_namespace" in k._annotations:
  1734. k_anno = k._annotations
  1735. attr = _entity_namespace_key(
  1736. k_anno["entity_namespace"], k_anno["proxy_key"]
  1737. )
  1738. values.extend(attr._bulk_update_tuples(v))
  1739. else:
  1740. values.append((k, v))
  1741. else:
  1742. values.append((k, v))
  1743. return values
  1744. @classmethod
  1745. def _resolved_keys_as_propnames(cls, mapper, resolved_values):
  1746. values = []
  1747. for k, v in resolved_values:
  1748. if isinstance(k, attributes.QueryableAttribute):
  1749. values.append((k.key, v))
  1750. continue
  1751. elif hasattr(k, "__clause_element__"):
  1752. k = k.__clause_element__()
  1753. if mapper and isinstance(k, expression.ColumnElement):
  1754. try:
  1755. attr = mapper._columntoproperty[k]
  1756. except orm_exc.UnmappedColumnError:
  1757. pass
  1758. else:
  1759. values.append((attr.key, v))
  1760. else:
  1761. raise sa_exc.InvalidRequestError(
  1762. "Invalid expression type: %r" % k
  1763. )
  1764. return values
  1765. @classmethod
  1766. def _do_pre_synchronize_fetch(
  1767. cls,
  1768. session,
  1769. statement,
  1770. params,
  1771. execution_options,
  1772. bind_arguments,
  1773. update_options,
  1774. ):
  1775. mapper = update_options._subject_mapper
  1776. select_stmt = (
  1777. select(*(mapper.primary_key + (mapper.select_identity_token,)))
  1778. .select_from(mapper)
  1779. .options(*statement._with_options)
  1780. )
  1781. select_stmt._where_criteria = statement._where_criteria
  1782. def skip_for_full_returning(orm_context):
  1783. bind = orm_context.session.get_bind(**orm_context.bind_arguments)
  1784. if bind.dialect.full_returning:
  1785. return _result.null_result()
  1786. else:
  1787. return None
  1788. result = session.execute(
  1789. select_stmt,
  1790. params,
  1791. execution_options,
  1792. bind_arguments,
  1793. _add_event=skip_for_full_returning,
  1794. )
  1795. matched_rows = result.fetchall()
  1796. value_evaluators = _EMPTY_DICT
  1797. if statement.__visit_name__ == "lambda_element":
  1798. # ._resolved is called on every LambdaElement in order to
  1799. # generate the cache key, so this access does not add
  1800. # additional expense
  1801. effective_statement = statement._resolved
  1802. else:
  1803. effective_statement = statement
  1804. if effective_statement.__visit_name__ == "update":
  1805. target_cls = mapper.class_
  1806. evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
  1807. resolved_values = cls._get_resolved_values(
  1808. mapper, effective_statement
  1809. )
  1810. resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
  1811. mapper, resolved_values
  1812. )
  1813. resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
  1814. mapper, resolved_values
  1815. )
  1816. value_evaluators = {}
  1817. for key, value in resolved_keys_as_propnames:
  1818. try:
  1819. _evaluator = evaluator_compiler.process(
  1820. coercions.expect(roles.ExpressionElementRole, value)
  1821. )
  1822. except evaluator.UnevaluatableError:
  1823. pass
  1824. else:
  1825. value_evaluators[key] = _evaluator
  1826. else:
  1827. resolved_keys_as_propnames = _EMPTY_DICT
  1828. return update_options + {
  1829. "_value_evaluators": value_evaluators,
  1830. "_matched_rows": matched_rows,
  1831. "_resolved_keys_as_propnames": resolved_keys_as_propnames,
  1832. }
  1833. @CompileState.plugin_for("orm", "update")
  1834. class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
  1835. @classmethod
  1836. def create_for_statement(cls, statement, compiler, **kw):
  1837. self = cls.__new__(cls)
  1838. ext_info = statement.table._annotations["parententity"]
  1839. self.mapper = mapper = ext_info.mapper
  1840. self.extra_criteria_entities = {}
  1841. self._resolved_values = cls._get_resolved_values(mapper, statement)
  1842. extra_criteria_attributes = {}
  1843. for opt in statement._with_options:
  1844. if opt._is_criteria_option:
  1845. opt.get_global_criteria(extra_criteria_attributes)
  1846. if not statement._preserve_parameter_order and statement._values:
  1847. self._resolved_values = dict(self._resolved_values)
  1848. new_stmt = sql.Update.__new__(sql.Update)
  1849. new_stmt.__dict__.update(statement.__dict__)
  1850. new_stmt.table = mapper.local_table
  1851. # note if the statement has _multi_values, these
  1852. # are passed through to the new statement, which will then raise
  1853. # InvalidRequestError because UPDATE doesn't support multi_values
  1854. # right now.
  1855. if statement._ordered_values:
  1856. new_stmt._ordered_values = self._resolved_values
  1857. elif statement._values:
  1858. new_stmt._values = self._resolved_values
  1859. new_crit = cls._adjust_for_extra_criteria(
  1860. extra_criteria_attributes, mapper
  1861. )
  1862. if new_crit:
  1863. new_stmt = new_stmt.where(*new_crit)
  1864. # if we are against a lambda statement we might not be the
  1865. # topmost object that received per-execute annotations
  1866. if (
  1867. compiler._annotations.get("synchronize_session", None) == "fetch"
  1868. and compiler.dialect.full_returning
  1869. ):
  1870. if new_stmt._returning:
  1871. raise sa_exc.InvalidRequestError(
  1872. "Can't use synchronize_session='fetch' "
  1873. "with explicit returning()"
  1874. )
  1875. new_stmt = new_stmt.returning(*mapper.primary_key)
  1876. UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
  1877. return self
  1878. @classmethod
  1879. def _do_post_synchronize_evaluate(cls, session, result, update_options):
  1880. states = set()
  1881. evaluated_keys = list(update_options._value_evaluators.keys())
  1882. values = update_options._resolved_keys_as_propnames
  1883. attrib = set(k for k, v in values)
  1884. for obj in update_options._matched_objects:
  1885. state, dict_ = (
  1886. attributes.instance_state(obj),
  1887. attributes.instance_dict(obj),
  1888. )
  1889. # the evaluated states were gathered across all identity tokens.
  1890. # however the post_sync events are called per identity token,
  1891. # so filter.
  1892. if (
  1893. update_options._refresh_identity_token is not None
  1894. and state.identity_token
  1895. != update_options._refresh_identity_token
  1896. ):
  1897. continue
  1898. # only evaluate unmodified attributes
  1899. to_evaluate = state.unmodified.intersection(evaluated_keys)
  1900. for key in to_evaluate:
  1901. if key in dict_:
  1902. dict_[key] = update_options._value_evaluators[key](obj)
  1903. state.manager.dispatch.refresh(state, None, to_evaluate)
  1904. state._commit(dict_, list(to_evaluate))
  1905. to_expire = attrib.intersection(dict_).difference(to_evaluate)
  1906. if to_expire:
  1907. state._expire_attributes(dict_, to_expire)
  1908. states.add(state)
  1909. session._register_altered(states)
  1910. @classmethod
  1911. def _do_post_synchronize_fetch(cls, session, result, update_options):
  1912. target_mapper = update_options._subject_mapper
  1913. states = set()
  1914. evaluated_keys = list(update_options._value_evaluators.keys())
  1915. if result.returns_rows:
  1916. matched_rows = [
  1917. tuple(row) + (update_options._refresh_identity_token,)
  1918. for row in result.all()
  1919. ]
  1920. else:
  1921. matched_rows = update_options._matched_rows
  1922. objs = [
  1923. session.identity_map[identity_key]
  1924. for identity_key in [
  1925. target_mapper.identity_key_from_primary_key(
  1926. list(primary_key),
  1927. identity_token=identity_token,
  1928. )
  1929. for primary_key, identity_token in [
  1930. (row[0:-1], row[-1]) for row in matched_rows
  1931. ]
  1932. if update_options._refresh_identity_token is None
  1933. or identity_token == update_options._refresh_identity_token
  1934. ]
  1935. if identity_key in session.identity_map
  1936. ]
  1937. values = update_options._resolved_keys_as_propnames
  1938. attrib = set(k for k, v in values)
  1939. for obj in objs:
  1940. state, dict_ = (
  1941. attributes.instance_state(obj),
  1942. attributes.instance_dict(obj),
  1943. )
  1944. to_evaluate = state.unmodified.intersection(evaluated_keys)
  1945. for key in to_evaluate:
  1946. if key in dict_:
  1947. dict_[key] = update_options._value_evaluators[key](obj)
  1948. state.manager.dispatch.refresh(state, None, to_evaluate)
  1949. state._commit(dict_, list(to_evaluate))
  1950. to_expire = attrib.intersection(dict_).difference(to_evaluate)
  1951. if to_expire:
  1952. state._expire_attributes(dict_, to_expire)
  1953. states.add(state)
  1954. session._register_altered(states)
  1955. @CompileState.plugin_for("orm", "delete")
  1956. class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
  1957. @classmethod
  1958. def create_for_statement(cls, statement, compiler, **kw):
  1959. self = cls.__new__(cls)
  1960. ext_info = statement.table._annotations["parententity"]
  1961. self.mapper = mapper = ext_info.mapper
  1962. self.extra_criteria_entities = {}
  1963. extra_criteria_attributes = {}
  1964. for opt in statement._with_options:
  1965. if opt._is_criteria_option:
  1966. opt.get_global_criteria(extra_criteria_attributes)
  1967. new_crit = cls._adjust_for_extra_criteria(
  1968. extra_criteria_attributes, mapper
  1969. )
  1970. if new_crit:
  1971. statement = statement.where(*new_crit)
  1972. if (
  1973. mapper
  1974. and compiler._annotations.get("synchronize_session", None)
  1975. == "fetch"
  1976. and compiler.dialect.full_returning
  1977. ):
  1978. statement = statement.returning(*mapper.primary_key)
  1979. DeleteDMLState.__init__(self, statement, compiler, **kw)
  1980. return self
  1981. @classmethod
  1982. def _do_post_synchronize_evaluate(cls, session, result, update_options):
  1983. session._remove_newly_deleted(
  1984. [
  1985. attributes.instance_state(obj)
  1986. for obj in update_options._matched_objects
  1987. ]
  1988. )
  1989. @classmethod
  1990. def _do_post_synchronize_fetch(cls, session, result, update_options):
  1991. target_mapper = update_options._subject_mapper
  1992. if result.returns_rows:
  1993. matched_rows = [
  1994. tuple(row) + (update_options._refresh_identity_token,)
  1995. for row in result.all()
  1996. ]
  1997. else:
  1998. matched_rows = update_options._matched_rows
  1999. for row in matched_rows:
  2000. primary_key = row[0:-1]
  2001. identity_token = row[-1]
  2002. # TODO: inline this and call remove_newly_deleted
  2003. # once
  2004. identity_key = target_mapper.identity_key_from_primary_key(
  2005. list(primary_key),
  2006. identity_token=identity_token,
  2007. )
  2008. if identity_key in session.identity_map:
  2009. session._remove_newly_deleted(
  2010. [
  2011. attributes.instance_state(
  2012. session.identity_map[identity_key]
  2013. )
  2014. ]
  2015. )