mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-06-04 23:06:24 -04:00
- fix a regression from ref #3178, where dialects that don't actually support
sane multi rowcount (e.g. pyodbc) would fail on multirow update. add a test that mocks this breakage into plain dialects
This commit is contained in:
@@ -617,6 +617,14 @@ def _emit_update_statements(base_mapper, uowtransaction,
|
||||
rows = 0
|
||||
records = list(records)
|
||||
|
||||
# TODO: would be super-nice to not have to determine this boolean
|
||||
# inside the loop here, in the 99.9999% of the time there's only
|
||||
# one connection in use
|
||||
assert_singlerow = connection.dialect.supports_sane_rowcount
|
||||
assert_multirow = assert_singlerow and \
|
||||
connection.dialect.supports_sane_multi_rowcount
|
||||
allow_multirow = not needs_version_id or assert_multirow
|
||||
|
||||
if hasvalue:
|
||||
for state, state_dict, params, mapper, \
|
||||
connection, value_params in records:
|
||||
@@ -635,9 +643,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
|
||||
value_params)
|
||||
rows += c.rowcount
|
||||
else:
|
||||
if needs_version_id and \
|
||||
not connection.dialect.supports_sane_multi_rowcount and \
|
||||
connection.dialect.supports_sane_rowcount:
|
||||
if not allow_multirow:
|
||||
for state, state_dict, params, mapper, \
|
||||
connection, value_params in records:
|
||||
c = cached_connections[connection].\
|
||||
@@ -654,6 +660,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
|
||||
rows += c.rowcount
|
||||
else:
|
||||
multiparams = [rec[2] for rec in records]
|
||||
|
||||
c = cached_connections[connection].\
|
||||
execute(statement, multiparams)
|
||||
|
||||
@@ -670,7 +677,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
|
||||
c.context.compiled_parameters[0],
|
||||
value_params)
|
||||
|
||||
if connection.dialect.supports_sane_rowcount:
|
||||
if assert_multirow or assert_singlerow and \
|
||||
len(multiparams) == 1:
|
||||
if rows != len(records):
|
||||
raise orm_exc.StaleDataError(
|
||||
"UPDATE statement on table '%s' expected to "
|
||||
|
||||
@@ -3,13 +3,13 @@ from sqlalchemy import testing
|
||||
from sqlalchemy.testing import engines
|
||||
from sqlalchemy.testing.schema import Table, Column
|
||||
from test.orm import _fixtures
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy.testing import fixtures
|
||||
from sqlalchemy import exc, util
|
||||
from sqlalchemy.testing import fixtures, config
|
||||
from sqlalchemy import Integer, String, ForeignKey, func
|
||||
from sqlalchemy.orm import mapper, relationship, backref, \
|
||||
create_session, unitofwork, attributes,\
|
||||
Session, exc as orm_exc
|
||||
from sqlalchemy.testing.mock import Mock
|
||||
from sqlalchemy.testing.mock import Mock, patch
|
||||
from sqlalchemy.testing.assertsql import AllOf, CompiledSQL
|
||||
from sqlalchemy import event
|
||||
|
||||
@@ -1473,6 +1473,67 @@ class BasicStaleChecksTest(fixtures.MappedTest):
|
||||
sess.flush
|
||||
)
|
||||
|
||||
def test_update_single_missing_broken_multi_rowcount(self):
|
||||
@util.memoized_property
|
||||
def rowcount(self):
|
||||
if len(self.context.compiled_parameters) > 1:
|
||||
return -1
|
||||
else:
|
||||
return self.context.rowcount
|
||||
|
||||
with patch.object(
|
||||
config.db.dialect, "supports_sane_multi_rowcount", False):
|
||||
with patch(
|
||||
"sqlalchemy.engine.result.ResultProxy.rowcount",
|
||||
rowcount):
|
||||
Parent, Child = self._fixture()
|
||||
sess = Session()
|
||||
p1 = Parent(id=1, data=2)
|
||||
sess.add(p1)
|
||||
sess.flush()
|
||||
|
||||
sess.execute(self.tables.parent.delete())
|
||||
|
||||
p1.data = 3
|
||||
assert_raises_message(
|
||||
orm_exc.StaleDataError,
|
||||
"UPDATE statement on table 'parent' expected to "
|
||||
"update 1 row\(s\); 0 were matched.",
|
||||
sess.flush
|
||||
)
|
||||
|
||||
def test_update_multi_missing_broken_multi_rowcount(self):
|
||||
@util.memoized_property
|
||||
def rowcount(self):
|
||||
if len(self.context.compiled_parameters) > 1:
|
||||
return -1
|
||||
else:
|
||||
return self.context.rowcount
|
||||
|
||||
with patch.object(
|
||||
config.db.dialect, "supports_sane_multi_rowcount", False):
|
||||
with patch(
|
||||
"sqlalchemy.engine.result.ResultProxy.rowcount",
|
||||
rowcount):
|
||||
Parent, Child = self._fixture()
|
||||
sess = Session()
|
||||
p1 = Parent(id=1, data=2)
|
||||
p2 = Parent(id=2, data=3)
|
||||
sess.add_all([p1, p2])
|
||||
sess.flush()
|
||||
|
||||
sess.execute(self.tables.parent.delete().where(Parent.id == 1))
|
||||
|
||||
p1.data = 3
|
||||
p2.data = 4
|
||||
sess.flush() # no exception
|
||||
|
||||
# update occurred for remaining row
|
||||
eq_(
|
||||
sess.query(Parent.id, Parent.data).all(),
|
||||
[(2, 4)]
|
||||
)
|
||||
|
||||
@testing.requires.sane_multi_rowcount
|
||||
def test_delete_multi_missing_warning(self):
|
||||
Parent, Child = self._fixture()
|
||||
@@ -1544,6 +1605,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
|
||||
T(id=10, data='t10', def_='def3'),
|
||||
T(id=11, data='t11'),
|
||||
])
|
||||
|
||||
self.assert_sql_execution(
|
||||
testing.db,
|
||||
sess.flush,
|
||||
|
||||
Reference in New Issue
Block a user