- any(), has(), contains(), attribute level == and != now

work properly with self-referential relations - the clause
inside the EXISTS is aliased on the "remote" side to
distinguish it from the parent table.
This commit is contained in:
Mike Bayer
2008-02-17 01:15:43 +00:00
parent 191dbee5c8
commit a3f67fecb2
3 changed files with 82 additions and 28 deletions
+8 -1
View File
@@ -1,7 +1,14 @@
=======
CHANGES
=======
0.4.4
------
- orm
- any(), has(), contains(), attribute level == and != now
work properly with self-referential relations - the clause
inside the EXISTS is aliased on the "remote" side to
distinguish it from the parent table.
0.4.4
------
+31 -20
View File
@@ -15,7 +15,7 @@ from sqlalchemy.sql.util import ClauseAdapter, ColumnsInClause
from sqlalchemy.sql import visitors, operators, ColumnElement
from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper
from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm.util import CascadeOptions
from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses
from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty
from sqlalchemy.exceptions import ArgumentError
import weakref
@@ -265,33 +265,44 @@ class PropertyLoader(StrategizedProperty):
return sql.and_(*clauses)
else:
return self.prop._optimized_compare(other)
def _join_and_criterion(self, criterion=None, **kwargs):
if self.prop._is_self_referential():
pac = PropertyAliasedClauses(self.prop,
self.prop.primaryjoin,
self.prop.secondaryjoin)
j = pac.primaryjoin
if pac.secondaryjoin:
j = j & pac.secondaryjoin
else:
j = self.prop.primaryjoin
if self.prop.secondaryjoin:
j = j & self.prop.secondaryjoin
def any(self, criterion=None, **kwargs):
if not self.prop.uselist:
raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
j = self.prop.primaryjoin
if self.prop.secondaryjoin:
j = j & self.prop.secondaryjoin
for k in kwargs:
crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
if criterion is None:
criterion = crit
else:
criterion = criterion & crit
if criterion and self.prop._is_self_referential():
criterion = pac.adapt_clause(criterion)
return j, criterion
def any(self, criterion=None, **kwargs):
if not self.prop.uselist:
raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().")
j, criterion = self._join_and_criterion(criterion, **kwargs)
return sql.exists([1], j & criterion)
def has(self, criterion=None, **kwargs):
if self.prop.uselist:
raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().")
j = self.prop.primaryjoin
if self.prop.secondaryjoin:
j = j & self.prop.secondaryjoin
for k in kwargs:
crit = (getattr(self.prop.mapper.class_, k) == kwargs[k])
if criterion is None:
criterion = crit
else:
criterion = criterion & crit
j, criterion = self._join_and_criterion(criterion, **kwargs)
return sql.exists([1], j & criterion)
def contains(self, other):
@@ -309,11 +320,11 @@ class PropertyLoader(StrategizedProperty):
def __ne__(self, other):
if self.prop.uselist and not hasattr(other, '__iter__'):
raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object")
criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))])
j, criterion = self._join_and_criterion(criterion)
j = self.prop.primaryjoin
if self.prop.secondaryjoin:
j = j & self.prop.secondaryjoin
return ~sql.exists([1], j & sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]))
return ~sql.exists([1], j & criterion)
def compare(self, op, value, value_is_parent=False):
if op == operators.eq:
+43 -7
View File
@@ -1121,15 +1121,20 @@ class CustomJoinTest(QueryTest):
assert [User(id=7)] == q.join(['open_orders', 'items'], aliased=True).filter(Item.id==4).join(['closed_orders', 'items'], aliased=True).filter(Item.id==3).all()
class SelfReferentialJoinTest(ORMTest):
class SelfReferentialTest(ORMTest):
keep_mappers = True
keep_data = True
def define_tables(self, metadata):
global nodes
nodes = Table('nodes', metadata,
Column('id', Integer, primary_key=True),
Column('parent_id', Integer, ForeignKey('nodes.id')),
Column('data', String(30)))
def test_join(self):
def insert_data(self):
global Node
class Node(Base):
def append(self, node):
self.children.append(node)
@@ -1149,11 +1154,11 @@ class SelfReferentialJoinTest(ORMTest):
n1.children[1].append(Node(data='n123'))
sess.save(n1)
sess.flush()
sess.clear()
sess.close()
def test_join(self):
sess = create_session()
# TODO: the aliasing of the join in query._join_to has to limit the aliasing
# among local_side / remote_side (add local_side as an attribute on PropertyLoader)
# also implement this idea in EagerLoader
node = sess.query(Node).join('children', aliased=True).filter_by(data='n122').first()
assert node.data=='n12'
@@ -1164,6 +1169,37 @@ class SelfReferentialJoinTest(ORMTest):
join('parent', aliased=True, from_joinpoint=True).filter_by(data='n1').first()
assert node.data == 'n122'
def test_any(self):
sess = create_session()
self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n1')).all(), [])
self.assertEquals(sess.query(Node).filter(Node.children.any(Node.data=='n12')).all(), [Node(data='n1')])
self.assertEquals(sess.query(Node).filter(~Node.children.any()).all(), [Node(data='n11'), Node(data='n13'),Node(data='n121'),Node(data='n122'),Node(data='n123'),])
def test_has(self):
sess = create_session()
self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n12')).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
self.assertEquals(sess.query(Node).filter(Node.parent.has(Node.data=='n122')).all(), [])
self.assertEquals(sess.query(Node).filter(~Node.parent.has()).all(), [Node(data='n1')])
def test_contains(self):
sess = create_session()
n122 = sess.query(Node).filter(Node.data=='n122').one()
self.assertEquals(sess.query(Node).filter(Node.children.contains(n122)).all(), [Node(data='n12')])
n13 = sess.query(Node).filter(Node.data=='n13').one()
self.assertEquals(sess.query(Node).filter(Node.children.contains(n13)).all(), [Node(data='n1')])
def test_eq_ne(self):
sess = create_session()
n12 = sess.query(Node).filter(Node.data=='n12').one()
self.assertEquals(sess.query(Node).filter(Node.parent==n12).all(), [Node(data='n121'),Node(data='n122'),Node(data='n123')])
self.assertEquals(sess.query(Node).filter(Node.parent != n12).all(), [Node(data='n1'), Node(data='n11'), Node(data='n12'), Node(data='n13')])
class ExternalColumnsTest(QueryTest):
keep_mappers = False