- fix a regression from ref #3178, where dialects that don't actually support

sane multi rowcount (e.g. pyodbc) would fail on multirow update.  add
a test that mocks this breakage into plain dialects
This commit is contained in:
Mike Bayer
2015-01-17 21:36:52 -05:00
parent 469b6fabaf
commit f49c367ef7
2 changed files with 77 additions and 7 deletions
+12 -4
View File
@@ -617,6 +617,14 @@ def _emit_update_statements(base_mapper, uowtransaction,
rows = 0
records = list(records)
# TODO: would be super-nice to not have to determine this boolean
# inside the loop here, in the 99.9999% of the time there's only
# one connection in use
assert_singlerow = connection.dialect.supports_sane_rowcount
assert_multirow = assert_singlerow and \
connection.dialect.supports_sane_multi_rowcount
allow_multirow = not needs_version_id or assert_multirow
if hasvalue:
for state, state_dict, params, mapper, \
connection, value_params in records:
@@ -635,9 +643,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
value_params)
rows += c.rowcount
else:
if needs_version_id and \
not connection.dialect.supports_sane_multi_rowcount and \
connection.dialect.supports_sane_rowcount:
if not allow_multirow:
for state, state_dict, params, mapper, \
connection, value_params in records:
c = cached_connections[connection].\
@@ -654,6 +660,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
rows += c.rowcount
else:
multiparams = [rec[2] for rec in records]
c = cached_connections[connection].\
execute(statement, multiparams)
@@ -670,7 +677,8 @@ def _emit_update_statements(base_mapper, uowtransaction,
c.context.compiled_parameters[0],
value_params)
if connection.dialect.supports_sane_rowcount:
if assert_multirow or assert_singlerow and \
len(multiparams) == 1:
if rows != len(records):
raise orm_exc.StaleDataError(
"UPDATE statement on table '%s' expected to "
+65 -3
View File
@@ -3,13 +3,13 @@ from sqlalchemy import testing
from sqlalchemy.testing import engines
from sqlalchemy.testing.schema import Table, Column
from test.orm import _fixtures
from sqlalchemy import exc
from sqlalchemy.testing import fixtures
from sqlalchemy import exc, util
from sqlalchemy.testing import fixtures, config
from sqlalchemy import Integer, String, ForeignKey, func
from sqlalchemy.orm import mapper, relationship, backref, \
create_session, unitofwork, attributes,\
Session, exc as orm_exc
from sqlalchemy.testing.mock import Mock
from sqlalchemy.testing.mock import Mock, patch
from sqlalchemy.testing.assertsql import AllOf, CompiledSQL
from sqlalchemy import event
@@ -1473,6 +1473,67 @@ class BasicStaleChecksTest(fixtures.MappedTest):
sess.flush
)
def test_update_single_missing_broken_multi_rowcount(self):
@util.memoized_property
def rowcount(self):
if len(self.context.compiled_parameters) > 1:
return -1
else:
return self.context.rowcount
with patch.object(
config.db.dialect, "supports_sane_multi_rowcount", False):
with patch(
"sqlalchemy.engine.result.ResultProxy.rowcount",
rowcount):
Parent, Child = self._fixture()
sess = Session()
p1 = Parent(id=1, data=2)
sess.add(p1)
sess.flush()
sess.execute(self.tables.parent.delete())
p1.data = 3
assert_raises_message(
orm_exc.StaleDataError,
"UPDATE statement on table 'parent' expected to "
"update 1 row\(s\); 0 were matched.",
sess.flush
)
def test_update_multi_missing_broken_multi_rowcount(self):
@util.memoized_property
def rowcount(self):
if len(self.context.compiled_parameters) > 1:
return -1
else:
return self.context.rowcount
with patch.object(
config.db.dialect, "supports_sane_multi_rowcount", False):
with patch(
"sqlalchemy.engine.result.ResultProxy.rowcount",
rowcount):
Parent, Child = self._fixture()
sess = Session()
p1 = Parent(id=1, data=2)
p2 = Parent(id=2, data=3)
sess.add_all([p1, p2])
sess.flush()
sess.execute(self.tables.parent.delete().where(Parent.id == 1))
p1.data = 3
p2.data = 4
sess.flush() # no exception
# update occurred for remaining row
eq_(
sess.query(Parent.id, Parent.data).all(),
[(2, 4)]
)
@testing.requires.sane_multi_rowcount
def test_delete_multi_missing_warning(self):
Parent, Child = self._fixture()
@@ -1544,6 +1605,7 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults):
T(id=10, data='t10', def_='def3'),
T(id=11, data='t11'),
])
self.assert_sql_execution(
testing.db,
sess.flush,