mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-06-03 22:35:46 -04:00
further refinement to the inheritance "descriptor" detection such that
local columns will still override superclass descriptors.
This commit is contained in:
@@ -639,18 +639,25 @@ class Mapper(object):
|
||||
|
||||
return getattr(getattr(cls, clskey), key)
|
||||
|
||||
def _should_exclude(self, name):
|
||||
def _should_exclude(self, name, local):
|
||||
"""determine whether a particular property should be implicitly present on the class.
|
||||
|
||||
This occurs when properties are propagated from an inherited class, or are
|
||||
applied from the columns present in the mapped table.
|
||||
|
||||
"""
|
||||
# check for an existing descriptor
|
||||
if getattr(self.class_, name, None) \
|
||||
and hasattr(getattr(self.class_, name), '__get__'):
|
||||
return True
|
||||
|
||||
# check for descriptors, either local or from
|
||||
# an inherited class
|
||||
if local:
|
||||
if self.class_.__dict__.get(name, None)\
|
||||
and hasattr(self.class_.__dict__[name], '__get__'):
|
||||
return True
|
||||
else:
|
||||
if getattr(self.class_, name, None)\
|
||||
and hasattr(getattr(self.class_, name), '__get__'):
|
||||
return True
|
||||
|
||||
if (self.include_properties is not None and
|
||||
name not in self.include_properties):
|
||||
self.__log("not including property %s" % (name))
|
||||
@@ -681,7 +688,7 @@ class Mapper(object):
|
||||
# pull properties from the inherited mapper if any.
|
||||
if self.inherits:
|
||||
for key, prop in self.inherits.__props.iteritems():
|
||||
if key not in self.__props and not self._should_exclude(key):
|
||||
if key not in self.__props and not self._should_exclude(key, local=False):
|
||||
self._adapt_inherited_property(key, prop)
|
||||
|
||||
# create properties for each column in the mapped table,
|
||||
@@ -690,9 +697,6 @@ class Mapper(object):
|
||||
if column in self._columntoproperty:
|
||||
continue
|
||||
|
||||
if self._should_exclude(column.key):
|
||||
continue
|
||||
|
||||
column_key = (self.column_prefix or '') + column.key
|
||||
|
||||
# adjust the "key" used for this column to that
|
||||
@@ -700,14 +704,15 @@ class Mapper(object):
|
||||
for mapper in self.iterate_to_root():
|
||||
if column in mapper._columntoproperty:
|
||||
column_key = mapper._columntoproperty[column].key
|
||||
|
||||
self._compile_property(column_key, column, init=False, setparent=True)
|
||||
|
||||
if not self._should_exclude(column_key, local=self.local_table.c.contains_column(column)):
|
||||
self._compile_property(column_key, column, init=False, setparent=True)
|
||||
|
||||
# do a special check for the "discriminiator" column, as it may only be present
|
||||
# in the 'with_polymorphic' selectable but we need it for the base mapper
|
||||
if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty:
|
||||
col = self.mapped_table.corresponding_column(self.polymorphic_on) or self.polymorphic_on
|
||||
if self._should_exclude(col.key):
|
||||
if self._should_exclude(col.key, local=False):
|
||||
raise sa_exc.InvalidRequestError("Cannot exclude or override the discriminator column %r" % col.key)
|
||||
self._compile_property(col.key, ColumnProperty(col), init=False, setparent=True)
|
||||
|
||||
|
||||
@@ -873,39 +873,57 @@ class OverrideColKeyTest(ORMTest):
|
||||
sess.flush()
|
||||
assert sess.query(Sub).one().data == "im the data"
|
||||
|
||||
def test_two_levels(self):
|
||||
def test_sub_columns_over_base_descriptors(self):
|
||||
class Base(object):
|
||||
pass
|
||||
@property
|
||||
def subdata(self):
|
||||
return "this is base"
|
||||
|
||||
class Sub(Base):
|
||||
@property
|
||||
def data(self):
|
||||
return "im sub"
|
||||
pass
|
||||
|
||||
class SubSub(Sub):
|
||||
@property
|
||||
def data(self):
|
||||
return "im sub sub"
|
||||
|
||||
mapper(Base, base)
|
||||
mapper(Sub, subtable, inherits=Base)
|
||||
mapper(SubSub, inherits=Sub)
|
||||
|
||||
sess = create_session()
|
||||
s1 = Sub()
|
||||
assert s1.data == "im sub"
|
||||
s2 = SubSub()
|
||||
assert s2.data == "im sub sub"
|
||||
b1 = Base()
|
||||
b1.data="this is some data"
|
||||
assert b1.data == "this is some data"
|
||||
|
||||
sess.add_all([s1, s2, b1])
|
||||
assert b1.subdata == "this is base"
|
||||
s1 = Sub()
|
||||
s1.subdata = "this is sub"
|
||||
assert s1.subdata == "this is sub"
|
||||
|
||||
sess.add_all([s1, b1])
|
||||
sess.flush()
|
||||
sess.clear()
|
||||
|
||||
assert sess.query(Sub).get(s1.base_id).data == "im sub"
|
||||
assert sess.query(SubSub).get(s2.base_id).data == "im sub sub"
|
||||
assert sess.query(Base).get(b1.base_id).subdata == "this is base"
|
||||
assert sess.query(Sub).get(s1.base_id).subdata == "this is sub"
|
||||
|
||||
def test_base_descriptors_over_base_cols(self):
|
||||
class Base(object):
|
||||
@property
|
||||
def data(self):
|
||||
return "this is base"
|
||||
|
||||
class Sub(Base):
|
||||
pass
|
||||
|
||||
mapper(Base, base)
|
||||
mapper(Sub, subtable, inherits=Base)
|
||||
|
||||
sess = create_session()
|
||||
b1 = Base()
|
||||
assert b1.data == "this is base"
|
||||
s1 = Sub()
|
||||
assert s1.data == "this is base"
|
||||
|
||||
sess.add_all([s1, b1])
|
||||
sess.flush()
|
||||
sess.clear()
|
||||
|
||||
assert sess.query(Base).get(b1.base_id).data == "this is base"
|
||||
assert sess.query(Sub).get(s1.base_id).data == "this is base"
|
||||
|
||||
|
||||
class DeleteOrphanTest(ORMTest):
|
||||
def define_tables(self, metadata):
|
||||
|
||||
Reference in New Issue
Block a user