diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 228cfef3aa..c7850ac1da 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -40,32 +40,58 @@ def save_obj(base_mapper, states, uowtransaction, single=False): save_obj(base_mapper, [state], uowtransaction, single=True) return - states_to_insert, states_to_update = _organize_states_for_save( - base_mapper, - states, - uowtransaction) - + states_to_update = [] + states_to_insert = [] cached_connections = _cached_connection_dict(base_mapper) + for (state, dict_, mapper, connection, + has_identity, row_switch) in _organize_states_for_save( + base_mapper, states, uowtransaction + ): + if has_identity or row_switch: + states_to_update.append( + (state, dict_, mapper, connection, + has_identity, row_switch) + ) + else: + states_to_insert.append( + (state, dict_, mapper, connection, + has_identity, row_switch) + ) + for table, mapper in base_mapper._sorted_tables.items(): - insert = _collect_insert_commands(base_mapper, uowtransaction, - table, states_to_insert) + if table not in mapper._pks_by_table: + continue + insert = ( + (state, state_dict, mapper, connection) + for state, state_dict, mapper, connection, has_identity, + row_switch in states_to_insert + ) + insert = _collect_insert_commands(table, insert) - update = _collect_update_commands(base_mapper, uowtransaction, - table, states_to_update) + update = ( + (state, state_dict, mapper, connection, row_switch) + for state, state_dict, mapper, connection, has_identity, + row_switch in states_to_update + ) + update = _collect_update_commands(uowtransaction, table, update) - if update: - _emit_update_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, update) + _emit_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) - if insert: - _emit_insert_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, insert) + _emit_insert_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, insert) - _finalize_insert_update_commands(base_mapper, uowtransaction, - states_to_insert, states_to_update) + _finalize_insert_update_commands( + base_mapper, uowtransaction, + ( + (state, state_dict, mapper, connection, has_identity) + for state, state_dict, mapper, connection, has_identity, + row_switch in states_to_insert + states_to_update + ) + ) def post_update(base_mapper, states, uowtransaction, post_update_cols): @@ -75,19 +101,20 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): """ cached_connections = _cached_connection_dict(base_mapper) - states_to_update = _organize_states_for_post_update( + states_to_update = list(_organize_states_for_post_update( base_mapper, - states, uowtransaction) + states, uowtransaction)) for table, mapper in base_mapper._sorted_tables.items(): + if table not in mapper._pks_by_table: + continue update = _collect_post_update_commands(base_mapper, uowtransaction, table, states_to_update, post_update_cols) - if update: - _emit_post_update_statements(base_mapper, uowtransaction, - cached_connections, - mapper, table, update) + _emit_post_update_statements(base_mapper, uowtransaction, + cached_connections, + mapper, table, update) def delete_obj(base_mapper, states, uowtransaction): @@ -100,19 +127,21 @@ def delete_obj(base_mapper, states, uowtransaction): cached_connections = _cached_connection_dict(base_mapper) - states_to_delete = _organize_states_for_delete( + states_to_delete = list(_organize_states_for_delete( base_mapper, states, - uowtransaction) + uowtransaction)) table_to_mapper = base_mapper._sorted_tables for table in reversed(list(table_to_mapper.keys())): + mapper = table_to_mapper[table] + if table not in mapper._pks_by_table: + continue + delete = _collect_delete_commands(base_mapper, uowtransaction, table, states_to_delete) - mapper = table_to_mapper[table] - _emit_delete_statements(base_mapper, uowtransaction, cached_connections, mapper, table, delete) @@ -133,9 +162,6 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): """ - states_to_insert = [] - states_to_update = [] - for state, dict_, mapper, connection in _connections_for_states( base_mapper, uowtransaction, states): @@ -181,18 +207,8 @@ def _organize_states_for_save(base_mapper, states, uowtransaction): uowtransaction.remove_state_actions(existing) row_switch = existing - if not has_identity and not row_switch: - states_to_insert.append( - (state, dict_, mapper, connection, - has_identity, row_switch) - ) - else: - states_to_update.append( - (state, dict_, mapper, connection, - has_identity, row_switch) - ) - - return states_to_insert, states_to_update + yield (state, dict_, mapper, connection, + has_identity, row_switch) def _organize_states_for_post_update(base_mapper, states, @@ -205,8 +221,7 @@ def _organize_states_for_post_update(base_mapper, states, the execution per state. """ - return list(_connections_for_states(base_mapper, uowtransaction, - states)) + return _connections_for_states(base_mapper, uowtransaction, states) def _organize_states_for_delete(base_mapper, states, uowtransaction): @@ -217,28 +232,21 @@ def _organize_states_for_delete(base_mapper, states, uowtransaction): mapper, the connection to use for the execution per state. """ - states_to_delete = [] - for state, dict_, mapper, connection in _connections_for_states( base_mapper, uowtransaction, states): mapper.dispatch.before_delete(mapper, connection, state) - states_to_delete.append((state, dict_, mapper, - bool(state.key), connection)) - return states_to_delete + yield state, dict_, mapper, bool(state.key), connection -def _collect_insert_commands(base_mapper, uowtransaction, table, - states_to_insert): +def _collect_insert_commands(table, states_to_insert): """Identify sets of values to use in INSERT statements for a list of states. """ - insert = [] - for state, state_dict, mapper, connection, has_identity, \ - row_switch in states_to_insert: + for state, state_dict, mapper, connection in states_to_insert: if table not in mapper._pks_by_table: continue @@ -262,7 +270,7 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, has_all_pks = mapper._pk_keys_by_table[table].issubset(params) - if base_mapper.eager_defaults: + if mapper.base_mapper.eager_defaults: has_all_defaults = mapper._server_default_cols[table].\ issubset(params) else: @@ -274,14 +282,13 @@ def _collect_insert_commands(base_mapper, uowtransaction, table, params[mapper.version_id_col.key] = \ mapper.version_id_generator(None) - insert.append((state, state_dict, params, mapper, - connection, value_params, has_all_pks, - has_all_defaults)) - return insert + yield ( + state, state_dict, params, mapper, + connection, value_params, has_all_pks, + has_all_defaults) -def _collect_update_commands(base_mapper, uowtransaction, - table, states_to_update): +def _collect_update_commands(uowtransaction, table, states_to_update): """Identify sets of values to use in UPDATE statements for a list of states. @@ -293,9 +300,7 @@ def _collect_update_commands(base_mapper, uowtransaction, """ - update = [] - for state, state_dict, mapper, connection, has_identity, \ - row_switch in states_to_update: + for state, state_dict, mapper, connection, row_switch in states_to_update: if table not in mapper._pks_by_table: continue @@ -368,9 +373,9 @@ def _collect_update_commands(base_mapper, uowtransaction, "Can't update table using NULL for primary " "key value") params.update(pk_params) - update.append((state, state_dict, params, mapper, - connection, value_params)) - return update + yield ( + state, state_dict, params, mapper, + connection, value_params) def _collect_post_update_commands(base_mapper, uowtransaction, table, @@ -380,7 +385,6 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, """ - update = [] for state, state_dict, mapper, connection in states_to_update: if table not in mapper._pks_by_table: continue @@ -405,9 +409,7 @@ def _collect_post_update_commands(base_mapper, uowtransaction, table, params[col.key] = value hasdata = True if hasdata: - update.append((state, state_dict, params, mapper, - connection)) - return update + yield params, connection def _collect_delete_commands(base_mapper, uowtransaction, table, @@ -415,15 +417,12 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, """Identify values to use in DELETE statements for a list of states to be deleted.""" - delete = util.defaultdict(list) - for state, state_dict, mapper, has_identity, connection \ in states_to_delete: if not has_identity or table not in mapper._pks_by_table: continue params = {} - delete[connection].append(params) for col in mapper._pks_by_table[table]: params[col.key] = \ value = \ @@ -441,7 +440,7 @@ def _collect_delete_commands(base_mapper, uowtransaction, table, mapper._get_committed_state_attr_by_column( state, state_dict, mapper.version_id_col) - return delete + yield params, connection def _emit_update_statements(base_mapper, uowtransaction, @@ -481,8 +480,7 @@ def _emit_update_statements(base_mapper, uowtransaction, lambda rec: ( rec[4], tuple(sorted(rec[2])), - bool(rec[5])) - ): + bool(rec[5]))): rows = 0 records = list(records) @@ -652,11 +650,10 @@ def _emit_post_update_statements(base_mapper, uowtransaction, # also group them into common (connection, cols) sets # to support executemany(). for key, grouper in groupby( - update, lambda rec: (rec[4], list(rec[2].keys())) + update, lambda rec: (rec[1], sorted(rec[0])) ): connection = key[0] - multiparams = [params for state, state_dict, - params, mapper, conn in grouper] + multiparams = [params for params, conn in grouper] cached_connections[connection].\ execute(statement, multiparams) @@ -686,8 +683,15 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, return table.delete(clause) - for connection, del_objects in delete.items(): - statement = base_mapper._memo(('delete', table), delete_stmt) + statement = base_mapper._memo(('delete', table), delete_stmt) + for connection, recs in groupby( + delete, + lambda rec: rec[1] + ): + del_objects = [ + params + for params, connection in recs + ] connection = cached_connections[connection] @@ -740,15 +744,12 @@ def _emit_delete_statements(base_mapper, uowtransaction, cached_connections, ) -def _finalize_insert_update_commands(base_mapper, uowtransaction, - states_to_insert, states_to_update): +def _finalize_insert_update_commands(base_mapper, uowtransaction, states): """finalize state on states that have been inserted or updated, including calling after_insert/after_update events. """ - for state, state_dict, mapper, connection, has_identity, \ - row_switch in states_to_insert + \ - states_to_update: + for state, state_dict, mapper, connection, has_identity in states: if mapper._readonly_props: readonly = state.unmodified_intersection(