Files
sqlalchemy/test/orm/generative.py
T
Jason Kirtland 8fa48edbf9 - Removed 2.3 set emulations/enhancements.
(sets.Set-based collections & DB-API returns still work.)
2008-07-15 19:23:52 +00:00

273 lines
8.9 KiB
Python

import testenv; testenv.configure_for_tests()
from testlib import testing, sa
from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, func
from sqlalchemy.orm import mapper, relation, create_session
from testlib.testing import eq_
from orm import _base, _fixtures
class GenerativeQueryTest(_base.MappedTest):
run_inserts = 'once'
run_deletes = None
def define_tables(self, metadata):
Table('foo', metadata,
Column('id', Integer, sa.Sequence('foo_id_seq'), primary_key=True),
Column('bar', Integer),
Column('range', Integer))
def fixtures(self):
rows = tuple([(i, i % 10) for i in range(100)])
foo_data = (('bar', 'range'),) + rows
return dict(foo=foo_data)
@testing.resolve_artifact_names
def setup_mappers(self):
class Foo(_base.BasicEntity):
pass
mapper(Foo, foo)
@testing.resolve_artifact_names
def test_selectby(self):
res = create_session().query(Foo).filter_by(range=5)
assert res.order_by([Foo.bar])[0].bar == 5
assert res.order_by([sa.desc(Foo.bar)])[0].bar == 95
@testing.crashes('mssql', 'FIXME: verify not fails_on()')
@testing.fails_on('maxdb')
@testing.resolve_artifact_names
def test_slice(self):
sess = create_session()
query = sess.query(Foo)
orig = query.all()
assert query[1] == orig[1]
assert list(query[10:20]) == orig[10:20]
assert list(query[10:]) == orig[10:]
assert list(query[:10]) == orig[:10]
assert list(query[:10]) == orig[:10]
assert list(query[10:40:3]) == orig[10:40:3]
assert list(query[-5:]) == orig[-5:]
assert query[10:20][5] == orig[10:20][5]
@testing.uses_deprecated('Call to deprecated function apply_max')
@testing.resolve_artifact_names
def test_aggregate(self):
sess = create_session()
query = sess.query(Foo)
assert query.count() == 100
assert sess.query(func.min(foo.c.bar)).filter(foo.c.bar<30).one() == (0,)
assert sess.query(func.max(foo.c.bar)).filter(foo.c.bar<30).one() == (29,)
assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29
assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29
@testing.resolve_artifact_names
def test_aggregate_1(self):
if (testing.against('mysql') and
testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma')):
return
query = create_session().query(func.sum(foo.c.bar))
assert query.filter(foo.c.bar<30).one() == (435,)
@testing.fails_on('firebird', 'mssql')
@testing.resolve_artifact_names
def test_aggregate_2(self):
query = create_session().query(func.avg(foo.c.bar))
avg = query.filter(foo.c.bar < 30).one()[0]
eq_(round(avg, 1), 14.5)
@testing.resolve_artifact_names
def test_aggregate_3(self):
query = create_session().query(Foo)
avg_f = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).next()[0]
assert round(avg_f, 1) == 14.5
avg_o = query.filter(foo.c.bar<30).values(sa.func.avg(foo.c.bar)).next()[0]
assert round(avg_o, 1) == 14.5
@testing.resolve_artifact_names
def test_filter(self):
query = create_session().query(Foo)
assert query.count() == 100
assert query.filter(Foo.bar < 30).count() == 30
res2 = query.filter(Foo.bar < 30).filter(Foo.bar > 10)
assert res2.count() == 19
@testing.resolve_artifact_names
def test_options(self):
query = create_session().query(Foo)
class ext1(sa.orm.MapperExtension):
def populate_instance(self, mapper, selectcontext, row, instance, **flags):
instance.TEST = "hello world"
return sa.orm.EXT_CONTINUE
assert query.options(sa.orm.extension(ext1()))[0].TEST == "hello world"
@testing.resolve_artifact_names
def test_order_by(self):
query = create_session().query(Foo)
assert query.order_by([Foo.bar])[0].bar == 0
assert query.order_by([sa.desc(Foo.bar)])[0].bar == 99
@testing.resolve_artifact_names
def test_offset(self):
query = create_session().query(Foo)
assert list(query.order_by([Foo.bar]).offset(10))[0].bar == 10
@testing.resolve_artifact_names
def test_offset(self):
query = create_session().query(Foo)
assert len(list(query.limit(10))) == 10
class GenerativeTest2(_base.MappedTest):
def define_tables(self, metadata):
Table('Table1', metadata,
Column('id', Integer, primary_key=True))
Table('Table2', metadata,
Column('t1id', Integer, ForeignKey("Table1.id"),
primary_key=True),
Column('num', Integer, primary_key=True))
@testing.resolve_artifact_names
def setup_mappers(self):
class Obj1(_base.BasicEntity):
pass
class Obj2(_base.BasicEntity):
pass
mapper(Obj1, Table1)
mapper(Obj2, Table2)
def fixtures(self):
return dict(
Table1=(('id',),
(1,),
(2,),
(3,),
(4,)),
Table2=(('num', 't1id'),
(1, 1),
(2, 1),
(3, 1),
(4, 2),
(5, 2),
(6, 3)))
@testing.resolve_artifact_names
def test_distinct_count(self):
query = create_session().query(Obj1)
eq_(query.count(), 4)
res = query.filter(sa.and_(Table1.c.id == Table2.c.t1id,
Table2.c.t1id == 1))
eq_(res.count(), 3)
res = query.filter(sa.and_(Table1.c.id == Table2.c.t1id,
Table2.c.t1id == 1)).distinct()
eq_(res.count(), 1)
class RelationsTest(_fixtures.FixtureTest):
run_setup_mappers = 'once'
run_inserts = 'once'
run_deletes = None
@testing.resolve_artifact_names
def setup_mappers(self):
mapper(User, users, properties={
'orders':relation(mapper(Order, orders, properties={
'addresses':relation(mapper(Address, addresses))}))})
@testing.resolve_artifact_names
def test_join(self):
"""Query.join"""
session = create_session()
q = (session.query(User).join(['orders', 'addresses']).
filter(Address.id == 1))
eq_([User(id=7)], q.all())
@testing.resolve_artifact_names
def test_outer_join(self):
"""Query.outerjoin"""
session = create_session()
q = (session.query(User).outerjoin(['orders', 'addresses']).
filter(sa.or_(Order.id == None, Address.id == 1)))
eq_(set([User(id=7), User(id=8), User(id=10)]),
set(q.all()))
@testing.resolve_artifact_names
def test_outer_join_count(self):
"""test the join and outerjoin functions on Query"""
session = create_session()
q = (session.query(User).outerjoin(['orders', 'addresses']).
filter(sa.or_(Order.id == None, Address.id == 1)))
eq_(q.count(), 4)
@testing.resolve_artifact_names
def test_from(self):
session = create_session()
sel = users.outerjoin(orders).outerjoin(
addresses, orders.c.address_id == addresses.c.id)
q = (session.query(User).select_from(sel).
filter(sa.or_(Order.id == None, Address.id == 1)))
eq_(set([User(id=7), User(id=8), User(id=10)]),
set(q.all()))
class CaseSensitiveTest(_base.MappedTest):
def define_tables(self, metadata):
Table('Table1', metadata,
Column('ID', Integer, primary_key=True))
Table('Table2', metadata,
Column('T1ID', Integer, ForeignKey("Table1.ID"),
primary_key=True),
Column('NUM', Integer, primary_key=True))
@testing.resolve_artifact_names
def setup_mappers(self):
class Obj1(_base.BasicEntity):
pass
class Obj2(_base.BasicEntity):
pass
mapper(Obj1, Table1)
mapper(Obj2, Table2)
def fixtures(self):
return dict(
Table1=(('ID',),
(1,),
(2,),
(3,),
(4,)),
Table2=(('NUM', 'T1ID'),
(1, 1),
(2, 1),
(3, 1),
(4, 2),
(5, 2),
(6, 3)))
@testing.resolve_artifact_names
def test_distinct_count(self):
q = create_session(bind=testing.db).query(Obj1)
assert q.count() == 4
res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1))
assert res.count() == 3
res = q.filter(sa.and_(Table1.c.ID==Table2.c.T1ID,Table2.c.T1ID==1)).distinct()
self.assertEqual(res.count(), 1)
if __name__ == "__main__":
testenv.main()