- added **kw to ClauseElement.compare(), so that we can smarten up the "use_get" operation

- many-to-one relation to a joined-table subclass now uses get()
  for a simple load (known as the "use_get" condition),
  i.e. Related->Sub(Base), without the need
  to redefine the primaryjoin condition in terms of the base
  table. [ticket:1186]
- specifying a foreign key with a declarative column,
  i.e. ForeignKey(MyRelatedClass.id) doesn't break the "use_get"
  condition from taking place [ticket:1492]
This commit is contained in:
Mike Bayer
2009-08-08 22:21:02 +00:00
parent cbdccb7fd2
commit a04da2a417
5 changed files with 148 additions and 14 deletions
+10 -1
View File
@@ -14,7 +14,16 @@
- added "make_transient()" helper function which transforms a persistent/
detached instance into a transient one (i.e. deletes the instance_key
and removes from any session.) [ticket:1052]
- many-to-one "lazyload" fixes:
- many-to-one relation to a joined-table subclass now uses get()
for a simple load (known as the "use_get" condition),
i.e. Related->Sub(Base), without the need
to redefine the primaryjoin condition in terms of the base
table. [ticket:1186]
- specifying a foreign key with a declarative column,
i.e. ForeignKey(MyRelatedClass.id) doesn't break the "use_get"
condition from taking place [ticket:1492]
- sql
- returning() support is native to insert(), update(), delete(). Implementations
of varying levels of functionality exist for Postgresql, Firebird, MSSQL and
+11 -1
View File
@@ -372,8 +372,18 @@ class LazyLoader(AbstractRelationLoader):
# determine if our "lazywhere" clause is the same as the mapper's
# get() clause. then we can just use mapper.get()
#from sqlalchemy.orm import query
self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.__lazywhere)
self.use_get = not self.uselist and \
self.mapper._get_clause[0].compare(
self.__lazywhere,
use_proxies=True,
equivalents=self.mapper._equivalent_columns
)
if self.use_get:
for col in self._equated_columns.keys():
if col in self.mapper._equivalent_columns:
for c in self.mapper._equivalent_columns[col]:
self._equated_columns[c] = self._equated_columns[col]
self.logger.info("%s will use query.get() to optimize instance loads" % self)
def init_class_attribute(self, mapper):
+44 -12
View File
@@ -1101,11 +1101,15 @@ class ClauseElement(Visitable):
bind._convert_to_unique()
return cloned_traverse(self, {}, {'bindparam':visit_bindparam})
def compare(self, other):
def compare(self, other, **kw):
"""Compare this ClauseElement to the given ClauseElement.
Subclasses should override the default behavior, which is a
straight identity comparison.
**kw are arguments consumed by subclass compare() methods and
may be used to modify the criteria for comparison.
(see :class:`ColumnElement`)
"""
return self is other
@@ -1697,6 +1701,34 @@ class ColumnElement(ClauseElement, _CompareMixin):
selectable.columns[name] = co
return co
def compare(self, other, use_proxies=False, equivalents=None, **kw):
"""Compare this ColumnElement to another.
Special arguments understood:
:param use_proxies: when True, consider two columns that
share a common base column as equivalent (i.e. shares_lineage())
:param equivalents: a dictionary of columns as keys mapped to sets
of columns. If the given "other" column is present in this dictionary,
if any of the columns in the correponding set() pass the comparison
test, the result is True. This is used to expand the comparison to
other columns that may be known to be equivalent to this one via
foreign key or other criterion.
"""
to_compare = (other, )
if equivalents and other in equivalents:
to_compare = equivalents[other].union(to_compare)
for oth in to_compare:
if use_proxies and self.shares_lineage(oth):
return True
elif oth is self:
return True
else:
return False
@util.memoized_property
def anon_label(self):
"""provides a constant 'anonymous label' for this ColumnElement.
@@ -2109,7 +2141,7 @@ class _BindParamClause(ColumnElement):
else:
return obj.type
def compare(self, other):
def compare(self, other, **kw):
"""Compare this ``_BindParamClause`` to the given clause.
Since ``compare()`` is meant to compare statement syntax, this
@@ -2274,16 +2306,16 @@ class ClauseList(ClauseElement):
else:
return self
def compare(self, other):
def compare(self, other, **kw):
"""Compare this ``ClauseList`` to the given ``ClauseList``,
including a comparison of all the clause items.
"""
if not isinstance(other, ClauseList) and len(self.clauses) == 1:
return self.clauses[0].compare(other)
return self.clauses[0].compare(other, **kw)
elif isinstance(other, ClauseList) and len(self.clauses) == len(other.clauses):
for i in range(0, len(self.clauses)):
if not self.clauses[i].compare(other.clauses[i]):
if not self.clauses[i].compare(other.clauses[i], **kw):
return False
else:
return self.operator == other.operator
@@ -2473,14 +2505,14 @@ class _UnaryExpression(ColumnElement):
def get_children(self, **kwargs):
return self.element,
def compare(self, other):
def compare(self, other, **kw):
"""Compare this ``_UnaryExpression`` against the given ``ClauseElement``."""
return (
isinstance(other, _UnaryExpression) and
self.operator == other.operator and
self.modifier == other.modifier and
self.element.compare(other.element)
self.element.compare(other.element, **kw)
)
def _negate(self):
@@ -2528,19 +2560,19 @@ class _BinaryExpression(ColumnElement):
def get_children(self, **kwargs):
return self.left, self.right
def compare(self, other):
def compare(self, other, **kw):
"""Compare this ``_BinaryExpression`` against the given ``_BinaryExpression``."""
return (
isinstance(other, _BinaryExpression) and
self.operator == other.operator and
(
self.left.compare(other.left) and
self.right.compare(other.right) or
self.left.compare(other.left, **kw) and
self.right.compare(other.right, **kw) or
(
operators.is_commutative(self.operator) and
self.left.compare(other.right) and
self.right.compare(other.left)
self.left.compare(other.right, **kw) and
self.right.compare(other.left, **kw)
)
)
)
+31
View File
@@ -246,6 +246,37 @@ class DeclarativeTest(DeclarativeTestBase):
Base = decl.declarative_base(cls=MyBase)
assert hasattr(Base, 'metadata')
assert Base().foobar() == "foobar"
def test_uses_get_on_class_col_fk(self):
# test [ticket:1492]
class Master(Base):
__tablename__ = 'master'
id = Column(Integer, primary_key=True)
class Detail(Base):
__tablename__ = 'detail'
id = Column(Integer, primary_key=True)
master_id = Column(None, ForeignKey(Master.id))
master = relation(Master)
Base.metadata.create_all()
compile_mappers()
assert class_mapper(Detail).get_property('master').strategy.use_get
m1 = Master()
d1 = Detail(master=m1)
sess = create_session()
sess.add(d1)
sess.flush()
sess.expunge_all()
d1 = sess.query(Detail).first()
m1 = sess.query(Master).first()
def go():
assert d1.master
self.assert_sql_count(testing.db, go, 0)
def test_index_doesnt_compile(self):
class User(Base):
+52
View File
@@ -208,6 +208,58 @@ class CascadeTest(_base.MappedTest):
assert t4_1 in sess.deleted
sess.flush()
class M2OUseGetTest(_base.MappedTest):
@classmethod
def define_tables(cls, metadata):
Table('base', metadata,
Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('type', String(30))
)
Table('sub', metadata,
Column('id', Integer, ForeignKey('base.id'), primary_key=True),
)
Table('related', metadata,
Column('id', Integer, primary_key=True, test_needs_autoincrement=True),
Column('sub_id', Integer, ForeignKey('sub.id')),
)
@testing.resolve_artifact_names
def test_use_get(self):
# test [ticket:1186]
class Base(_base.BasicEntity):
pass
class Sub(Base):
pass
class Related(Base):
pass
mapper(Base, base, polymorphic_on=base.c.type, polymorphic_identity='b')
mapper(Sub, sub, inherits=Base, polymorphic_identity='s')
mapper(Related, related, properties={
# previously, this was needed for the comparison to occur:
# the 'primaryjoin' looks just like "Sub"'s "get" clause (based on the Base id),
# and foreign_keys since that join condition doesn't actually have any fks in it
#'sub':relation(Sub, primaryjoin=base.c.id==related.c.sub_id, foreign_keys=related.c.sub_id)
# now we can use this:
'sub':relation(Sub)
})
assert class_mapper(Related).get_property('sub').strategy.use_get
sess = create_session()
s1 = Sub()
r1 = Related(sub=s1)
sess.add(r1)
sess.flush()
sess.expunge_all()
r1 = sess.query(Related).first()
s1 = sess.query(Sub).first()
def go():
assert r1.sub
self.assert_sql_count(testing.db, go, 0)
class GetTest(_base.MappedTest):
@classmethod
def define_tables(cls, metadata):