cleanup and callcount reduction in mapper._save_obj, _delete_obj.

includes an untested fix for [ticket:1761]
This commit is contained in:
Mike Bayer
2010-04-03 21:42:41 -04:00
parent eefdbd3757
commit dfab13e9ae
+196 -110
View File
@@ -1224,11 +1224,13 @@ class Mapper(object):
try:
if item_type == 'property':
prop = iterator.next()
visitables.append((prop.cascade_iterator(type_, parent_state, visited_instances, halt_on), 'mapper', None))
visitables.append((prop.cascade_iterator(type_, parent_state,
visited_instances, halt_on), 'mapper', None))
elif item_type == 'mapper':
instance, instance_mapper, corresponding_state = iterator.next()
yield (instance, instance_mapper)
visitables.append((instance_mapper._props.itervalues(), 'property', corresponding_state))
visitables.append((instance_mapper._props.itervalues(),
'property', corresponding_state))
except StopIteration:
visitables.pop()
@@ -1263,55 +1265,46 @@ class Mapper(object):
# if batch=false, call _save_obj separately for each object
if not single and not self.batch:
for state in _sort_states(states):
self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True)
self._save_obj([state],
uowtransaction,
postupdate=postupdate,
post_update_cols=post_update_cols,
single=True)
return
# if session has a connection callable,
# organize individual states with the connection to use for insert/update
tups = []
# organize individual states with the connection
# to use for insert/update
if 'connection_callable' in uowtransaction.mapper_flush_opts:
connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
for state in _sort_states(states):
m = _state_mapper(state)
tups.append(
(
state,
m,
connection_callable(self, state.obj()),
_state_has_identity(state),
state.key or m._identity_key_from_state(state)
)
)
connection_callable = \
uowtransaction.mapper_flush_opts['connection_callable']
else:
connection = uowtransaction.transaction.connection(self)
for state in _sort_states(states):
m = _state_mapper(state)
tups.append(
(
state,
m,
connection,
_state_has_identity(state),
state.key or m._identity_key_from_state(state)
)
)
connection_callable = None
if not postupdate:
# call before_XXX extensions
for state, mapper, connection, has_identity, instance_key in tups:
tups = []
for state in _sort_states(states):
conn = connection_callable and \
connection_callable(self, state.obj()) or \
connection
has_identity = _state_has_identity(state)
mapper = _state_mapper(state)
instance_key = state.key or mapper._identity_key_from_state(state)
row_switch = None
if not postupdate:
# call before_XXX extensions
if not has_identity:
if 'before_insert' in mapper.extension:
mapper.extension.before_insert(mapper, connection, state.obj())
mapper.extension.before_insert(mapper, conn, state.obj())
else:
if 'before_update' in mapper.extension:
mapper.extension.before_update(mapper, connection, state.obj())
mapper.extension.before_update(mapper, conn, state.obj())
row_switches = {}
if not postupdate:
for state, mapper, connection, has_identity, instance_key in tups:
# detect if we have a "pending" instance (i.e. has no instance_key attached to it),
# and another instance with the same identity key already exists as persistent. convert to an
# UPDATE if so.
# and another instance with the same identity key already exists as persistent.
# convert to an UPDATE if so.
if not has_identity and instance_key in uowtransaction.session.identity_map:
instance = uowtransaction.session.identity_map[instance_key]
existing = attributes.instance_state(instance)
@@ -1320,28 +1313,42 @@ class Mapper(object):
"New instance %s with identity key %s conflicts "
"with persistent instance %s" %
(state_str(state), instance_key, state_str(existing)))
self._log_debug(
"detected row switch for identity %s. will update %s, remove %s from "
"transaction", instance_key, state_str(state), state_str(existing))
"detected row switch for identity %s. "
"will update %s, remove %s from "
"transaction", instance_key,
state_str(state), state_str(existing))
# remove the "delete" flag from the existing element
uowtransaction.set_row_switch(existing)
row_switches[state] = existing
table_to_mapper = self._sorted_tables
row_switch = existing
for table in table_to_mapper.iterkeys():
tups.append(
(state,
mapper,
conn,
has_identity,
instance_key,
row_switch)
)
table_to_mapper = self._sorted_tables
for table in table_to_mapper:
insert = []
update = []
for state, mapper, connection, has_identity, instance_key in tups:
for state, mapper, connection, has_identity, \
instance_key, row_switch in tups:
if table not in mapper._pks_by_table:
continue
pks = mapper._pks_by_table[table]
isinsert = not has_identity and not postupdate and state not in row_switches
isinsert = not has_identity and \
not postupdate and \
not row_switch
params = {}
value_params = {}
@@ -1371,23 +1378,36 @@ class Mapper(object):
value_params[col] = value
else:
params[col.key] = value
insert.append((state, params, mapper, connection, value_params))
insert.append((state, params, mapper,
connection, value_params))
else:
for col in mapper._cols_by_table[table]:
if col is mapper.version_id_col:
params[col._label] = mapper._get_state_attr_by_column(row_switches.get(state, state), col)
params[col.key] = mapper.version_id_generator(params[col._label])
params[col._label] = \
mapper._get_state_attr_by_column(
row_switch or state,
col)
params[col.key] = \
mapper.version_id_generator(params[col._label])
# HACK: check for history, in case the history is only
# in a different table than the one where the version_id_col
# is.
for prop in mapper._columntoproperty.itervalues():
history = attributes.get_state_history(state, prop.key, passive=True)
history = attributes.get_state_history(
state, prop.key, passive=True)
if history.added:
hasdata = True
elif mapper.polymorphic_on is not None and \
mapper.polymorphic_on.shares_lineage(col) and col not in pks:
mapper.polymorphic_on.shares_lineage(col) and \
col not in pks:
pass
else:
if post_update_cols is not None and col not in post_update_cols:
if post_update_cols is not None and \
col not in post_update_cols:
if col in pks:
params[col._label] = mapper._get_state_attr_by_column(state, col)
params[col._label] = \
mapper._get_state_attr_by_column(state, col)
continue
prop = mapper._columntoproperty[col]
@@ -1424,27 +1444,32 @@ class Mapper(object):
elif col in pks:
params[col._label] = mapper._get_state_attr_by_column(state, col)
if hasdata:
update.append((state, params, mapper, connection, value_params))
update.append((state, params, mapper,
connection, value_params))
if update:
mapper = table_to_mapper[table]
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col._label, type_=col.type))
clause.clauses.append(
col ==
sql.bindparam(col._label, type_=col.type)
)
if mapper.version_id_col is not None and \
table.c.contains_column(mapper.version_id_col):
needs_version_id = mapper.version_id_col is not None and \
table.c.contains_column(mapper.version_id_col)
if needs_version_id:
clause.clauses.append(mapper.version_id_col ==\
sql.bindparam(mapper.version_id_col._label, type_=col.type))
statement = table.update(clause)
rows = 0
for state, params, mapper, connection, value_params in update:
c = connection.execute(statement.values(value_params), params)
mapper._postfetch(uowtransaction, connection, table,
mapper._postfetch(uowtransaction, table,
state, c, c.last_updated_params(), value_params)
rows += c.rowcount
@@ -1452,13 +1477,15 @@ class Mapper(object):
if connection.dialect.supports_sane_rowcount:
if rows != len(update):
raise orm_exc.ConcurrentModificationError(
"Updated rowcount %d does not match number of objects updated %d" %
"Updated rowcount %d does not match number "
"of objects updated %d" %
(rows, len(update)))
elif mapper.version_id_col is not None:
elif needs_version_id:
util.warn("Dialect %s does not support updated rowcount "
"- versioning cannot be verified." % c.dialect.dialect_description,
stacklevel=12)
"- versioning cannot be verified." %
c.dialect.dialect_description,
stacklevel=12)
if insert:
statement = table.insert()
@@ -1473,12 +1500,12 @@ class Mapper(object):
len(primary_key) > i:
mapper._set_state_attr_by_column(state, col, primary_key[i])
mapper._postfetch(uowtransaction, connection, table,
mapper._postfetch(uowtransaction, table,
state, c, c.last_inserted_params(), value_params)
if not postupdate:
for state, mapper, connection, has_identity, instance_key in tups:
for state, mapper, connection, has_identity, \
instance_key, row_switch in tups:
# expire readonly attributes
readonly = state.unmodified.intersection(
@@ -1488,8 +1515,8 @@ class Mapper(object):
if readonly:
_expire_state(state, state.dict, readonly)
# if specified, eagerly refresh whatever has
# been expired.
# if eager_defaults option is enabled,
# refresh whatever has been expired.
if self.eager_defaults and state.unloaded:
state.key = self._identity_key_from_state(state)
uowtransaction.session.query(self)._get(
@@ -1504,7 +1531,7 @@ class Mapper(object):
if 'after_update' in mapper.extension:
mapper.extension.after_update(mapper, connection, state.obj())
def _postfetch(self, uowtransaction, connection, table,
def _postfetch(self, uowtransaction, table,
state, resultproxy, params, value_params):
"""Expire attributes in need of newly persisted database state."""
@@ -1523,23 +1550,37 @@ class Mapper(object):
if c.key in params and c in self._columntoproperty:
self._set_state_attr_by_column(state, c, params[c.key])
deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]]
if deferred_props:
_expire_state(state, state.dict, deferred_props)
if postfetch_cols:
_expire_state(state, state.dict,
[self._columntoproperty[c].key
for c in postfetch_cols]
)
# synchronize newly inserted ids from one table to the next
# TODO: this still goes a little too often. would be nice to
# have definitive list of "columns that changed" here
cols = set(table.c)
for m in self.iterate_to_root():
if m._inherits_equated_pairs and \
cols.intersection([l for l, r in m._inherits_equated_pairs]):
sync.populate(state, m, state, m,
m._inherits_equated_pairs,
uowtransaction,
self.passive_updates)
for m, equated_pairs in self._table_to_equated[table]:
sync.populate(state, m, state, m,
equated_pairs,
uowtransaction,
self.passive_updates)
@util.memoized_property
def _table_to_equated(self):
"""memoized map of tables to collections of columns to be
synchronized upwards to the base mapper."""
result = util.defaultdict(list)
for table in self._sorted_tables:
cols = set(table.c)
for m in self.iterate_to_root():
if m._inherits_equated_pairs and \
cols.intersection([l for l, r in m._inherits_equated_pairs]):
result[table].append((m, m._inherits_equated_pairs))
return result
def _delete_obj(self, states, uowtransaction):
"""Issue ``DELETE`` statements for a list of objects.
@@ -1548,50 +1589,95 @@ class Mapper(object):
"""
if 'connection_callable' in uowtransaction.mapper_flush_opts:
connection_callable = uowtransaction.mapper_flush_opts['connection_callable']
tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in _sort_states(states)]
connection_callable = \
uowtransaction.mapper_flush_opts['connection_callable']
else:
connection = uowtransaction.transaction.connection(self)
tups = [(state, _state_mapper(state), connection) for state in _sort_states(states)]
for state, mapper, connection in tups:
connection_callable = None
tups = []
for state in _sort_states(states):
mapper = _state_mapper(state)
conn = connection_callable and \
connection_callable(self, state.obj()) or \
connection
if 'before_delete' in mapper.extension:
mapper.extension.before_delete(mapper, connection, state.obj())
mapper.extension.before_delete(mapper, conn, state.obj())
tups.append((state,
_state_mapper(state),
_state_has_identity(state),
conn))
table_to_mapper = self._sorted_tables
for table in reversed(table_to_mapper.keys()):
delete = {}
for state, mapper, connection in tups:
if table not in mapper._pks_by_table:
delete = util.defaultdict(list)
for state, mapper, has_identity, connection in tups:
if not has_identity or table not in mapper._pks_by_table:
continue
params = {}
if not _state_has_identity(state):
continue
else:
delete.setdefault(connection, []).append(params)
delete[connection].append(params)
for col in mapper._pks_by_table[table]:
params[col.key] = mapper._get_state_attr_by_column(state, col)
if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col)
if mapper.version_id_col is not None and \
table.c.contains_column(mapper.version_id_col):
params[mapper.version_id_col.key] = \
mapper._get_state_attr_by_column(state, mapper.version_id_col)
for connection, del_objects in delete.iteritems():
mapper = table_to_mapper[table]
clause = sql.and_()
for col in mapper._pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col.key, type_=col.type))
if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col):
need_version_id = mapper.version_id_col is not None and \
table.c.contains_column(mapper.version_id_col)
if need_version_id:
clause.clauses.append(
mapper.version_id_col ==
sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type))
statement = table.delete(clause)
c = connection.execute(statement, del_objects)
if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects):
raise orm_exc.ConcurrentModificationError("Deleted rowcount %d does not match "
"number of objects deleted %d" % (c.rowcount, len(del_objects)))
sql.bindparam(
mapper.version_id_col.key,
type_=mapper.version_id_col.type
)
)
for state, mapper, connection in tups:
statement = table.delete(clause)
rows = -1
if need_version_id and \
not connection.dialect.supports_sane_multi_rowcount:
# TODO: need test coverage for this [ticket:1761]
if connection.dialect.supports_sane_rowcount:
rows = 0
# execute deletes individually so that versioned
# rows can be verified
for params in del_objects:
c = connection.execute(statement, params)
rows += c.rowcount
else:
util.warn("Dialect %s does not support deleted rowcount "
"- versioning cannot be verified." %
c.dialect.dialect_description,
stacklevel=12)
connection.execute(statement, del_objects)
else:
c = connection.execute(statement, del_objects)
if connection.dialect.supports_sane_multi_rowcount:
rows = c.rowcount
if rows != -1 and rows != len(del_objects):
raise orm_exc.ConcurrentModificationError(
"Deleted rowcount %d does not match "
"number of objects deleted %d" %
(c.rowcount, len(del_objects))
)
for state, mapper, has_identity, connection in tups:
if 'after_delete' in mapper.extension:
mapper.extension.after_delete(mapper, connection, state.obj())