- query.get() now returns None if queried for an identifier

that is present in the identity map with a different class
than the one requested, i.e. when using polymorphic loading.
[ticket:1727]
This commit is contained in:
Mike Bayer
2010-03-13 12:28:50 -05:00
parent 10659a005c
commit 3290ac23df
3 changed files with 137 additions and 111 deletions
+5
View File
@@ -47,6 +47,11 @@ CHANGES
from_statement() to start with since it no longer modifies
the query. [ticket:1688]
- query.get() now returns None if queried for an identifier
that is present in the identity map with a different class
than the one requested, i.e. when using polymorphic loading.
[ticket:1727]
- A major fix in query.join(), when the "on" clause is an
attribute of an aliased() construct, but there is already
an existing join made out to a compatible target, query properly
+18 -12
View File
@@ -1537,17 +1537,23 @@ class Query(object):
only_load_props=None, passive=None):
lockmode = lockmode or self._lockmode
mapper = self._mapper_zero()
if not self._populate_existing and \
not refresh_state and \
not self._mapper_zero().always_refresh and \
not mapper.always_refresh and \
lockmode is None:
instance = self.session.identity_map.get(key)
if instance:
# item present in identity map with a different class
if not issubclass(instance.__class__, mapper.class_):
return None
state = attributes.instance_state(instance)
# expired - ensure it still exists
if state.expired:
if passive is attributes.PASSIVE_NO_FETCH:
return attributes.PASSIVE_NO_RESULT
try:
state()
except orm_exc.ObjectDeletedError:
@@ -1570,8 +1576,6 @@ class Query(object):
q = self._clone()
if ident is not None:
mapper = q._mapper_zero()
params = {}
(_get_clause, _get_params) = mapper._get_clause
# None present in ident - turn those comparisons
@@ -1587,14 +1591,16 @@ class Query(object):
_get_clause = q._adapt_clause(_get_clause, True, False)
q._criterion = _get_clause
for i, primary_key in enumerate(mapper.primary_key):
try:
params[_get_params[primary_key].key] = ident[i]
except IndexError:
raise sa_exc.InvalidRequestError(
"Could not find enough values to formulate primary "
"key for query.get(); primary key columns are %s" %
','.join("'%s'" % c for c in mapper.primary_key))
params = dict([
(_get_params[primary_key].key, id_val)
for id_val, primary_key in zip(ident, mapper.primary_key)
])
if len(params) != len(mapper.primary_key):
raise sa_exc.InvalidRequestError(
"Incorrect number of values in identifier to formulate primary "
"key for query.get(); primary key columns are %s" %
','.join("'%s'" % c for c in mapper.primary_key))
q._params = params
+114 -99
View File
@@ -28,7 +28,7 @@ class O2MTest(_base.MappedTest):
Column('foo_id', Integer, ForeignKey('foo.id'), nullable=False),
Column('data', String(20)))
def testbasic(self):
def test_basic(self):
class Foo(object):
def __init__(self, data=None):
self.data = data
@@ -279,78 +279,88 @@ class GetTest(_base.MappedTest):
Column('foo_id', Integer, ForeignKey('foo.id')),
Column('bar_id', Integer, ForeignKey('bar.id')),
Column('data', String(20)))
@classmethod
def setup_classes(cls):
class Foo(_base.BasicEntity):
pass
def _create_test(polymorphic, name):
def test_get(self):
class Foo(object):
pass
class Bar(Foo):
pass
class Bar(Foo):
pass
class Blub(Bar):
pass
class Blub(Bar):
pass
def test_get_polymorphic(self):
self._do_get_test(True)
def test_get_nonpolymorphic(self):
self._do_get_test(False)
if polymorphic:
mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo')
mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar')
mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub')
else:
mapper(Foo, foo)
mapper(Bar, bar, inherits=Foo)
mapper(Blub, blub, inherits=Bar)
@testing.resolve_artifact_names
def _do_get_test(self, polymorphic):
if polymorphic:
mapper(Foo, foo, polymorphic_on=foo.c.type, polymorphic_identity='foo')
mapper(Bar, bar, inherits=Foo, polymorphic_identity='bar')
mapper(Blub, blub, inherits=Bar, polymorphic_identity='blub')
else:
mapper(Foo, foo)
mapper(Bar, bar, inherits=Foo)
mapper(Blub, blub, inherits=Bar)
sess = create_session()
f = Foo()
b = Bar()
bl = Blub()
sess.add(f)
sess.add(b)
sess.add(bl)
sess.flush()
sess = create_session()
f = Foo()
b = Bar()
bl = Blub()
sess.add(f)
sess.add(b)
sess.add(bl)
sess.flush()
if polymorphic:
def go():
assert sess.query(Foo).get(f.id) == f
assert sess.query(Foo).get(b.id) == b
assert sess.query(Foo).get(bl.id) == bl
assert sess.query(Bar).get(b.id) == b
assert sess.query(Bar).get(bl.id) == bl
assert sess.query(Blub).get(bl.id) == bl
if polymorphic:
def go():
assert sess.query(Foo).get(f.id) is f
assert sess.query(Foo).get(b.id) is b
assert sess.query(Foo).get(bl.id) is bl
assert sess.query(Bar).get(b.id) is b
assert sess.query(Bar).get(bl.id) is bl
assert sess.query(Blub).get(bl.id) is bl
self.assert_sql_count(testing.db, go, 0)
else:
# this is testing the 'wrong' behavior of using get()
# polymorphically with mappers that are not configured to be
# polymorphic. the important part being that get() always
# returns an instance of the query's type.
def go():
assert sess.query(Foo).get(f.id) == f
# test class mismatches - item is present
# in the identity map but we requested a subclass
assert sess.query(Blub).get(f.id) is None
assert sess.query(Blub).get(b.id) is None
assert sess.query(Bar).get(f.id) is None
self.assert_sql_count(testing.db, go, 0)
else:
# this is testing the 'wrong' behavior of using get()
# polymorphically with mappers that are not configured to be
# polymorphic. the important part being that get() always
# returns an instance of the query's type.
def go():
assert sess.query(Foo).get(f.id) is f
bb = sess.query(Foo).get(b.id)
assert isinstance(b, Foo) and bb.id==b.id
bb = sess.query(Foo).get(b.id)
assert isinstance(b, Foo) and bb.id==b.id
bll = sess.query(Foo).get(bl.id)
assert isinstance(bll, Foo) and bll.id==bl.id
bll = sess.query(Foo).get(bl.id)
assert isinstance(bll, Foo) and bll.id==bl.id
assert sess.query(Bar).get(b.id) == b
assert sess.query(Bar).get(b.id) is b
bll = sess.query(Bar).get(bl.id)
assert isinstance(bll, Bar) and bll.id == bl.id
bll = sess.query(Bar).get(bl.id)
assert isinstance(bll, Bar) and bll.id == bl.id
assert sess.query(Blub).get(bl.id) == bl
assert sess.query(Blub).get(bl.id) is bl
self.assert_sql_count(testing.db, go, 3)
self.assert_sql_count(testing.db, go, 3)
test_get = function_named(test_get, name)
return test_get
test_get_polymorphic = _create_test(True, 'test_get_polymorphic')
test_get_nonpolymorphic = _create_test(False, 'test_get_nonpolymorphic')
class EagerLazyTest(_base.MappedTest):
"""tests eager load/lazy load of child items off inheritance mappers, tests that
LazyLoader constructs the right query condition."""
@classmethod
def define_tables(cls, metadata):
global foo, bar, bar_foo
@@ -367,7 +377,7 @@ class EagerLazyTest(_base.MappedTest):
)
@testing.fails_on('maxdb', 'FIXME: unknown')
def testbasic(self):
def test_basic(self):
class Foo(object): pass
class Bar(Foo): pass
@@ -394,7 +404,8 @@ class EagerLazyTest(_base.MappedTest):
self.assert_(len(q.first().eager) == 1)
class EagerTargetingTest(_base.MappedTest):
"""test a scenario where joined table inheritance might be confused as an eagerly loaded joined table."""
"""test a scenario where joined table inheritance might be
confused as an eagerly loaded joined table."""
@classmethod
def define_tables(cls, metadata):
@@ -450,31 +461,32 @@ class EagerTargetingTest(_base.MappedTest):
class FlushTest(_base.MappedTest):
"""test dependency sorting among inheriting mappers"""
@classmethod
def define_tables(cls, metadata):
global users, roles, user_roles, admins
users = Table('users', metadata,
Table('users', metadata,
Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('email', String(128)),
Column('password', String(16)),
)
roles = Table('role', metadata,
Table('roles', metadata,
Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('description', String(32))
)
user_roles = Table('user_role', metadata,
Table('user_roles', metadata,
Column('user_id', Integer, ForeignKey('users.id'), primary_key=True),
Column('role_id', Integer, ForeignKey('role.id'), primary_key=True)
Column('role_id', Integer, ForeignKey('roles.id'), primary_key=True)
)
admins = Table('admin', metadata,
Table('admins', metadata,
Column('admin_id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('user_id', Integer, ForeignKey('users.id'))
)
def testone(self):
@testing.resolve_artifact_names
def test_one(self):
class User(object):pass
class Role(object):pass
class Admin(User):pass
@@ -501,7 +513,8 @@ class FlushTest(_base.MappedTest):
assert user_roles.count().scalar() == 1
def testtwo(self):
@testing.resolve_artifact_names
def test_two(self):
class User(object):
def __init__(self, email=None, password=None):
self.email = email
@@ -541,34 +554,24 @@ class FlushTest(_base.MappedTest):
class VersioningTest(_base.MappedTest):
@classmethod
def define_tables(cls, metadata):
global base, subtable, stuff
base = Table('base', metadata,
Table('base', metadata,
Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('version_id', Integer, nullable=False),
Column('value', String(40)),
Column('discriminator', Integer, nullable=False)
)
subtable = Table('subtable', metadata,
Table('subtable', metadata,
Column('id', None, ForeignKey('base.id'), primary_key=True),
Column('subdata', String(50))
)
stuff = Table('stuff', metadata,
Table('stuff', metadata,
Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('parent', Integer, ForeignKey('base.id'))
)
def setup(self):
super(VersioningTest, self).setup()
if not testing.db.dialect.supports_sane_rowcount:
self._warnings_filters = warnings.filters[:]
warnings.filterwarnings('ignore', category=sa_exc.SAWarning)
def teardown(self):
super(VersioningTest, self).teardown()
if not testing.db.dialect.supports_sane_rowcount:
warnings.filters[:] = self._warnings_filters
@testing.emits_warning(r".*updated rowcount")
@engines.close_open_connections
@testing.resolve_artifact_names
def test_save_update(self):
class Base(_fixtures.Base):
pass
@@ -577,7 +580,10 @@ class VersioningTest(_base.MappedTest):
class Stuff(Base):
pass
mapper(Stuff, stuff)
mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1, properties={
mapper(Base, base,
polymorphic_on=base.c.discriminator,
version_id_col=base.c.version_id,
polymorphic_identity=1, properties={
'stuff':relation(Stuff)
})
mapper(Sub, subtable, inherits=Base, polymorphic_identity=2)
@@ -599,17 +605,14 @@ class VersioningTest(_base.MappedTest):
sess.flush()
try:
sess2.query(Base).with_lockmode('read').get(s1.id)
assert False
except orm_exc.ConcurrentModificationError, e:
assert True
assert_raises(orm_exc.ConcurrentModificationError,
sess2.query(Base).with_lockmode('read').get,
s1.id)
try:
if not testing.db.dialect.supports_sane_rowcount:
sess2.flush()
assert not testing.db.dialect.supports_sane_rowcount
except orm_exc.ConcurrentModificationError, e:
assert True
else:
assert_raises(orm_exc.ConcurrentModificationError, sess2.flush)
sess2.refresh(s2)
if testing.db.dialect.supports_sane_rowcount:
@@ -617,13 +620,17 @@ class VersioningTest(_base.MappedTest):
s2.subdata = 'sess2 subdata'
sess2.flush()
@testing.emits_warning(r".*updated rowcount")
@testing.resolve_artifact_names
def test_delete(self):
class Base(_fixtures.Base):
pass
class Sub(Base):
pass
mapper(Base, base, polymorphic_on=base.c.discriminator, version_id_col=base.c.version_id, polymorphic_identity=1)
mapper(Base, base,
polymorphic_on=base.c.discriminator,
version_id_col=base.c.version_id, polymorphic_identity=1)
mapper(Sub, subtable, inherits=Base, polymorphic_identity=2)
sess = create_session()
@@ -697,17 +704,24 @@ class DistinctPKTest(_base.MappedTest):
def test_explicit_props(self):
person_mapper = mapper(Person, person_table)
mapper(Employee, employee_table, inherits=person_mapper, properties={'pid':person_table.c.id, 'eid':employee_table.c.id})
mapper(Employee, employee_table, inherits=person_mapper,
properties={'pid':person_table.c.id,
'eid':employee_table.c.id})
self._do_test(True)
def test_explicit_composite_pk(self):
person_mapper = mapper(Person, person_table)
try:
mapper(Employee, employee_table, inherits=person_mapper, primary_key=[person_table.c.id, employee_table.c.id])
self._do_test(True)
assert False
except sa_exc.SAWarning, e:
assert str(e) == "On mapper Mapper|Employee|employees, primary key column 'employees.id' is being combined with distinct primary key column 'persons.id' in attribute 'id'. Use explicit properties to give each column its own mapped attribute name.", str(e)
mapper(Employee, employee_table,
inherits=person_mapper,
primary_key=[person_table.c.id, employee_table.c.id])
assert_raises_message(sa_exc.SAWarning,
r"On mapper Mapper\|Employee\|employees, "
"primary key column 'employees.id' is being "
"combined with distinct primary key column 'persons.id' "
"in attribute 'id'. Use explicit properties to give "
"each column its own mapped attribute name.",
self._do_test, True
)
def test_explicit_pk(self):
person_mapper = mapper(Person, person_table)
@@ -1242,6 +1256,7 @@ class DeleteOrphanTest(_base.MappedTest):
s1 = SubClass(data='s1')
sess.add(s1)
assert_raises_message(orm_exc.FlushError,
"is not attached to any parent 'Parent' instance via that classes' 'related' attribute", sess.flush)
r"is not attached to any parent 'Parent' instance via "
"that classes' 'related' attribute", sess.flush)