|
- # orm/persistence.py
- # Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
- # <see AUTHORS file>
- #
- # This module is part of SQLAlchemy and is released under
- # the MIT License: http://www.opensource.org/licenses/mit-license.php
-
- """private module containing functions used to emit INSERT, UPDATE
- and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
- mappers.
-
- The functions here are called only by the unit of work functions
- in unitofwork.py.
-
- """
-
- from itertools import chain
- from itertools import groupby
- import operator
-
- from . import attributes
- from . import evaluator
- from . import exc as orm_exc
- from . import loading
- from . import sync
- from .base import state_str
- from .. import exc as sa_exc
- from .. import future
- from .. import sql
- from .. import util
- from ..engine import result as _result
- from ..sql import coercions
- from ..sql import expression
- from ..sql import operators
- from ..sql import roles
- from ..sql import select
- from ..sql.base import _entity_namespace_key
- from ..sql.base import CompileState
- from ..sql.base import Options
- from ..sql.dml import DeleteDMLState
- from ..sql.dml import UpdateDMLState
- from ..sql.elements import BooleanClauseList
- from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
-
-
- def _bulk_insert(
- mapper,
- mappings,
- session_transaction,
- isstates,
- return_defaults,
- render_nulls,
- ):
- base_mapper = mapper.base_mapper
-
- if session_transaction.session.connection_callable:
- raise NotImplementedError(
- "connection_callable / per-instance sharding "
- "not supported in bulk_insert()"
- )
-
- if isstates:
- if return_defaults:
- states = [(state, state.dict) for state in mappings]
- mappings = [dict_ for (state, dict_) in states]
- else:
- mappings = [state.dict for state in mappings]
- else:
- mappings = list(mappings)
-
- connection = session_transaction.connection(base_mapper)
- for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper):
- continue
-
- records = (
- (
- None,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_pks,
- has_all_defaults,
- )
- for (
- state,
- state_dict,
- params,
- mp,
- conn,
- value_params,
- has_all_pks,
- has_all_defaults,
- ) in _collect_insert_commands(
- table,
- ((None, mapping, mapper, connection) for mapping in mappings),
- bulk=True,
- return_defaults=return_defaults,
- render_nulls=render_nulls,
- )
- )
- _emit_insert_statements(
- base_mapper,
- None,
- super_mapper,
- table,
- records,
- bookkeeping=return_defaults,
- )
-
- if return_defaults and isstates:
- identity_cls = mapper._identity_class
- identity_props = [p.key for p in mapper._identity_key_props]
- for state, dict_ in states:
- state.key = (
- identity_cls,
- tuple([dict_[key] for key in identity_props]),
- )
-
-
- def _bulk_update(
- mapper, mappings, session_transaction, isstates, update_changed_only
- ):
- base_mapper = mapper.base_mapper
-
- search_keys = mapper._primary_key_propkeys
- if mapper._version_id_prop:
- search_keys = {mapper._version_id_prop.key}.union(search_keys)
-
- def _changed_dict(mapper, state):
- return dict(
- (k, v)
- for k, v in state.dict.items()
- if k in state.committed_state or k in search_keys
- )
-
- if isstates:
- if update_changed_only:
- mappings = [_changed_dict(mapper, state) for state in mappings]
- else:
- mappings = [state.dict for state in mappings]
- else:
- mappings = list(mappings)
-
- if session_transaction.session.connection_callable:
- raise NotImplementedError(
- "connection_callable / per-instance sharding "
- "not supported in bulk_update()"
- )
-
- connection = session_transaction.connection(base_mapper)
-
- for table, super_mapper in base_mapper._sorted_tables.items():
- if not mapper.isa(super_mapper):
- continue
-
- records = _collect_update_commands(
- None,
- table,
- (
- (
- None,
- mapping,
- mapper,
- connection,
- (
- mapping[mapper._version_id_prop.key]
- if mapper._version_id_prop
- else None
- ),
- )
- for mapping in mappings
- ),
- bulk=True,
- )
-
- _emit_update_statements(
- base_mapper,
- None,
- super_mapper,
- table,
- records,
- bookkeeping=False,
- )
-
-
- def save_obj(base_mapper, states, uowtransaction, single=False):
- """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
- of objects.
-
- This is called within the context of a UOWTransaction during a
- flush operation, given a list of states to be flushed. The
- base mapper in an inheritance hierarchy handles the inserts/
- updates for all descendant mappers.
-
- """
-
- # if batch=false, call _save_obj separately for each object
- if not single and not base_mapper.batch:
- for state in _sort_states(base_mapper, states):
- save_obj(base_mapper, [state], uowtransaction, single=True)
- return
-
- states_to_update = []
- states_to_insert = []
-
- for (
- state,
- dict_,
- mapper,
- connection,
- has_identity,
- row_switch,
- update_version_id,
- ) in _organize_states_for_save(base_mapper, states, uowtransaction):
- if has_identity or row_switch:
- states_to_update.append(
- (state, dict_, mapper, connection, update_version_id)
- )
- else:
- states_to_insert.append((state, dict_, mapper, connection))
-
- for table, mapper in base_mapper._sorted_tables.items():
- if table not in mapper._pks_by_table:
- continue
- insert = _collect_insert_commands(table, states_to_insert)
-
- update = _collect_update_commands(
- uowtransaction, table, states_to_update
- )
-
- _emit_update_statements(
- base_mapper,
- uowtransaction,
- mapper,
- table,
- update,
- )
-
- _emit_insert_statements(
- base_mapper,
- uowtransaction,
- mapper,
- table,
- insert,
- )
-
- _finalize_insert_update_commands(
- base_mapper,
- uowtransaction,
- chain(
- (
- (state, state_dict, mapper, connection, False)
- for (state, state_dict, mapper, connection) in states_to_insert
- ),
- (
- (state, state_dict, mapper, connection, True)
- for (
- state,
- state_dict,
- mapper,
- connection,
- update_version_id,
- ) in states_to_update
- ),
- ),
- )
-
-
- def post_update(base_mapper, states, uowtransaction, post_update_cols):
- """Issue UPDATE statements on behalf of a relationship() which
- specifies post_update.
-
- """
-
- states_to_update = list(
- _organize_states_for_post_update(base_mapper, states, uowtransaction)
- )
-
- for table, mapper in base_mapper._sorted_tables.items():
- if table not in mapper._pks_by_table:
- continue
-
- update = (
- (
- state,
- state_dict,
- sub_mapper,
- connection,
- mapper._get_committed_state_attr_by_column(
- state, state_dict, mapper.version_id_col
- )
- if mapper.version_id_col is not None
- else None,
- )
- for state, state_dict, sub_mapper, connection in states_to_update
- if table in sub_mapper._pks_by_table
- )
-
- update = _collect_post_update_commands(
- base_mapper, uowtransaction, table, update, post_update_cols
- )
-
- _emit_post_update_statements(
- base_mapper,
- uowtransaction,
- mapper,
- table,
- update,
- )
-
-
- def delete_obj(base_mapper, states, uowtransaction):
- """Issue ``DELETE`` statements for a list of objects.
-
- This is called within the context of a UOWTransaction during a
- flush operation.
-
- """
-
- states_to_delete = list(
- _organize_states_for_delete(base_mapper, states, uowtransaction)
- )
-
- table_to_mapper = base_mapper._sorted_tables
-
- for table in reversed(list(table_to_mapper.keys())):
- mapper = table_to_mapper[table]
- if table not in mapper._pks_by_table:
- continue
- elif mapper.inherits and mapper.passive_deletes:
- continue
-
- delete = _collect_delete_commands(
- base_mapper, uowtransaction, table, states_to_delete
- )
-
- _emit_delete_statements(
- base_mapper,
- uowtransaction,
- mapper,
- table,
- delete,
- )
-
- for (
- state,
- state_dict,
- mapper,
- connection,
- update_version_id,
- ) in states_to_delete:
- mapper.dispatch.after_delete(mapper, connection, state)
-
-
- def _organize_states_for_save(base_mapper, states, uowtransaction):
- """Make an initial pass across a set of states for INSERT or
- UPDATE.
-
- This includes splitting out into distinct lists for
- each, calling before_insert/before_update, obtaining
- key information for each state including its dictionary,
- mapper, the connection to use for the execution per state,
- and the identity flag.
-
- """
-
- for state, dict_, mapper, connection in _connections_for_states(
- base_mapper, uowtransaction, states
- ):
-
- has_identity = bool(state.key)
-
- instance_key = state.key or mapper._identity_key_from_state(state)
-
- row_switch = update_version_id = None
-
- # call before_XXX extensions
- if not has_identity:
- mapper.dispatch.before_insert(mapper, connection, state)
- else:
- mapper.dispatch.before_update(mapper, connection, state)
-
- if mapper._validate_polymorphic_identity:
- mapper._validate_polymorphic_identity(mapper, state, dict_)
-
- # detect if we have a "pending" instance (i.e. has
- # no instance_key attached to it), and another instance
- # with the same identity key already exists as persistent.
- # convert to an UPDATE if so.
- if (
- not has_identity
- and instance_key in uowtransaction.session.identity_map
- ):
- instance = uowtransaction.session.identity_map[instance_key]
- existing = attributes.instance_state(instance)
-
- if not uowtransaction.was_already_deleted(existing):
- if not uowtransaction.is_deleted(existing):
- util.warn(
- "New instance %s with identity key %s conflicts "
- "with persistent instance %s"
- % (state_str(state), instance_key, state_str(existing))
- )
- else:
- base_mapper._log_debug(
- "detected row switch for identity %s. "
- "will update %s, remove %s from "
- "transaction",
- instance_key,
- state_str(state),
- state_str(existing),
- )
-
- # remove the "delete" flag from the existing element
- uowtransaction.remove_state_actions(existing)
- row_switch = existing
-
- if (has_identity or row_switch) and mapper.version_id_col is not None:
- update_version_id = mapper._get_committed_state_attr_by_column(
- row_switch if row_switch else state,
- row_switch.dict if row_switch else dict_,
- mapper.version_id_col,
- )
-
- yield (
- state,
- dict_,
- mapper,
- connection,
- has_identity,
- row_switch,
- update_version_id,
- )
-
-
- def _organize_states_for_post_update(base_mapper, states, uowtransaction):
- """Make an initial pass across a set of states for UPDATE
- corresponding to post_update.
-
- This includes obtaining key information for each state
- including its dictionary, mapper, the connection to use for
- the execution per state.
-
- """
- return _connections_for_states(base_mapper, uowtransaction, states)
-
-
- def _organize_states_for_delete(base_mapper, states, uowtransaction):
- """Make an initial pass across a set of states for DELETE.
-
- This includes calling out before_delete and obtaining
- key information for each state including its dictionary,
- mapper, the connection to use for the execution per state.
-
- """
- for state, dict_, mapper, connection in _connections_for_states(
- base_mapper, uowtransaction, states
- ):
-
- mapper.dispatch.before_delete(mapper, connection, state)
-
- if mapper.version_id_col is not None:
- update_version_id = mapper._get_committed_state_attr_by_column(
- state, dict_, mapper.version_id_col
- )
- else:
- update_version_id = None
-
- yield (state, dict_, mapper, connection, update_version_id)
-
-
- def _collect_insert_commands(
- table,
- states_to_insert,
- bulk=False,
- return_defaults=False,
- render_nulls=False,
- ):
- """Identify sets of values to use in INSERT statements for a
- list of states.
-
- """
- for state, state_dict, mapper, connection in states_to_insert:
- if table not in mapper._pks_by_table:
- continue
-
- params = {}
- value_params = {}
-
- propkey_to_col = mapper._propkey_to_col[table]
-
- eval_none = mapper._insert_cols_evaluating_none[table]
-
- for propkey in set(propkey_to_col).intersection(state_dict):
- value = state_dict[propkey]
- col = propkey_to_col[propkey]
- if value is None and col not in eval_none and not render_nulls:
- continue
- elif not bulk and (
- hasattr(value, "__clause_element__")
- or isinstance(value, sql.ClauseElement)
- ):
- value_params[col] = (
- value.__clause_element__()
- if hasattr(value, "__clause_element__")
- else value
- )
- else:
- params[col.key] = value
-
- if not bulk:
- # for all the columns that have no default and we don't have
- # a value and where "None" is not a special value, add
- # explicit None to the INSERT. This is a legacy behavior
- # which might be worth removing, as it should not be necessary
- # and also produces confusion, given that "missing" and None
- # now have distinct meanings
- for colkey in (
- mapper._insert_cols_as_none[table]
- .difference(params)
- .difference([c.key for c in value_params])
- ):
- params[colkey] = None
-
- if not bulk or return_defaults:
- # params are in terms of Column key objects, so
- # compare to pk_keys_by_table
- has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
-
- if mapper.base_mapper.eager_defaults:
- has_all_defaults = mapper._server_default_cols[table].issubset(
- params
- )
- else:
- has_all_defaults = True
- else:
- has_all_defaults = has_all_pks = True
-
- if (
- mapper.version_id_generator is not False
- and mapper.version_id_col is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- ):
- params[mapper.version_id_col.key] = mapper.version_id_generator(
- None
- )
-
- yield (
- state,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_pks,
- has_all_defaults,
- )
-
-
- def _collect_update_commands(
- uowtransaction, table, states_to_update, bulk=False
- ):
- """Identify sets of values to use in UPDATE statements for a
- list of states.
-
- This function works intricately with the history system
- to determine exactly what values should be updated
- as well as how the row should be matched within an UPDATE
- statement. Includes some tricky scenarios where the primary
- key of an object might have been changed.
-
- """
-
- for (
- state,
- state_dict,
- mapper,
- connection,
- update_version_id,
- ) in states_to_update:
-
- if table not in mapper._pks_by_table:
- continue
-
- pks = mapper._pks_by_table[table]
-
- value_params = {}
-
- propkey_to_col = mapper._propkey_to_col[table]
-
- if bulk:
- # keys here are mapped attribute keys, so
- # look at mapper attribute keys for pk
- params = dict(
- (propkey_to_col[propkey].key, state_dict[propkey])
- for propkey in set(propkey_to_col)
- .intersection(state_dict)
- .difference(mapper._pk_attr_keys_by_table[table])
- )
- has_all_defaults = True
- else:
- params = {}
- for propkey in set(propkey_to_col).intersection(
- state.committed_state
- ):
- value = state_dict[propkey]
- col = propkey_to_col[propkey]
-
- if hasattr(value, "__clause_element__") or isinstance(
- value, sql.ClauseElement
- ):
- value_params[col] = (
- value.__clause_element__()
- if hasattr(value, "__clause_element__")
- else value
- )
- # guard against values that generate non-__nonzero__
- # objects for __eq__()
- elif (
- state.manager[propkey].impl.is_equal(
- value, state.committed_state[propkey]
- )
- is not True
- ):
- params[col.key] = value
-
- if mapper.base_mapper.eager_defaults:
- has_all_defaults = (
- mapper._server_onupdate_default_cols[table]
- ).issubset(params)
- else:
- has_all_defaults = True
-
- if (
- update_version_id is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- ):
-
- if not bulk and not (params or value_params):
- # HACK: check for history in other tables, in case the
- # history is only in a different table than the one
- # where the version_id_col is. This logic was lost
- # from 0.9 -> 1.0.0 and restored in 1.0.6.
- for prop in mapper._columntoproperty.values():
- history = state.manager[prop.key].impl.get_history(
- state, state_dict, attributes.PASSIVE_NO_INITIALIZE
- )
- if history.added:
- break
- else:
- # no net change, break
- continue
-
- col = mapper.version_id_col
- no_params = not params and not value_params
- params[col._label] = update_version_id
-
- if (
- bulk or col.key not in params
- ) and mapper.version_id_generator is not False:
- val = mapper.version_id_generator(update_version_id)
- params[col.key] = val
- elif mapper.version_id_generator is False and no_params:
- # no version id generator, no values set on the table,
- # and version id wasn't manually incremented.
- # set version id to itself so we get an UPDATE
- # statement
- params[col.key] = update_version_id
-
- elif not (params or value_params):
- continue
-
- has_all_pks = True
- expect_pk_cascaded = False
- if bulk:
- # keys here are mapped attribute keys, so
- # look at mapper attribute keys for pk
- pk_params = dict(
- (propkey_to_col[propkey]._label, state_dict.get(propkey))
- for propkey in set(propkey_to_col).intersection(
- mapper._pk_attr_keys_by_table[table]
- )
- )
- else:
- pk_params = {}
- for col in pks:
- propkey = mapper._columntoproperty[col].key
-
- history = state.manager[propkey].impl.get_history(
- state, state_dict, attributes.PASSIVE_OFF
- )
-
- if history.added:
- if (
- not history.deleted
- or ("pk_cascaded", state, col)
- in uowtransaction.attributes
- ):
- expect_pk_cascaded = True
- pk_params[col._label] = history.added[0]
- params.pop(col.key, None)
- else:
- # else, use the old value to locate the row
- pk_params[col._label] = history.deleted[0]
- if col in value_params:
- has_all_pks = False
- else:
- pk_params[col._label] = history.unchanged[0]
- if pk_params[col._label] is None:
- raise orm_exc.FlushError(
- "Can't update table %s using NULL for primary "
- "key value on column %s" % (table, col)
- )
-
- if params or value_params:
- params.update(pk_params)
- yield (
- state,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_defaults,
- has_all_pks,
- )
- elif expect_pk_cascaded:
- # no UPDATE occurs on this table, but we expect that CASCADE rules
- # have changed the primary key of the row; propagate this event to
- # other columns that expect to have been modified. this normally
- # occurs after the UPDATE is emitted however we invoke it here
- # explicitly in the absence of our invoking an UPDATE
- for m, equated_pairs in mapper._table_to_equated[table]:
- sync.populate(
- state,
- m,
- state,
- m,
- equated_pairs,
- uowtransaction,
- mapper.passive_updates,
- )
-
-
- def _collect_post_update_commands(
- base_mapper, uowtransaction, table, states_to_update, post_update_cols
- ):
- """Identify sets of values to use in UPDATE statements for a
- list of states within a post_update operation.
-
- """
-
- for (
- state,
- state_dict,
- mapper,
- connection,
- update_version_id,
- ) in states_to_update:
-
- # assert table in mapper._pks_by_table
-
- pks = mapper._pks_by_table[table]
- params = {}
- hasdata = False
-
- for col in mapper._cols_by_table[table]:
- if col in pks:
- params[col._label] = mapper._get_state_attr_by_column(
- state, state_dict, col, passive=attributes.PASSIVE_OFF
- )
-
- elif col in post_update_cols or col.onupdate is not None:
- prop = mapper._columntoproperty[col]
- history = state.manager[prop.key].impl.get_history(
- state, state_dict, attributes.PASSIVE_NO_INITIALIZE
- )
- if history.added:
- value = history.added[0]
- params[col.key] = value
- hasdata = True
- if hasdata:
- if (
- update_version_id is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- ):
-
- col = mapper.version_id_col
- params[col._label] = update_version_id
-
- if (
- bool(state.key)
- and col.key not in params
- and mapper.version_id_generator is not False
- ):
- val = mapper.version_id_generator(update_version_id)
- params[col.key] = val
- yield state, state_dict, mapper, connection, params
-
-
- def _collect_delete_commands(
- base_mapper, uowtransaction, table, states_to_delete
- ):
- """Identify values to use in DELETE statements for a list of
- states to be deleted."""
-
- for (
- state,
- state_dict,
- mapper,
- connection,
- update_version_id,
- ) in states_to_delete:
-
- if table not in mapper._pks_by_table:
- continue
-
- params = {}
- for col in mapper._pks_by_table[table]:
- params[
- col.key
- ] = value = mapper._get_committed_state_attr_by_column(
- state, state_dict, col
- )
- if value is None:
- raise orm_exc.FlushError(
- "Can't delete from table %s "
- "using NULL for primary "
- "key value on column %s" % (table, col)
- )
-
- if (
- update_version_id is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- ):
- params[mapper.version_id_col.key] = update_version_id
- yield params, connection
-
-
- def _emit_update_statements(
- base_mapper,
- uowtransaction,
- mapper,
- table,
- update,
- bookkeeping=True,
- ):
- """Emit UPDATE statements corresponding to value lists collected
- by _collect_update_commands()."""
-
- needs_version_id = (
- mapper.version_id_col is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- )
-
- execution_options = {"compiled_cache": base_mapper._compiled_cache}
-
- def update_stmt():
- clauses = BooleanClauseList._construct_raw(operators.and_)
-
- for col in mapper._pks_by_table[table]:
- clauses.clauses.append(
- col == sql.bindparam(col._label, type_=col.type)
- )
-
- if needs_version_id:
- clauses.clauses.append(
- mapper.version_id_col
- == sql.bindparam(
- mapper.version_id_col._label,
- type_=mapper.version_id_col.type,
- )
- )
-
- stmt = table.update().where(clauses)
- return stmt
-
- cached_stmt = base_mapper._memo(("update", table), update_stmt)
-
- for (
- (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
- records,
- ) in groupby(
- update,
- lambda rec: (
- rec[4], # connection
- set(rec[2]), # set of parameter keys
- bool(rec[5]), # whether or not we have "value" parameters
- rec[6], # has_all_defaults
- rec[7], # has all pks
- ),
- ):
- rows = 0
- records = list(records)
-
- statement = cached_stmt
- return_defaults = False
-
- if not has_all_pks:
- statement = statement.return_defaults()
- return_defaults = True
- elif (
- bookkeeping
- and not has_all_defaults
- and mapper.base_mapper.eager_defaults
- ):
- statement = statement.return_defaults()
- return_defaults = True
- elif mapper.version_id_col is not None:
- statement = statement.return_defaults(mapper.version_id_col)
- return_defaults = True
-
- assert_singlerow = (
- connection.dialect.supports_sane_rowcount
- if not return_defaults
- else connection.dialect.supports_sane_rowcount_returning
- )
-
- assert_multirow = (
- assert_singlerow
- and connection.dialect.supports_sane_multi_rowcount
- )
- allow_multirow = has_all_defaults and not needs_version_id
-
- if hasvalue:
- for (
- state,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_defaults,
- has_all_pks,
- ) in records:
- c = connection._execute_20(
- statement.values(value_params),
- params,
- execution_options=execution_options,
- )
- if bookkeeping:
- _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params,
- True,
- c.returned_defaults,
- )
- rows += c.rowcount
- check_rowcount = assert_singlerow
- else:
- if not allow_multirow:
- check_rowcount = assert_singlerow
- for (
- state,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_defaults,
- has_all_pks,
- ) in records:
- c = connection._execute_20(
- statement, params, execution_options=execution_options
- )
-
- # TODO: why with bookkeeping=False?
- if bookkeeping:
- _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params,
- True,
- c.returned_defaults,
- )
- rows += c.rowcount
- else:
- multiparams = [rec[2] for rec in records]
-
- check_rowcount = assert_multirow or (
- assert_singlerow and len(multiparams) == 1
- )
-
- c = connection._execute_20(
- statement, multiparams, execution_options=execution_options
- )
-
- rows += c.rowcount
-
- for (
- state,
- state_dict,
- params,
- mapper,
- connection,
- value_params,
- has_all_defaults,
- has_all_pks,
- ) in records:
- if bookkeeping:
- _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- value_params,
- True,
- c.returned_defaults
- if not c.context.executemany
- else None,
- )
-
- if check_rowcount:
- if rows != len(records):
- raise orm_exc.StaleDataError(
- "UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched."
- % (table.description, len(records), rows)
- )
-
- elif needs_version_id:
- util.warn(
- "Dialect %s does not support updated rowcount "
- "- versioning cannot be verified."
- % c.dialect.dialect_description
- )
-
-
- def _emit_insert_statements(
- base_mapper,
- uowtransaction,
- mapper,
- table,
- insert,
- bookkeeping=True,
- ):
- """Emit INSERT statements corresponding to value lists collected
- by _collect_insert_commands()."""
-
- cached_stmt = base_mapper._memo(("insert", table), table.insert)
-
- execution_options = {"compiled_cache": base_mapper._compiled_cache}
-
- for (
- (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
- records,
- ) in groupby(
- insert,
- lambda rec: (
- rec[4], # connection
- set(rec[2]), # parameter keys
- bool(rec[5]), # whether we have "value" parameters
- rec[6],
- rec[7],
- ),
- ):
-
- statement = cached_stmt
-
- if (
- not bookkeeping
- or (
- has_all_defaults
- or not base_mapper.eager_defaults
- or not connection.dialect.implicit_returning
- )
- and has_all_pks
- and not hasvalue
- ):
- # the "we don't need newly generated values back" section.
- # here we have all the PKs, all the defaults or we don't want
- # to fetch them, or the dialect doesn't support RETURNING at all
- # so we have to post-fetch / use lastrowid anyway.
- records = list(records)
- multiparams = [rec[2] for rec in records]
-
- c = connection._execute_20(
- statement, multiparams, execution_options=execution_options
- )
-
- if bookkeeping:
- for (
- (
- state,
- state_dict,
- params,
- mapper_rec,
- conn,
- value_params,
- has_all_pks,
- has_all_defaults,
- ),
- last_inserted_params,
- ) in zip(records, c.context.compiled_parameters):
- if state:
- _postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- last_inserted_params,
- value_params,
- False,
- c.returned_defaults
- if not c.context.executemany
- else None,
- )
- else:
- _postfetch_bulk_save(mapper_rec, state_dict, table)
-
- else:
- # here, we need defaults and/or pk values back.
-
- records = list(records)
- if (
- not hasvalue
- and connection.dialect.insert_executemany_returning
- and len(records) > 1
- ):
- do_executemany = True
- else:
- do_executemany = False
-
- if not has_all_defaults and base_mapper.eager_defaults:
- statement = statement.return_defaults()
- elif mapper.version_id_col is not None:
- statement = statement.return_defaults(mapper.version_id_col)
- elif do_executemany:
- statement = statement.return_defaults(*table.primary_key)
-
- if do_executemany:
- multiparams = [rec[2] for rec in records]
-
- c = connection._execute_20(
- statement, multiparams, execution_options=execution_options
- )
-
- if bookkeeping:
- for (
- (
- state,
- state_dict,
- params,
- mapper_rec,
- conn,
- value_params,
- has_all_pks,
- has_all_defaults,
- ),
- last_inserted_params,
- inserted_primary_key,
- returned_defaults,
- ) in util.zip_longest(
- records,
- c.context.compiled_parameters,
- c.inserted_primary_key_rows,
- c.returned_defaults_rows or (),
- ):
- for pk, col in zip(
- inserted_primary_key,
- mapper._pks_by_table[table],
- ):
- prop = mapper_rec._columntoproperty[col]
- if state_dict.get(prop.key) is None:
- state_dict[prop.key] = pk
-
- if state:
- _postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- last_inserted_params,
- value_params,
- False,
- returned_defaults,
- )
- else:
- _postfetch_bulk_save(mapper_rec, state_dict, table)
- else:
- for (
- state,
- state_dict,
- params,
- mapper_rec,
- connection,
- value_params,
- has_all_pks,
- has_all_defaults,
- ) in records:
- if value_params:
- result = connection._execute_20(
- statement.values(value_params),
- params,
- execution_options=execution_options,
- )
- else:
- result = connection._execute_20(
- statement,
- params,
- execution_options=execution_options,
- )
-
- primary_key = result.inserted_primary_key
- for pk, col in zip(
- primary_key, mapper._pks_by_table[table]
- ):
- prop = mapper_rec._columntoproperty[col]
- if (
- col in value_params
- or state_dict.get(prop.key) is None
- ):
- state_dict[prop.key] = pk
- if bookkeeping:
- if state:
- _postfetch(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- result,
- result.context.compiled_parameters[0],
- value_params,
- False,
- result.returned_defaults
- if not result.context.executemany
- else None,
- )
- else:
- _postfetch_bulk_save(mapper_rec, state_dict, table)
-
-
- def _emit_post_update_statements(
- base_mapper, uowtransaction, mapper, table, update
- ):
- """Emit UPDATE statements corresponding to value lists collected
- by _collect_post_update_commands()."""
-
- execution_options = {"compiled_cache": base_mapper._compiled_cache}
-
- needs_version_id = (
- mapper.version_id_col is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- )
-
- def update_stmt():
- clauses = BooleanClauseList._construct_raw(operators.and_)
-
- for col in mapper._pks_by_table[table]:
- clauses.clauses.append(
- col == sql.bindparam(col._label, type_=col.type)
- )
-
- if needs_version_id:
- clauses.clauses.append(
- mapper.version_id_col
- == sql.bindparam(
- mapper.version_id_col._label,
- type_=mapper.version_id_col.type,
- )
- )
-
- stmt = table.update().where(clauses)
-
- if mapper.version_id_col is not None:
- stmt = stmt.return_defaults(mapper.version_id_col)
-
- return stmt
-
- statement = base_mapper._memo(("post_update", table), update_stmt)
-
- # execute each UPDATE in the order according to the original
- # list of states to guarantee row access order, but
- # also group them into common (connection, cols) sets
- # to support executemany().
- for key, records in groupby(
- update,
- lambda rec: (rec[3], set(rec[4])), # connection # parameter keys
- ):
- rows = 0
-
- records = list(records)
- connection = key[0]
-
- assert_singlerow = (
- connection.dialect.supports_sane_rowcount
- if mapper.version_id_col is None
- else connection.dialect.supports_sane_rowcount_returning
- )
- assert_multirow = (
- assert_singlerow
- and connection.dialect.supports_sane_multi_rowcount
- )
- allow_multirow = not needs_version_id or assert_multirow
-
- if not allow_multirow:
- check_rowcount = assert_singlerow
- for state, state_dict, mapper_rec, connection, params in records:
-
- c = connection._execute_20(
- statement, params, execution_options=execution_options
- )
-
- _postfetch_post_update(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- )
- rows += c.rowcount
- else:
- multiparams = [
- params
- for state, state_dict, mapper_rec, conn, params in records
- ]
-
- check_rowcount = assert_multirow or (
- assert_singlerow and len(multiparams) == 1
- )
-
- c = connection._execute_20(
- statement, multiparams, execution_options=execution_options
- )
-
- rows += c.rowcount
- for state, state_dict, mapper_rec, connection, params in records:
- _postfetch_post_update(
- mapper_rec,
- uowtransaction,
- table,
- state,
- state_dict,
- c,
- c.context.compiled_parameters[0],
- )
-
- if check_rowcount:
- if rows != len(records):
- raise orm_exc.StaleDataError(
- "UPDATE statement on table '%s' expected to "
- "update %d row(s); %d were matched."
- % (table.description, len(records), rows)
- )
-
- elif needs_version_id:
- util.warn(
- "Dialect %s does not support updated rowcount "
- "- versioning cannot be verified."
- % c.dialect.dialect_description
- )
-
-
- def _emit_delete_statements(
- base_mapper, uowtransaction, mapper, table, delete
- ):
- """Emit DELETE statements corresponding to value lists collected
- by _collect_delete_commands()."""
-
- need_version_id = (
- mapper.version_id_col is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- )
-
- def delete_stmt():
- clauses = BooleanClauseList._construct_raw(operators.and_)
-
- for col in mapper._pks_by_table[table]:
- clauses.clauses.append(
- col == sql.bindparam(col.key, type_=col.type)
- )
-
- if need_version_id:
- clauses.clauses.append(
- mapper.version_id_col
- == sql.bindparam(
- mapper.version_id_col.key, type_=mapper.version_id_col.type
- )
- )
-
- return table.delete().where(clauses)
-
- statement = base_mapper._memo(("delete", table), delete_stmt)
- for connection, recs in groupby(delete, lambda rec: rec[1]): # connection
- del_objects = [params for params, connection in recs]
-
- execution_options = {"compiled_cache": base_mapper._compiled_cache}
- expected = len(del_objects)
- rows_matched = -1
- only_warn = False
-
- if (
- need_version_id
- and not connection.dialect.supports_sane_multi_rowcount
- ):
- if connection.dialect.supports_sane_rowcount:
- rows_matched = 0
- # execute deletes individually so that versioned
- # rows can be verified
- for params in del_objects:
-
- c = connection._execute_20(
- statement, params, execution_options=execution_options
- )
- rows_matched += c.rowcount
- else:
- util.warn(
- "Dialect %s does not support deleted rowcount "
- "- versioning cannot be verified."
- % connection.dialect.dialect_description
- )
- connection._execute_20(
- statement, del_objects, execution_options=execution_options
- )
- else:
- c = connection._execute_20(
- statement, del_objects, execution_options=execution_options
- )
-
- if not need_version_id:
- only_warn = True
-
- rows_matched = c.rowcount
-
- if (
- base_mapper.confirm_deleted_rows
- and rows_matched > -1
- and expected != rows_matched
- and (
- connection.dialect.supports_sane_multi_rowcount
- or len(del_objects) == 1
- )
- ):
- # TODO: why does this "only warn" if versioning is turned off,
- # whereas the UPDATE raises?
- if only_warn:
- util.warn(
- "DELETE statement on table '%s' expected to "
- "delete %d row(s); %d were matched. Please set "
- "confirm_deleted_rows=False within the mapper "
- "configuration to prevent this warning."
- % (table.description, expected, rows_matched)
- )
- else:
- raise orm_exc.StaleDataError(
- "DELETE statement on table '%s' expected to "
- "delete %d row(s); %d were matched. Please set "
- "confirm_deleted_rows=False within the mapper "
- "configuration to prevent this warning."
- % (table.description, expected, rows_matched)
- )
-
-
- def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
- """finalize state on states that have been inserted or updated,
- including calling after_insert/after_update events.
-
- """
- for state, state_dict, mapper, connection, has_identity in states:
-
- if mapper._readonly_props:
- readonly = state.unmodified_intersection(
- [
- p.key
- for p in mapper._readonly_props
- if (
- p.expire_on_flush
- and (not p.deferred or p.key in state.dict)
- )
- or (
- not p.expire_on_flush
- and not p.deferred
- and p.key not in state.dict
- )
- ]
- )
- if readonly:
- state._expire_attributes(state.dict, readonly)
-
- # if eager_defaults option is enabled, load
- # all expired cols. Else if we have a version_id_col, make sure
- # it isn't expired.
- toload_now = []
-
- if base_mapper.eager_defaults:
- toload_now.extend(
- state._unloaded_non_object.intersection(
- mapper._server_default_plus_onupdate_propkeys
- )
- )
-
- if (
- mapper.version_id_col is not None
- and mapper.version_id_generator is False
- ):
- if mapper._version_id_prop.key in state.unloaded:
- toload_now.extend([mapper._version_id_prop.key])
-
- if toload_now:
- state.key = base_mapper._identity_key_from_state(state)
- stmt = future.select(mapper).set_label_style(
- LABEL_STYLE_TABLENAME_PLUS_COL
- )
- loading.load_on_ident(
- uowtransaction.session,
- stmt,
- state.key,
- refresh_state=state,
- only_load_props=toload_now,
- )
-
- # call after_XXX extensions
- if not has_identity:
- mapper.dispatch.after_insert(mapper, connection, state)
- else:
- mapper.dispatch.after_update(mapper, connection, state)
-
- if (
- mapper.version_id_generator is False
- and mapper.version_id_col is not None
- ):
- if state_dict[mapper._version_id_prop.key] is None:
- raise orm_exc.FlushError(
- "Instance does not contain a non-NULL version value"
- )
-
-
- def _postfetch_post_update(
- mapper, uowtransaction, table, state, dict_, result, params
- ):
- if uowtransaction.is_deleted(state):
- return
-
- prefetch_cols = result.context.compiled.prefetch
- postfetch_cols = result.context.compiled.postfetch
-
- if (
- mapper.version_id_col is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- ):
- prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
-
- refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
- if refresh_flush:
- load_evt_attrs = []
-
- for c in prefetch_cols:
- if c.key in params and c in mapper._columntoproperty:
- dict_[mapper._columntoproperty[c].key] = params[c.key]
- if refresh_flush:
- load_evt_attrs.append(mapper._columntoproperty[c].key)
-
- if refresh_flush and load_evt_attrs:
- mapper.class_manager.dispatch.refresh_flush(
- state, uowtransaction, load_evt_attrs
- )
-
- if postfetch_cols:
- state._expire_attributes(
- state.dict,
- [
- mapper._columntoproperty[c].key
- for c in postfetch_cols
- if c in mapper._columntoproperty
- ],
- )
-
-
- def _postfetch(
- mapper,
- uowtransaction,
- table,
- state,
- dict_,
- result,
- params,
- value_params,
- isupdate,
- returned_defaults,
- ):
- """Expire attributes in need of newly persisted database state,
- after an INSERT or UPDATE statement has proceeded for that
- state."""
-
- prefetch_cols = result.context.compiled.prefetch
- postfetch_cols = result.context.compiled.postfetch
- returning_cols = result.context.compiled.returning
-
- if (
- mapper.version_id_col is not None
- and mapper.version_id_col in mapper._cols_by_table[table]
- ):
- prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
-
- refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
- if refresh_flush:
- load_evt_attrs = []
-
- if returning_cols:
- row = returned_defaults
- if row is not None:
- for row_value, col in zip(row, returning_cols):
- # pk cols returned from insert are handled
- # distinctly, don't step on the values here
- if col.primary_key and result.context.isinsert:
- continue
-
- # note that columns can be in the "return defaults" that are
- # not mapped to this mapper, typically because they are
- # "excluded", which can be specified directly or also occurs
- # when using declarative w/ single table inheritance
- prop = mapper._columntoproperty.get(col)
- if prop:
- dict_[prop.key] = row_value
- if refresh_flush:
- load_evt_attrs.append(prop.key)
-
- for c in prefetch_cols:
- if c.key in params and c in mapper._columntoproperty:
- dict_[mapper._columntoproperty[c].key] = params[c.key]
- if refresh_flush:
- load_evt_attrs.append(mapper._columntoproperty[c].key)
-
- if refresh_flush and load_evt_attrs:
- mapper.class_manager.dispatch.refresh_flush(
- state, uowtransaction, load_evt_attrs
- )
-
- if isupdate and value_params:
- # explicitly suit the use case specified by
- # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
- # database which are set to themselves in order to do a version bump.
- postfetch_cols.extend(
- [
- col
- for col in value_params
- if col.primary_key and col not in returning_cols
- ]
- )
-
- if postfetch_cols:
- state._expire_attributes(
- state.dict,
- [
- mapper._columntoproperty[c].key
- for c in postfetch_cols
- if c in mapper._columntoproperty
- ],
- )
-
- # synchronize newly inserted ids from one table to the next
- # TODO: this still goes a little too often. would be nice to
- # have definitive list of "columns that changed" here
- for m, equated_pairs in mapper._table_to_equated[table]:
- sync.populate(
- state,
- m,
- state,
- m,
- equated_pairs,
- uowtransaction,
- mapper.passive_updates,
- )
-
-
- def _postfetch_bulk_save(mapper, dict_, table):
- for m, equated_pairs in mapper._table_to_equated[table]:
- sync.bulk_populate_inherit_keys(dict_, m, equated_pairs)
-
-
- def _connections_for_states(base_mapper, uowtransaction, states):
- """Return an iterator of (state, state.dict, mapper, connection).
-
- The states are sorted according to _sort_states, then paired
- with the connection they should be using for the given
- unit of work transaction.
-
- """
- # if session has a connection callable,
- # organize individual states with the connection
- # to use for update
- if uowtransaction.session.connection_callable:
- connection_callable = uowtransaction.session.connection_callable
- else:
- connection = uowtransaction.transaction.connection(base_mapper)
- connection_callable = None
-
- for state in _sort_states(base_mapper, states):
- if connection_callable:
- connection = connection_callable(base_mapper, state.obj())
-
- mapper = state.manager.mapper
-
- yield state, state.dict, mapper, connection
-
-
- def _sort_states(mapper, states):
- pending = set(states)
- persistent = set(s for s in pending if s.key is not None)
- pending.difference_update(persistent)
-
- try:
- persistent_sorted = sorted(
- persistent, key=mapper._persistent_sortkey_fn
- )
- except TypeError as err:
- util.raise_(
- sa_exc.InvalidRequestError(
- "Could not sort objects by primary key; primary key "
- "values must be sortable in Python (was: %s)" % err
- ),
- replace_context=err,
- )
- return (
- sorted(pending, key=operator.attrgetter("insert_order"))
- + persistent_sorted
- )
-
-
- _EMPTY_DICT = util.immutabledict()
-
-
- class BulkUDCompileState(CompileState):
- class default_update_options(Options):
- _synchronize_session = "evaluate"
- _autoflush = True
- _subject_mapper = None
- _resolved_values = _EMPTY_DICT
- _resolved_keys_as_propnames = _EMPTY_DICT
- _value_evaluators = _EMPTY_DICT
- _matched_objects = None
- _matched_rows = None
- _refresh_identity_token = None
-
- @classmethod
- def orm_pre_session_exec(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- is_reentrant_invoke,
- ):
- if is_reentrant_invoke:
- return statement, execution_options
-
- (
- update_options,
- execution_options,
- ) = BulkUDCompileState.default_update_options.from_execution_options(
- "_sa_orm_update_options",
- {"synchronize_session"},
- execution_options,
- statement._execution_options,
- )
-
- sync = update_options._synchronize_session
- if sync is not None:
- if sync not in ("evaluate", "fetch", False):
- raise sa_exc.ArgumentError(
- "Valid strategies for session synchronization "
- "are 'evaluate', 'fetch', False"
- )
-
- bind_arguments["clause"] = statement
- try:
- plugin_subject = statement._propagate_attrs["plugin_subject"]
- except KeyError:
- assert False, "statement had 'orm' plugin but no plugin_subject"
- else:
- bind_arguments["mapper"] = plugin_subject.mapper
-
- update_options += {"_subject_mapper": plugin_subject.mapper}
-
- if update_options._autoflush:
- session._autoflush()
-
- statement = statement._annotate(
- {"synchronize_session": update_options._synchronize_session}
- )
-
- # this stage of the execution is called before the do_orm_execute event
- # hook. meaning for an extension like horizontal sharding, this step
- # happens before the extension splits out into multiple backends and
- # runs only once. if we do pre_sync_fetch, we execute a SELECT
- # statement, which the horizontal sharding extension splits amongst the
- # shards and combines the results together.
-
- if update_options._synchronize_session == "evaluate":
- update_options = cls._do_pre_synchronize_evaluate(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
- elif update_options._synchronize_session == "fetch":
- update_options = cls._do_pre_synchronize_fetch(
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- )
-
- return (
- statement,
- util.immutabledict(execution_options).union(
- dict(_sa_orm_update_options=update_options)
- ),
- )
-
- @classmethod
- def orm_setup_cursor_result(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- result,
- ):
-
- # this stage of the execution is called after the
- # do_orm_execute event hook. meaning for an extension like
- # horizontal sharding, this step happens *within* the horizontal
- # sharding event handler which calls session.execute() re-entrantly
- # and will occur for each backend individually.
- # the sharding extension then returns its own merged result from the
- # individual ones we return here.
-
- update_options = execution_options["_sa_orm_update_options"]
- if update_options._synchronize_session == "evaluate":
- cls._do_post_synchronize_evaluate(session, result, update_options)
- elif update_options._synchronize_session == "fetch":
- cls._do_post_synchronize_fetch(session, result, update_options)
-
- return result
-
- @classmethod
- def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
- """Apply extra criteria filtering.
-
- For all distinct single-table-inheritance mappers represented in the
- table being updated or deleted, produce additional WHERE criteria such
- that only the appropriate subtypes are selected from the total results.
-
- Additionally, add WHERE criteria originating from LoaderCriteriaOptions
- collected from the statement.
-
- """
-
- return_crit = ()
-
- adapter = ext_info._adapter if ext_info.is_aliased_class else None
-
- if (
- "additional_entity_criteria",
- ext_info.mapper,
- ) in global_attributes:
- return_crit += tuple(
- ae._resolve_where_criteria(ext_info)
- for ae in global_attributes[
- ("additional_entity_criteria", ext_info.mapper)
- ]
- if ae.include_aliases or ae.entity is ext_info
- )
-
- if ext_info.mapper._single_table_criterion is not None:
- return_crit += (ext_info.mapper._single_table_criterion,)
-
- if adapter:
- return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
-
- return return_crit
-
- @classmethod
- def _do_pre_synchronize_evaluate(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- ):
- mapper = update_options._subject_mapper
- target_cls = mapper.class_
-
- value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT
-
- try:
- evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- crit = ()
- if statement._where_criteria:
- crit += statement._where_criteria
-
- global_attributes = {}
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(global_attributes)
-
- if global_attributes:
- crit += cls._adjust_for_extra_criteria(
- global_attributes, mapper
- )
-
- if crit:
- eval_condition = evaluator_compiler.process(*crit)
- else:
-
- def eval_condition(obj):
- return True
-
- except evaluator.UnevaluatableError as err:
- util.raise_(
- sa_exc.InvalidRequestError(
- 'Could not evaluate current criteria in Python: "%s". '
- "Specify 'fetch' or False for the "
- "synchronize_session execution option." % err
- ),
- from_=err,
- )
-
- if statement.__visit_name__ == "lambda_element":
- # ._resolved is called on every LambdaElement in order to
- # generate the cache key, so this access does not add
- # additional expense
- effective_statement = statement._resolved
- else:
- effective_statement = statement
-
- if effective_statement.__visit_name__ == "update":
- resolved_values = cls._get_resolved_values(
- mapper, effective_statement
- )
- value_evaluators = {}
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
-
- # TODO: detect when the where clause is a trivial primary key match.
- matched_objects = [
- state.obj()
- for state in session.identity_map.all_states()
- if state.mapper.isa(mapper)
- and not state.expired
- and eval_condition(state.obj())
- and (
- update_options._refresh_identity_token is None
- # TODO: coverage for the case where horizontal sharding
- # invokes an update() or delete() given an explicit identity
- # token up front
- or state.identity_token
- == update_options._refresh_identity_token
- )
- ]
- return update_options + {
- "_matched_objects": matched_objects,
- "_value_evaluators": value_evaluators,
- "_resolved_keys_as_propnames": resolved_keys_as_propnames,
- }
-
- @classmethod
- def _get_resolved_values(cls, mapper, statement):
- if statement._multi_values:
- return []
- elif statement._ordered_values:
- iterator = statement._ordered_values
- elif statement._values:
- iterator = statement._values.items()
- else:
- return []
-
- values = []
- if iterator:
- for k, v in iterator:
- if mapper:
- if isinstance(k, util.string_types):
- desc = _entity_namespace_key(mapper, k)
- values.extend(desc._bulk_update_tuples(v))
- elif "entity_namespace" in k._annotations:
- k_anno = k._annotations
- attr = _entity_namespace_key(
- k_anno["entity_namespace"], k_anno["proxy_key"]
- )
- values.extend(attr._bulk_update_tuples(v))
- else:
- values.append((k, v))
- else:
- values.append((k, v))
- return values
-
- @classmethod
- def _resolved_keys_as_propnames(cls, mapper, resolved_values):
- values = []
- for k, v in resolved_values:
- if isinstance(k, attributes.QueryableAttribute):
- values.append((k.key, v))
- continue
- elif hasattr(k, "__clause_element__"):
- k = k.__clause_element__()
-
- if mapper and isinstance(k, expression.ColumnElement):
- try:
- attr = mapper._columntoproperty[k]
- except orm_exc.UnmappedColumnError:
- pass
- else:
- values.append((attr.key, v))
- else:
- raise sa_exc.InvalidRequestError(
- "Invalid expression type: %r" % k
- )
- return values
-
- @classmethod
- def _do_pre_synchronize_fetch(
- cls,
- session,
- statement,
- params,
- execution_options,
- bind_arguments,
- update_options,
- ):
- mapper = update_options._subject_mapper
-
- select_stmt = (
- select(*(mapper.primary_key + (mapper.select_identity_token,)))
- .select_from(mapper)
- .options(*statement._with_options)
- )
- select_stmt._where_criteria = statement._where_criteria
-
- def skip_for_full_returning(orm_context):
- bind = orm_context.session.get_bind(**orm_context.bind_arguments)
- if bind.dialect.full_returning:
- return _result.null_result()
- else:
- return None
-
- result = session.execute(
- select_stmt,
- params,
- execution_options,
- bind_arguments,
- _add_event=skip_for_full_returning,
- )
- matched_rows = result.fetchall()
-
- value_evaluators = _EMPTY_DICT
-
- if statement.__visit_name__ == "lambda_element":
- # ._resolved is called on every LambdaElement in order to
- # generate the cache key, so this access does not add
- # additional expense
- effective_statement = statement._resolved
- else:
- effective_statement = statement
-
- if effective_statement.__visit_name__ == "update":
- target_cls = mapper.class_
- evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
- resolved_values = cls._get_resolved_values(
- mapper, effective_statement
- )
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
-
- resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
- mapper, resolved_values
- )
- value_evaluators = {}
- for key, value in resolved_keys_as_propnames:
- try:
- _evaluator = evaluator_compiler.process(
- coercions.expect(roles.ExpressionElementRole, value)
- )
- except evaluator.UnevaluatableError:
- pass
- else:
- value_evaluators[key] = _evaluator
-
- else:
- resolved_keys_as_propnames = _EMPTY_DICT
-
- return update_options + {
- "_value_evaluators": value_evaluators,
- "_matched_rows": matched_rows,
- "_resolved_keys_as_propnames": resolved_keys_as_propnames,
- }
-
-
- @CompileState.plugin_for("orm", "update")
- class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
- @classmethod
- def create_for_statement(cls, statement, compiler, **kw):
-
- self = cls.__new__(cls)
-
- ext_info = statement.table._annotations["parententity"]
-
- self.mapper = mapper = ext_info.mapper
-
- self.extra_criteria_entities = {}
-
- self._resolved_values = cls._get_resolved_values(mapper, statement)
-
- extra_criteria_attributes = {}
-
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(extra_criteria_attributes)
-
- if not statement._preserve_parameter_order and statement._values:
- self._resolved_values = dict(self._resolved_values)
-
- new_stmt = sql.Update.__new__(sql.Update)
- new_stmt.__dict__.update(statement.__dict__)
- new_stmt.table = mapper.local_table
-
- # note if the statement has _multi_values, these
- # are passed through to the new statement, which will then raise
- # InvalidRequestError because UPDATE doesn't support multi_values
- # right now.
- if statement._ordered_values:
- new_stmt._ordered_values = self._resolved_values
- elif statement._values:
- new_stmt._values = self._resolved_values
-
- new_crit = cls._adjust_for_extra_criteria(
- extra_criteria_attributes, mapper
- )
- if new_crit:
- new_stmt = new_stmt.where(*new_crit)
-
- # if we are against a lambda statement we might not be the
- # topmost object that received per-execute annotations
-
- if (
- compiler._annotations.get("synchronize_session", None) == "fetch"
- and compiler.dialect.full_returning
- ):
- if new_stmt._returning:
- raise sa_exc.InvalidRequestError(
- "Can't use synchronize_session='fetch' "
- "with explicit returning()"
- )
- new_stmt = new_stmt.returning(*mapper.primary_key)
-
- UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
-
- return self
-
- @classmethod
- def _do_post_synchronize_evaluate(cls, session, result, update_options):
-
- states = set()
- evaluated_keys = list(update_options._value_evaluators.keys())
- values = update_options._resolved_keys_as_propnames
- attrib = set(k for k, v in values)
- for obj in update_options._matched_objects:
-
- state, dict_ = (
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
- )
-
- # the evaluated states were gathered across all identity tokens.
- # however the post_sync events are called per identity token,
- # so filter.
- if (
- update_options._refresh_identity_token is not None
- and state.identity_token
- != update_options._refresh_identity_token
- ):
- continue
-
- # only evaluate unmodified attributes
- to_evaluate = state.unmodified.intersection(evaluated_keys)
- for key in to_evaluate:
- if key in dict_:
- dict_[key] = update_options._value_evaluators[key](obj)
-
- state.manager.dispatch.refresh(state, None, to_evaluate)
-
- state._commit(dict_, list(to_evaluate))
-
- to_expire = attrib.intersection(dict_).difference(to_evaluate)
- if to_expire:
- state._expire_attributes(dict_, to_expire)
-
- states.add(state)
- session._register_altered(states)
-
- @classmethod
- def _do_post_synchronize_fetch(cls, session, result, update_options):
- target_mapper = update_options._subject_mapper
-
- states = set()
- evaluated_keys = list(update_options._value_evaluators.keys())
-
- if result.returns_rows:
- matched_rows = [
- tuple(row) + (update_options._refresh_identity_token,)
- for row in result.all()
- ]
- else:
- matched_rows = update_options._matched_rows
-
- objs = [
- session.identity_map[identity_key]
- for identity_key in [
- target_mapper.identity_key_from_primary_key(
- list(primary_key),
- identity_token=identity_token,
- )
- for primary_key, identity_token in [
- (row[0:-1], row[-1]) for row in matched_rows
- ]
- if update_options._refresh_identity_token is None
- or identity_token == update_options._refresh_identity_token
- ]
- if identity_key in session.identity_map
- ]
-
- values = update_options._resolved_keys_as_propnames
- attrib = set(k for k, v in values)
-
- for obj in objs:
- state, dict_ = (
- attributes.instance_state(obj),
- attributes.instance_dict(obj),
- )
-
- to_evaluate = state.unmodified.intersection(evaluated_keys)
- for key in to_evaluate:
- if key in dict_:
- dict_[key] = update_options._value_evaluators[key](obj)
- state.manager.dispatch.refresh(state, None, to_evaluate)
-
- state._commit(dict_, list(to_evaluate))
-
- to_expire = attrib.intersection(dict_).difference(to_evaluate)
- if to_expire:
- state._expire_attributes(dict_, to_expire)
-
- states.add(state)
- session._register_altered(states)
-
-
- @CompileState.plugin_for("orm", "delete")
- class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
- @classmethod
- def create_for_statement(cls, statement, compiler, **kw):
- self = cls.__new__(cls)
-
- ext_info = statement.table._annotations["parententity"]
- self.mapper = mapper = ext_info.mapper
-
- self.extra_criteria_entities = {}
-
- extra_criteria_attributes = {}
-
- for opt in statement._with_options:
- if opt._is_criteria_option:
- opt.get_global_criteria(extra_criteria_attributes)
-
- new_crit = cls._adjust_for_extra_criteria(
- extra_criteria_attributes, mapper
- )
- if new_crit:
- statement = statement.where(*new_crit)
-
- if (
- mapper
- and compiler._annotations.get("synchronize_session", None)
- == "fetch"
- and compiler.dialect.full_returning
- ):
- statement = statement.returning(*mapper.primary_key)
-
- DeleteDMLState.__init__(self, statement, compiler, **kw)
-
- return self
-
- @classmethod
- def _do_post_synchronize_evaluate(cls, session, result, update_options):
-
- session._remove_newly_deleted(
- [
- attributes.instance_state(obj)
- for obj in update_options._matched_objects
- ]
- )
-
- @classmethod
- def _do_post_synchronize_fetch(cls, session, result, update_options):
- target_mapper = update_options._subject_mapper
-
- if result.returns_rows:
- matched_rows = [
- tuple(row) + (update_options._refresh_identity_token,)
- for row in result.all()
- ]
- else:
- matched_rows = update_options._matched_rows
-
- for row in matched_rows:
- primary_key = row[0:-1]
- identity_token = row[-1]
-
- # TODO: inline this and call remove_newly_deleted
- # once
- identity_key = target_mapper.identity_key_from_primary_key(
- list(primary_key),
- identity_token=identity_token,
- )
- if identity_key in session.identity_map:
- session._remove_newly_deleted(
- [
- attributes.instance_state(
- session.identity_map[identity_key]
- )
- ]
- )
|