mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-31 04:48:02 -04:00
cleanup and callcount reduction in mapper._save_obj, _delete_obj.
includes an untested fix for [ticket:1761]
This commit is contained in:
+196
-110
@@ -1224,11 +1224,13 @@ class Mapper(object):
|
||||
try:
|
||||
if item_type == 'property':
|
||||
prop = iterator.next()
|
||||
visitables.append((prop.cascade_iterator(type_, parent_state, visited_instances, halt_on), 'mapper', None))
|
||||
visitables.append((prop.cascade_iterator(type_, parent_state,
|
||||
visited_instances, halt_on), 'mapper', None))
|
||||
elif item_type == 'mapper':
|
||||
instance, instance_mapper, corresponding_state = iterator.next()
|
||||
yield (instance, instance_mapper)
|
||||
visitables.append((instance_mapper._props.itervalues(), 'property', corresponding_state))
|
||||
visitables.append((instance_mapper._props.itervalues(),
|
||||
'property', corresponding_state))
|
||||
except StopIteration:
|
||||
visitables.pop()
|
||||
|
||||
@@ -1263,55 +1265,46 @@ class Mapper(object):
|
||||
# if batch=false, call _save_obj separately for each object
|
||||
if not single and not self.batch:
|
||||
for state in _sort_states(states):
|
||||
self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
|
||||
self._save_obj([state],
|
||||
uowtransaction,
|
||||
postupdate=postupdate,
|
||||
post_update_cols=post_update_cols,
|
||||
single=True)
|
||||
return
|
||||
|
||||
|
||||
# if session has a connection callable,
|
||||
# organize individual states with the connection to use for insert/update
|
||||
tups = []
|
||||
# organize individual states with the connection
|
||||
# to use for insert/update
|
||||
if 'connection_callable' in uowtransaction.mapper_flush_opts:
|
||||
connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
|
||||
for state in _sort_states(states):
|
||||
m = _state_mapper(state)
|
||||
tups.append(
|
||||
(
|
||||
state,
|
||||
m,
|
||||
connection_callable(self, state.obj()),
|
||||
_state_has_identity(state),
|
||||
state.key or m._identity_key_from_state(state)
|
||||
)
|
||||
)
|
||||
connection_callable = \
|
||||
uowtransaction.mapper_flush_opts['connection_callable']
|
||||
else:
|
||||
connection = uowtransaction.transaction.connection(self)
|
||||
for state in _sort_states(states):
|
||||
m = _state_mapper(state)
|
||||
tups.append(
|
||||
(
|
||||
state,
|
||||
m,
|
||||
connection,
|
||||
_state_has_identity(state),
|
||||
state.key or m._identity_key_from_state(state)
|
||||
)
|
||||
)
|
||||
connection_callable = None
|
||||
|
||||
if not postupdate:
|
||||
# call before_XXX extensions
|
||||
for state, mapper, connection, has_identity, instance_key in tups:
|
||||
tups = []
|
||||
for state in _sort_states(states):
|
||||
conn = connection_callable and \
|
||||
connection_callable(self, state.obj()) or \
|
||||
connection
|
||||
|
||||
has_identity = _state_has_identity(state)
|
||||
mapper = _state_mapper(state)
|
||||
instance_key = state.key or mapper._identity_key_from_state(state)
|
||||
|
||||
row_switch = None
|
||||
if not postupdate:
|
||||
# call before_XXX extensions
|
||||
if not has_identity:
|
||||
if 'before_insert' in mapper.extension:
|
||||
mapper.extension.before_insert(mapper, connection, state.obj())
|
||||
mapper.extension.before_insert(mapper, conn, state.obj())
|
||||
else:
|
||||
if 'before_update' in mapper.extension:
|
||||
mapper.extension.before_update(mapper, connection, state.obj())
|
||||
mapper.extension.before_update(mapper, conn, state.obj())
|
||||
|
||||
row_switches = {}
|
||||
if not postupdate:
|
||||
for state, mapper, connection, has_identity, instance_key in tups:
|
||||
# detect if we have a "pending" instance (i.e. has no instance_key attached to it),
|
||||
# and another instance with the same identity key already exists as persistent. convert to an
|
||||
# UPDATE if so.
|
||||
# and another instance with the same identity key already exists as persistent.
|
||||
# convert to an UPDATE if so.
|
||||
if not has_identity and instance_key in uowtransaction.session.identity_map:
|
||||
instance = uowtransaction.session.identity_map[instance_key]
|
||||
existing = attributes.instance_state(instance)
|
||||
@@ -1320,28 +1313,42 @@ class Mapper(object):
|
||||
"New instance %s with identity key %s conflicts "
|
||||
"with persistent instance %s" %
|
||||
(state_str(state), instance_key, state_str(existing)))
|
||||
|
||||
|
||||
self._log_debug(
|
||||
"detected row switch for identity %s. will update %s, remove %s from "
|
||||
"transaction", instance_key, state_str(state), state_str(existing))
|
||||
|
||||
"detected row switch for identity %s. "
|
||||
"will update %s, remove %s from "
|
||||
"transaction", instance_key,
|
||||
state_str(state), state_str(existing))
|
||||
|
||||
# remove the "delete" flag from the existing element
|
||||
uowtransaction.set_row_switch(existing)
|
||||
row_switches[state] = existing
|
||||
|
||||
table_to_mapper = self._sorted_tables
|
||||
row_switch = existing
|
||||
|
||||
for table in table_to_mapper.iterkeys():
|
||||
tups.append(
|
||||
(state,
|
||||
mapper,
|
||||
conn,
|
||||
has_identity,
|
||||
instance_key,
|
||||
row_switch)
|
||||
)
|
||||
|
||||
table_to_mapper = self._sorted_tables
|
||||
|
||||
for table in table_to_mapper:
|
||||
insert = []
|
||||
update = []
|
||||
|
||||
for state, mapper, connection, has_identity, instance_key in tups:
|
||||
for state, mapper, connection, has_identity, \
|
||||
instance_key, row_switch in tups:
|
||||
if table not in mapper._pks_by_table:
|
||||
continue
|
||||
|
||||
pks = mapper._pks_by_table[table]
|
||||
|
||||
isinsert = not has_identity and not postupdate and state not in row_switches
|
||||
isinsert = not has_identity and \
|
||||
not postupdate and \
|
||||
not row_switch
|
||||
|
||||
params = {}
|
||||
value_params = {}
|
||||
@@ -1371,23 +1378,36 @@ class Mapper(object):
|
||||
value_params[col] = value
|
||||
else:
|
||||
params[col.key] = value
|
||||
insert.append((state, params, mapper, connection, value_params))
|
||||
insert.append((state, params, mapper,
|
||||
connection, value_params))
|
||||
else:
|
||||
for col in mapper._cols_by_table[table]:
|
||||
if col is mapper.version_id_col:
|
||||
params[col._label] = mapper._get_state_attr_by_column(row_switches.get(state, state), col)
|
||||
params[col.key] = mapper.version_id_generator(params[col._label])
|
||||
params[col._label] = \
|
||||
mapper._get_state_attr_by_column(
|
||||
row_switch or state,
|
||||
col)
|
||||
params[col.key] = \
|
||||
mapper.version_id_generator(params[col._label])
|
||||
|
||||
# HACK: check for history, in case the history is only
|
||||
# in a different table than the one where the version_id_col
|
||||
# is.
|
||||
for prop in mapper._columntoproperty.itervalues():
|
||||
history = attributes.get_state_history(state, prop.key, passive=True)
|
||||
history = attributes.get_state_history(
|
||||
state, prop.key, passive=True)
|
||||
if history.added:
|
||||
hasdata = True
|
||||
elif mapper.polymorphic_on is not None and \
|
||||
mapper.polymorphic_on.shares_lineage(col) and col not in pks:
|
||||
mapper.polymorphic_on.shares_lineage(col) and \
|
||||
col not in pks:
|
||||
pass
|
||||
else:
|
||||
if post_update_cols is not None and col not in post_update_cols:
|
||||
if post_update_cols is not None and \
|
||||
col not in post_update_cols:
|
||||
if col in pks:
|
||||
params[col._label] = mapper._get_state_attr_by_column(state, col)
|
||||
params[col._label] = \
|
||||
mapper._get_state_attr_by_column(state, col)
|
||||
continue
|
||||
|
||||
prop = mapper._columntoproperty[col]
|
||||
@@ -1424,27 +1444,32 @@ class Mapper(object):
|
||||
elif col in pks:
|
||||
params[col._label] = mapper._get_state_attr_by_column(state, col)
|
||||
if hasdata:
|
||||
update.append((state, params, mapper, connection, value_params))
|
||||
update.append((state, params, mapper,
|
||||
connection, value_params))
|
||||
|
||||
if update:
|
||||
mapper = table_to_mapper[table]
|
||||
clause = sql.and_()
|
||||
|
||||
for col in mapper._pks_by_table[table]:
|
||||
clause.clauses.append(col == sql.bindparam(col._label, type_=col.type))
|
||||
clause.clauses.append(
|
||||
col ==
|
||||
sql.bindparam(col._label, type_=col.type)
|
||||
)
|
||||
|
||||
if mapper.version_id_col is not None and \
|
||||
table.c.contains_column(mapper.version_id_col):
|
||||
|
||||
needs_version_id = mapper.version_id_col is not None and \
|
||||
table.c.contains_column(mapper.version_id_col)
|
||||
|
||||
if needs_version_id:
|
||||
clause.clauses.append(mapper.version_id_col ==\
|
||||
sql.bindparam(mapper.version_id_col._label, type_=col.type))
|
||||
|
||||
statement = table.update(clause)
|
||||
|
||||
|
||||
rows = 0
|
||||
for state, params, mapper, connection, value_params in update:
|
||||
c = connection.execute(statement.values(value_params), params)
|
||||
mapper._postfetch(uowtransaction, connection, table,
|
||||
mapper._postfetch(uowtransaction, table,
|
||||
state, c, c.last_updated_params(), value_params)
|
||||
|
||||
rows += c.rowcount
|
||||
@@ -1452,13 +1477,15 @@ class Mapper(object):
|
||||
if connection.dialect.supports_sane_rowcount:
|
||||
if rows != len(update):
|
||||
raise orm_exc.ConcurrentModificationError(
|
||||
"Updated rowcount %d does not match number of objects updated %d" %
|
||||
"Updated rowcount %d does not match number "
|
||||
"of objects updated %d" %
|
||||
(rows, len(update)))
|
||||
|
||||
elif mapper.version_id_col is not None:
|
||||
|
||||
elif needs_version_id:
|
||||
util.warn("Dialect %s does not support updated rowcount "
|
||||
"- versioning cannot be verified." % c.dialect.dialect_description,
|
||||
stacklevel=12)
|
||||
"- versioning cannot be verified." %
|
||||
c.dialect.dialect_description,
|
||||
stacklevel=12)
|
||||
|
||||
if insert:
|
||||
statement = table.insert()
|
||||
@@ -1473,12 +1500,12 @@ class Mapper(object):
|
||||
len(primary_key) > i:
|
||||
mapper._set_state_attr_by_column(state, col, primary_key[i])
|
||||
|
||||
mapper._postfetch(uowtransaction, connection, table,
|
||||
mapper._postfetch(uowtransaction, table,
|
||||
state, c, c.last_inserted_params(), value_params)
|
||||
|
||||
|
||||
if not postupdate:
|
||||
for state, mapper, connection, has_identity, instance_key in tups:
|
||||
for state, mapper, connection, has_identity, \
|
||||
instance_key, row_switch in tups:
|
||||
|
||||
# expire readonly attributes
|
||||
readonly = state.unmodified.intersection(
|
||||
@@ -1488,8 +1515,8 @@ class Mapper(object):
|
||||
if readonly:
|
||||
_expire_state(state, state.dict, readonly)
|
||||
|
||||
# if specified, eagerly refresh whatever has
|
||||
# been expired.
|
||||
# if eager_defaults option is enabled,
|
||||
# refresh whatever has been expired.
|
||||
if self.eager_defaults and state.unloaded:
|
||||
state.key = self._identity_key_from_state(state)
|
||||
uowtransaction.session.query(self)._get(
|
||||
@@ -1504,7 +1531,7 @@ class Mapper(object):
|
||||
if 'after_update' in mapper.extension:
|
||||
mapper.extension.after_update(mapper, connection, state.obj())
|
||||
|
||||
def _postfetch(self, uowtransaction, connection, table,
|
||||
def _postfetch(self, uowtransaction, table,
|
||||
state, resultproxy, params, value_params):
|
||||
"""Expire attributes in need of newly persisted database state."""
|
||||
|
||||
@@ -1523,23 +1550,37 @@ class Mapper(object):
|
||||
if c.key in params and c in self._columntoproperty:
|
||||
self._set_state_attr_by_column(state, c, params[c.key])
|
||||
|
||||
deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]]
|
||||
|
||||
if deferred_props:
|
||||
_expire_state(state, state.dict, deferred_props)
|
||||
if postfetch_cols:
|
||||
_expire_state(state, state.dict,
|
||||
[self._columntoproperty[c].key
|
||||
for c in postfetch_cols]
|
||||
)
|
||||
|
||||
# synchronize newly inserted ids from one table to the next
|
||||
# TODO: this still goes a little too often. would be nice to
|
||||
# have definitive list of "columns that changed" here
|
||||
cols = set(table.c)
|
||||
for m in self.iterate_to_root():
|
||||
if m._inherits_equated_pairs and \
|
||||
cols.intersection([l for l, r in m._inherits_equated_pairs]):
|
||||
sync.populate(state, m, state, m,
|
||||
m._inherits_equated_pairs,
|
||||
uowtransaction,
|
||||
self.passive_updates)
|
||||
|
||||
for m, equated_pairs in self._table_to_equated[table]:
|
||||
sync.populate(state, m, state, m,
|
||||
equated_pairs,
|
||||
uowtransaction,
|
||||
self.passive_updates)
|
||||
|
||||
@util.memoized_property
|
||||
def _table_to_equated(self):
|
||||
"""memoized map of tables to collections of columns to be
|
||||
synchronized upwards to the base mapper."""
|
||||
|
||||
result = util.defaultdict(list)
|
||||
|
||||
for table in self._sorted_tables:
|
||||
cols = set(table.c)
|
||||
for m in self.iterate_to_root():
|
||||
if m._inherits_equated_pairs and \
|
||||
cols.intersection([l for l, r in m._inherits_equated_pairs]):
|
||||
result[table].append((m, m._inherits_equated_pairs))
|
||||
|
||||
return result
|
||||
|
||||
def _delete_obj(self, states, uowtransaction):
|
||||
"""Issue ``DELETE`` statements for a list of objects.
|
||||
|
||||
@@ -1548,50 +1589,95 @@ class Mapper(object):
|
||||
|
||||
"""
|
||||
if 'connection_callable' in uowtransaction.mapper_flush_opts:
|
||||
connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
|
||||
tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in _sort_states(states)]
|
||||
connection_callable = \
|
||||
uowtransaction.mapper_flush_opts['connection_callable']
|
||||
else:
|
||||
connection = uowtransaction.transaction.connection(self)
|
||||
tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)]
|
||||
|
||||
for state, mapper, connection in tups:
|
||||
connection_callable = None
|
||||
|
||||
tups = []
|
||||
for state in _sort_states(states):
|
||||
mapper = _state_mapper(state)
|
||||
|
||||
conn = connection_callable and \
|
||||
connection_callable(self, state.obj()) or \
|
||||
connection
|
||||
|
||||
if 'before_delete' in mapper.extension:
|
||||
mapper.extension.before_delete(mapper, connection, state.obj())
|
||||
mapper.extension.before_delete(mapper, conn, state.obj())
|
||||
|
||||
tups.append((state,
|
||||
_state_mapper(state),
|
||||
_state_has_identity(state),
|
||||
conn))
|
||||
|
||||
table_to_mapper = self._sorted_tables
|
||||
|
||||
for table in reversed(table_to_mapper.keys()):
|
||||
delete = {}
|
||||
for state, mapper, connection in tups:
|
||||
if table not in mapper._pks_by_table:
|
||||
delete = util.defaultdict(list)
|
||||
for state, mapper, has_identity, connection in tups:
|
||||
if not has_identity or table not in mapper._pks_by_table:
|
||||
continue
|
||||
|
||||
params = {}
|
||||
if not _state_has_identity(state):
|
||||
continue
|
||||
else:
|
||||
delete.setdefault(connection, []).append(params)
|
||||
delete[connection].append(params)
|
||||
for col in mapper._pks_by_table[table]:
|
||||
params[col.key] = mapper._get_state_attr_by_column(state, col)
|
||||
if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
|
||||
params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col)
|
||||
if mapper.version_id_col is not None and \
|
||||
table.c.contains_column(mapper.version_id_col):
|
||||
params[mapper.version_id_col.key] = \
|
||||
mapper._get_state_attr_by_column(state, mapper.version_id_col)
|
||||
|
||||
for connection, del_objects in delete.iteritems():
|
||||
mapper = table_to_mapper[table]
|
||||
clause = sql.and_()
|
||||
for col in mapper._pks_by_table[table]:
|
||||
clause.clauses.append(col == sql.bindparam(col.key, type_=col.type))
|
||||
if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
|
||||
|
||||
need_version_id = mapper.version_id_col is not None and \
|
||||
table.c.contains_column(mapper.version_id_col)
|
||||
|
||||
if need_version_id:
|
||||
clause.clauses.append(
|
||||
mapper.version_id_col ==
|
||||
sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type))
|
||||
statement = table.delete(clause)
|
||||
c = connection.execute(statement, del_objects)
|
||||
if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects):
|
||||
raise orm_exc.ConcurrentModificationError("Deleted rowcount %d does not match "
|
||||
"number of objects deleted %d" % (c.rowcount, len(del_objects)))
|
||||
sql.bindparam(
|
||||
mapper.version_id_col.key,
|
||||
type_=mapper.version_id_col.type
|
||||
)
|
||||
)
|
||||
|
||||
for state, mapper, connection in tups:
|
||||
statement = table.delete(clause)
|
||||
rows = -1
|
||||
|
||||
if need_version_id and \
|
||||
not connection.dialect.supports_sane_multi_rowcount:
|
||||
# TODO: need test coverage for this [ticket:1761]
|
||||
if connection.dialect.supports_sane_rowcount:
|
||||
rows = 0
|
||||
# execute deletes individually so that versioned
|
||||
# rows can be verified
|
||||
for params in del_objects:
|
||||
c = connection.execute(statement, params)
|
||||
rows += c.rowcount
|
||||
else:
|
||||
util.warn("Dialect %s does not support deleted rowcount "
|
||||
"- versioning cannot be verified." %
|
||||
c.dialect.dialect_description,
|
||||
stacklevel=12)
|
||||
connection.execute(statement, del_objects)
|
||||
else:
|
||||
c = connection.execute(statement, del_objects)
|
||||
if connection.dialect.supports_sane_multi_rowcount:
|
||||
rows = c.rowcount
|
||||
|
||||
if rows != -1 and rows != len(del_objects):
|
||||
raise orm_exc.ConcurrentModificationError(
|
||||
"Deleted rowcount %d does not match "
|
||||
"number of objects deleted %d" %
|
||||
(c.rowcount, len(del_objects))
|
||||
)
|
||||
|
||||
for state, mapper, has_identity, connection in tups:
|
||||
if 'after_delete' in mapper.extension:
|
||||
mapper.extension.after_delete(mapper, connection, state.obj())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user