- removed query.min()/max()/sum()/avg(). these should be called using column arguments or values in conjunction with func.

- fixed [ticket:1008], count() works with single table inheritance
- changed the relationship of InstrumentedAttribute to class such that each subclass in an inheritance hierarchy gets a unique InstrumentedAttribute per column-oriented attribute, including for the same underlying ColumnProperty.  This allows expressions from subclasses to be annotated accurately so that Query can get a hold of the exact entities to be queried when using column-based expressions.  This repairs various polymorphic scenarios with both single and joined table inheritance.
- still to be determined is what does something like query(Person.name, Engineer.engineer_info) do; currently it's problematic.  Even trickier is query(Person.name, Engineer.engineer_info, Manager.manager_name)
This commit is contained in:
Mike Bayer
2008-06-02 03:07:12 +00:00
parent e3e1535720
commit e525aee015
8 changed files with 280 additions and 90 deletions
+6 -4
View File
@@ -439,9 +439,10 @@ class PropComparator(expression.ColumnOperators):
return a.has(b, **kwargs)
has_op = staticmethod(has_op)
def __init__(self, prop):
def __init__(self, prop, mapper):
self.prop = self.property = prop
self.mapper = mapper
def of_type_op(a, class_):
return a.of_type(class_)
of_type_op = staticmethod(of_type_op)
@@ -753,11 +754,12 @@ class LoaderStrategy(object):
def __init__(self, parent):
self.parent_property = parent
self.is_class_level = False
def init(self):
self.parent = self.parent_property.parent
self.key = self.parent_property.key
def init(self):
raise NotImplementedError("LoaderStrategy")
def init_class_attribute(self):
pass
+19 -16
View File
@@ -38,7 +38,7 @@ class ColumnProperty(StrategizedProperty):
self.columns = [expression._labeled(c) for c in columns]
self.group = kwargs.pop('group', None)
self.deferred = kwargs.pop('deferred', False)
self.comparator = ColumnProperty.ColumnComparator(self)
self.comparator_factory = ColumnProperty.ColumnComparator
util.set_creation_order(self)
if self.deferred:
self.strategy_class = strategies.DeferredColumnLoader
@@ -80,7 +80,7 @@ class ColumnProperty(StrategizedProperty):
class ColumnComparator(PropComparator):
def __clause_element__(self):
return self.prop.columns[0]._annotate({"parententity": self.prop.parent})
return self.prop.columns[0]._annotate({"parententity": self.mapper})
__clause_element__ = util.cache_decorator(__clause_element__)
def operate(self, op, *other, **kwargs):
@@ -101,7 +101,7 @@ class CompositeProperty(ColumnProperty):
def __init__(self, class_, *columns, **kwargs):
super(CompositeProperty, self).__init__(*columns, **kwargs)
self.composite_class = class_
self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self)
self.comparator_factory = kwargs.pop('comparator', CompositeProperty.Comparator)
self.strategy_class = strategies.CompositeColumnLoader
def do_init(self):
@@ -170,8 +170,7 @@ class SynonymProperty(MapperProperty):
def do_init(self):
class_ = self.parent.class_
def comparator():
return self.parent._get_property(self.key, resolve_synonyms=True).comparator
self.logger.info("register managed attribute %s on class %s" % (self.key, class_.__name__))
if self.descriptor is None:
class SynonymProp(object):
@@ -184,7 +183,14 @@ class SynonymProperty(MapperProperty):
return s
return getattr(obj, self.name)
self.descriptor = SynonymProp()
sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator, parententity=self.parent)
def comparator_callable(prop, mapper):
def comparator():
prop = self.parent._get_property(self.key, resolve_synonyms=True)
return prop.comparator_factory(prop, mapper)
return comparator
strategies.DefaultColumnLoader(self)._register_attribute(None, None, False, comparator_callable, proxy_property=self.descriptor)
def merge(self, session, source, dest, _recursive):
pass
@@ -195,18 +201,13 @@ class ComparableProperty(MapperProperty):
def __init__(self, comparator_factory, descriptor=None):
self.descriptor = descriptor
self.comparator = comparator_factory(self)
self.comparator_factory = comparator_factory
util.set_creation_order(self)
def do_init(self):
"""Set up a proxy to the unmanaged descriptor."""
class_ = self.parent.class_
# refactor me
sessionlib.register_attribute(class_, self.key, uselist=False,
proxy_property=self.descriptor,
useobject=False,
comparator=self.comparator)
strategies.DefaultColumnLoader(self)._register_attribute(None, None, False, self.comparator_factory, proxy_property=self.descriptor)
def setup(self, context, entity, path, adapter, **kwargs):
pass
@@ -252,10 +253,11 @@ class PropertyLoader(StrategizedProperty):
self.passive_updates = passive_updates
self.remote_side = remote_side
self.enable_typechecks = enable_typechecks
self.comparator = PropertyLoader.Comparator(self)
self.comparator = PropertyLoader.Comparator(self, None)
self.join_depth = join_depth
self.local_remote_pairs = _local_remote_pairs
self.__join_cache = {}
self.comparator_factory = PropertyLoader.Comparator
util.set_creation_order(self)
if strategy_class:
@@ -295,8 +297,9 @@ class PropertyLoader(StrategizedProperty):
self._is_backref = _is_backref
class Comparator(PropComparator):
def __init__(self, prop, of_type=None):
def __init__(self, prop, mapper, of_type=None):
self.prop = self.property = prop
self.mapper = mapper
if of_type:
self._of_type = _class_to_mapper(of_type)
@@ -314,7 +317,7 @@ class PropertyLoader(StrategizedProperty):
return op(self, *other, **kwargs)
def of_type(self, cls):
return PropertyLoader.Comparator(self.prop, cls)
return PropertyLoader.Comparator(self.prop, self.mapper, cls)
def __eq__(self, other):
if other is None:
+58 -41
View File
@@ -97,7 +97,14 @@ class Query(object):
self.__setup_aliasizers(self._entities)
def __setup_aliasizers(self, entities):
d = {}
if hasattr(self, '_mapper_adapter_map'):
# usually safe to share a single map, but copying to prevent
# subtle leaks if end-user is reusing base query with arbitrary
# number of aliased() objects
self._mapper_adapter_map = d = self._mapper_adapter_map.copy()
else:
self._mapper_adapter_map = d = {}
for ent in entities:
for entity in ent.entities:
if entity not in d:
@@ -114,7 +121,7 @@ class Query(object):
d[entity] = (mapper, adapter, selectable, is_aliased_class, with_polymorphic)
ent.setup_entity(entity, *d[entity])
def __mapper_loads_polymorphically_with(self, mapper, adapter):
for m2 in mapper._with_polymorphic_mappers:
for m in m2.iterate_to_root():
@@ -650,26 +657,6 @@ class Query(object):
return self.filter(sql.and_(*clauses))
def min(self, col):
"""Execute the SQL ``min()`` function against the given column."""
return self._col_aggregate(col, sql.func.min)
def max(self, col):
"""Execute the SQL ``max()`` function against the given column."""
return self._col_aggregate(col, sql.func.max)
def sum(self, col):
"""Execute the SQL ``sum()`` function against the given column."""
return self._col_aggregate(col, sql.func.sum)
def avg(self, col):
"""Execute the SQL ``avg()`` function against the given column."""
return self._col_aggregate(col, sql.func.avg)
def order_by(self, *criterion):
"""apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``"""
@@ -1213,18 +1200,17 @@ class Query(object):
_should_nest_selectable = property(_should_nest_selectable)
def count(self):
"""Apply this query's criterion to a SELECT COUNT statement.
this is the purely generative version which will become
the public method in version 0.5.
"""
return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._mapper_zero().primary_key))
"""Apply this query's criterion to a SELECT COUNT statement."""
return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._only_mapper_zero().primary_key))
def _col_aggregate(self, col, func, nested_cols=None):
whereclause = self._criterion
context = QueryContext(self)
self._adjust_for_single_inheritance(context)
whereclause = context.whereclause
from_obj = self.__mapper_zero_from_obj()
if self._should_nest_selectable:
@@ -1371,7 +1357,9 @@ class Query(object):
froms = [context.from_clause] # "load from a single FROM" mode, i.e. when select_from() or join() is used
else:
froms = context.froms # "load from discrete FROMs" mode, i.e. when each _MappedEntity has its own FROM
self._adjust_for_single_inheritance(context)
if eager_joins and self._should_nest_selectable:
# for eager joins present and LIMIT/OFFSET/DISTINCT, wrap the query inside a select,
# then append eager joins onto that
@@ -1382,7 +1370,15 @@ class Query(object):
context.order_by = None
order_by_col_expr = []
inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=froms, use_labels=labels, correlate=False, order_by=context.order_by, **self._select_args)
inner = sql.select(
context.primary_columns + order_by_col_expr,
context.whereclause,
from_obj=froms,
use_labels=labels,
correlate=False,
order_by=context.order_by,
**self._select_args
)
if self._correlate:
inner = inner.correlate(*self._correlate)
@@ -1418,7 +1414,17 @@ class Query(object):
froms += context.eager_joins.values()
statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=froms, use_labels=labels, for_update=for_update, correlate=False, order_by=context.order_by, **self._select_args)
statement = sql.select(
context.primary_columns + context.secondary_columns,
context.whereclause,
from_obj=froms,
use_labels=labels,
for_update=for_update,
correlate=False,
order_by=context.order_by,
**self._select_args
)
if self._correlate:
statement = statement.correlate(*self._correlate)
@@ -1429,6 +1435,22 @@ class Query(object):
return context
def _adjust_for_single_inheritance(self, context):
"""Apply single-table-inheritance filtering.
For all distinct single-table-inheritance mappers represented in the columns
clause of this query, add criterion to the WHERE clause of the given QueryContext
such that only the appropriate subtypes are selected from the total results.
"""
for entity, (mapper, adapter, s, i, w) in self._mapper_adapter_map.iteritems():
if mapper.single and mapper.inherits and mapper.polymorphic_on and mapper.polymorphic_identity is not None:
crit = mapper.polymorphic_on.in_([m.polymorphic_identity for m in mapper.polymorphic_iterator()])
if adapter:
crit = adapter.traverse(crit)
crit = self._adapt_clause(crit, False, False)
context.whereclause = sql.and_(context.whereclause, crit)
def __log_debug(self, msg):
self.logger.debug(msg)
@@ -1463,7 +1485,7 @@ class _MapperEntity(_QueryEntity):
self.entities = [entity]
self.entity_zero = entity
self.entity_name = entity_name
def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic):
self.mapper = mapper
self.extension = self.mapper.extension
@@ -1554,15 +1576,10 @@ class _MapperEntity(_QueryEntity):
return main, entname
def setup_context(self, query, context):
# if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
# that we only load the appropriate types
if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()]))
adapter = self._get_entity_clauses(query, context)
context.froms.append(self.selectable)
adapter = self._get_entity_clauses(query, context)
if context.order_by is False and self.mapper.order_by:
context.order_by = self.mapper.order_by
+61 -13
View File
@@ -17,11 +17,31 @@ from sqlalchemy.orm import session as sessionlib
from sqlalchemy.orm import util as mapperutil
class ColumnLoader(LoaderStrategy):
"""Default column loader."""
class DefaultColumnLoader(LoaderStrategy):
def _register_attribute(self, compare_function, copy_function, mutable_scalars, comparator_factory, callable_=None, proxy_property=None):
self.logger.info("%s register managed attribute" % self)
for mapper in self.parent.polymorphic_iterator():
if mapper is self.parent or not mapper.concrete:
sessionlib.register_attribute(
mapper.class_,
self.key,
uselist=False,
useobject=False,
copy_function=copy_function,
compare_function=compare_function,
mutable_scalars=mutable_scalars,
comparator=comparator_factory(self.parent_property, mapper),
parententity=mapper,
callable_=callable_,
proxy_property=proxy_property
)
DefaultColumnLoader.logger = log.class_logger(DefaultColumnLoader)
class ColumnLoader(DefaultColumnLoader):
def init(self):
super(ColumnLoader, self).init()
self.columns = self.parent_property.columns
self._should_log_debug = log.is_debug_enabled(self.logger)
self.is_composite = hasattr(self.parent_property, 'composite_class')
@@ -34,9 +54,14 @@ class ColumnLoader(LoaderStrategy):
def init_class_attribute(self):
self.is_class_level = True
self.logger.info("%s register managed attribute" % self)
coltype = self.columns[0].type
sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent)
self._register_attribute(
coltype.compare_values,
coltype.copy_value,
self.columns[0].type.is_mutable(),
self.parent_property.comparator_factory
)
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
key, col = self.key, self.columns[0]
@@ -78,7 +103,13 @@ class CompositeColumnLoader(ColumnLoader):
return False
else:
return True
sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator, parententity=self.parent)
self._register_attribute(
compare,
copy,
True,
self.parent_property.comparator_factory
)
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
key, columns, composite_class = self.key, self.columns, self.parent_property.composite_class
@@ -106,7 +137,7 @@ class CompositeColumnLoader(ColumnLoader):
CompositeColumnLoader.logger = log.class_logger(CompositeColumnLoader)
class DeferredColumnLoader(LoaderStrategy):
class DeferredColumnLoader(DefaultColumnLoader):
"""Deferred column loader, a per-column or per-column-group lazy loader."""
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
@@ -130,7 +161,6 @@ class DeferredColumnLoader(LoaderStrategy):
return (new_execute, None)
def init(self):
super(DeferredColumnLoader, self).init()
if hasattr(self.parent_property, 'composite_class'):
raise NotImplementedError("Deferred loading for composite types not implemented yet")
self.columns = self.parent_property.columns
@@ -139,8 +169,13 @@ class DeferredColumnLoader(LoaderStrategy):
def init_class_attribute(self):
self.is_class_level = True
self.logger.info("%s register managed attribute" % self)
sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent)
self._register_attribute(
self.columns[0].type.compare_values,
self.columns[0].type.copy_value,
self.columns[0].type.is_mutable(),
self.parent_property.comparator_factory,
callable_=self.class_level_loader,
)
def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs):
if \
@@ -238,7 +273,6 @@ class UndeferGroupOption(MapperOption):
class AbstractRelationLoader(LoaderStrategy):
def init(self):
super(AbstractRelationLoader, self).init()
for attr in ['mapper', 'target', 'table', 'uselist']:
setattr(self, attr, getattr(self.parent_property, attr))
self._should_log_debug = log.is_debug_enabled(self.logger)
@@ -249,7 +283,7 @@ class AbstractRelationLoader(LoaderStrategy):
else:
state.initialize(self.key)
def _register_attribute(self, class_, callable_=None, **kwargs):
def _register_attribute(self, class_, callable_=None, impl_class=None, **kwargs):
self.logger.info("%s register managed %s attribute" % (self, (self.uselist and "collection" or "scalar")))
if self.parent_property.backref:
@@ -257,7 +291,21 @@ class AbstractRelationLoader(LoaderStrategy):
else:
attribute_ext = None
sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=attribute_ext, cascade=self.parent_property.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, parententity=self.parent, **kwargs)
sessionlib.register_attribute(
class_,
self.key,
uselist=self.uselist,
useobject=True,
extension=attribute_ext,
cascade=self.parent_property.cascade,
trackparent=True,
typecallable=self.parent_property.collection_class,
callable_=callable_,
comparator=self.parent_property.comparator,
parententity=self.parent,
impl_class=impl_class,
**kwargs
)
class NoLoader(AbstractRelationLoader):
def init_class_attribute(self):
+8 -7
View File
@@ -1,6 +1,6 @@
import testenv; testenv.configure_for_tests()
from testlib import testing, sa
from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData
from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, func
from sqlalchemy.orm import mapper, relation, create_session
from testlib.testing import eq_
from testlib.compat import set
@@ -57,8 +57,9 @@ class GenerativeQueryTest(_base.MappedTest):
sess = create_session()
query = sess.query(Foo)
assert query.count() == 100
assert query.filter(foo.c.bar<30).min(foo.c.bar) == 0
assert query.filter(foo.c.bar<30).max(foo.c.bar) == 29
assert sess.query(func.min(foo.c.bar)).filter(foo.c.bar<30).one() == (0,)
assert sess.query(func.max(foo.c.bar)).filter(foo.c.bar<30).one() == (29,)
assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29
assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29
@@ -68,14 +69,14 @@ class GenerativeQueryTest(_base.MappedTest):
testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma')):
return
query = create_session().query(Foo)
assert query.filter(foo.c.bar<30).sum(foo.c.bar) == 435
query = create_session().query(func.sum(foo.c.bar))
assert query.filter(foo.c.bar<30).one() == (435,)
@testing.fails_on('firebird', 'mssql')
@testing.resolve_artifact_names
def test_aggregate_2(self):
query = create_session().query(Foo)
avg = query.filter(foo.c.bar < 30).avg(foo.c.bar)
query = create_session().query(func.avg(foo.c.bar))
avg = query.filter(foo.c.bar < 30).one()[0]
eq_(round(avg, 1), 14.5)
@testing.resolve_artifact_names
+36 -1
View File
@@ -510,7 +510,7 @@ def make_test(select_type):
self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1)
def test_mixed_entities(self):
sess = create_session()
@@ -525,6 +525,41 @@ def make_test(select_type):
[(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
u'Elbonia, Inc.')]
)
self.assertEquals(
sess.query(Manager.name).all(),
[('pointy haired boss', ), ('dogbert',)]
)
self.assertEquals(
sess.query(Manager.name + " foo").all(),
[('pointy haired boss foo', ), ('dogbert foo',)]
)
self.assertEquals(
sess.query(Engineer.name, Engineer.primary_language).all(),
[(u'dilbert', u'java'), (u'wally', u'c++'), (u'vlad', u'cobol')]
)
self.assertEquals(
sess.query(Boss.name, Boss.golf_swing).all(),
[(u'pointy haired boss', u'fore')]
)
# TODO: I think raise error on these for now. different inheritance/loading schemes have different
# results here, all incorrect
#
# self.assertEquals(
# sess.query(Person.name, Engineer.primary_language).all(),
# []
# )
# self.assertEquals(
# sess.query(Person.name, Engineer.primary_language, Manager.manager_name).all(),
# []
# )
self.assertEquals(
sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
+91 -7
View File
@@ -3,8 +3,9 @@ from sqlalchemy import *
from sqlalchemy.orm import *
from testlib import *
from testlib.fixtures import Base
from orm._base import MappedTest, ComparableEntity
class SingleInheritanceTest(ORMTest):
class SingleInheritanceTest(MappedTest):
def define_tables(self, metadata):
global employees_table
employees_table = Table('employees', metadata,
@@ -14,9 +15,9 @@ class SingleInheritanceTest(ORMTest):
Column('engineer_info', String(50)),
Column('type', String(20))
)
def test_single_inheritance(self):
class Employee(Base):
def setup_classes(self):
class Employee(ComparableEntity):
pass
class Manager(Employee):
pass
@@ -25,19 +26,22 @@ class SingleInheritanceTest(ORMTest):
class JuniorEngineer(Engineer):
pass
@testing.resolve_artifact_names
def setup_mappers(self):
mapper(Employee, employees_table, polymorphic_on=employees_table.c.type)
mapper(Manager, inherits=Employee, polymorphic_identity='manager')
mapper(Engineer, inherits=Employee, polymorphic_identity='engineer')
mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer')
@testing.resolve_artifact_names
def test_single_inheritance(self):
session = create_session()
m1 = Manager(name='Tom', manager_data='knows how to manage things')
e1 = Engineer(name='Kurt', engineer_info='knows how to hack')
e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed')
session.save(m1)
session.save(e1)
session.save(e2)
session.add_all([m1, e1, e2])
session.flush()
assert session.query(Employee).all() == [m1, e1, e2]
@@ -48,6 +52,86 @@ class SingleInheritanceTest(ORMTest):
m1 = session.query(Manager).one()
session.expire(m1, ['manager_data'])
self.assertEquals(m1.manager_data, "knows how to manage things")
@testing.resolve_artifact_names
def test_multi_qualification(self):
session = create_session()
m1 = Manager(name='Tom', manager_data='knows how to manage things')
e1 = Engineer(name='Kurt', engineer_info='knows how to hack')
e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed')
session.add_all([m1, e1, e2])
session.flush()
ealias = aliased(Engineer)
self.assertEquals(
session.query(Manager, ealias).all(),
[(m1, e1), (m1, e2)]
)
self.assertEquals(
session.query(Manager.name).all(),
[("Tom",)]
)
self.assertEquals(
session.query(Manager.name, ealias.name).all(),
[("Tom", "Kurt"), ("Tom", "Ed")]
)
self.assertEquals(
session.query(func.upper(Manager.name), func.upper(ealias.name)).all(),
[("TOM", "KURT"), ("TOM", "ED")]
)
self.assertEquals(
session.query(Manager).add_entity(ealias).all(),
[(m1, e1), (m1, e2)]
)
self.assertEquals(
session.query(Manager.name).add_column(ealias.name).all(),
[("Tom", "Kurt"), ("Tom", "Ed")]
)
# TODO: I think raise error on this for now
# self.assertEquals(
# session.query(Employee.name, Manager.manager_data, Engineer.engineer_info).all(),
# []
# )
@testing.resolve_artifact_names
def test_select_from(self):
sess = create_session()
m1 = Manager(name='Tom', manager_data='data1')
m2 = Manager(name='Tom2', manager_data='data2')
e1 = Engineer(name='Kurt', engineer_info='knows how to hack')
e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed')
sess.add_all([m1, m2, e1, e2])
sess.flush()
self.assertEquals(
sess.query(Manager).select_from(employees_table.select().limit(10)).all(),
[m1, m2]
)
@testing.resolve_artifact_names
def test_count(self):
sess = create_session()
m1 = Manager(name='Tom', manager_data='data1')
m2 = Manager(name='Tom2', manager_data='data2')
e1 = Engineer(name='Kurt', engineer_info='data3')
e2 = JuniorEngineer(name='marvin', engineer_info='data4')
sess.add_all([m1, m2, e1, e2])
sess.flush()
self.assertEquals(sess.query(Manager).count(), 2)
self.assertEquals(sess.query(Engineer).count(), 2)
self.assertEquals(sess.query(Employee).count(), 4)
self.assertEquals(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2)
self.assertEquals(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3)
class SingleOnJoinedTest(ORMTest):
def define_tables(self, metadata):
+1 -1
View File
@@ -542,7 +542,7 @@ class AggregateTest(QueryTest):
def test_sum(self):
sess = create_session()
orders = sess.query(Order).filter(Order.id.in_([2, 3, 4]))
assert orders.sum(Order.user_id * Order.address_id) == 79
self.assertEquals(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,))
def test_apply(self):
sess = create_session()