mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-06-04 23:06:24 -04:00
- removed query.min()/max()/sum()/avg(). these should be called using column arguments or values in conjunction with func.
- fixed [ticket:1008], count() works with single table inheritance - changed the relationship of InstrumentedAttribute to class such that each subclass in an inheritance hierarchy gets a unique InstrumentedAttribute per column-oriented attribute, including for the same underlying ColumnProperty. This allows expressions from subclasses to be annotated accurately so that Query can get a hold of the exact entities to be queried when using column-based expressions. This repairs various polymorphic scenarios with both single and joined table inheritance. - still to be determined is what does something like query(Person.name, Engineer.engineer_info) do; currently it's problematic. Even trickier is query(Person.name, Engineer.engineer_info, Manager.manager_name)
This commit is contained in:
@@ -439,9 +439,10 @@ class PropComparator(expression.ColumnOperators):
|
||||
return a.has(b, **kwargs)
|
||||
has_op = staticmethod(has_op)
|
||||
|
||||
def __init__(self, prop):
|
||||
def __init__(self, prop, mapper):
|
||||
self.prop = self.property = prop
|
||||
|
||||
self.mapper = mapper
|
||||
|
||||
def of_type_op(a, class_):
|
||||
return a.of_type(class_)
|
||||
of_type_op = staticmethod(of_type_op)
|
||||
@@ -753,11 +754,12 @@ class LoaderStrategy(object):
|
||||
def __init__(self, parent):
|
||||
self.parent_property = parent
|
||||
self.is_class_level = False
|
||||
|
||||
def init(self):
|
||||
self.parent = self.parent_property.parent
|
||||
self.key = self.parent_property.key
|
||||
|
||||
def init(self):
|
||||
raise NotImplementedError("LoaderStrategy")
|
||||
|
||||
def init_class_attribute(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class ColumnProperty(StrategizedProperty):
|
||||
self.columns = [expression._labeled(c) for c in columns]
|
||||
self.group = kwargs.pop('group', None)
|
||||
self.deferred = kwargs.pop('deferred', False)
|
||||
self.comparator = ColumnProperty.ColumnComparator(self)
|
||||
self.comparator_factory = ColumnProperty.ColumnComparator
|
||||
util.set_creation_order(self)
|
||||
if self.deferred:
|
||||
self.strategy_class = strategies.DeferredColumnLoader
|
||||
@@ -80,7 +80,7 @@ class ColumnProperty(StrategizedProperty):
|
||||
|
||||
class ColumnComparator(PropComparator):
|
||||
def __clause_element__(self):
|
||||
return self.prop.columns[0]._annotate({"parententity": self.prop.parent})
|
||||
return self.prop.columns[0]._annotate({"parententity": self.mapper})
|
||||
__clause_element__ = util.cache_decorator(__clause_element__)
|
||||
|
||||
def operate(self, op, *other, **kwargs):
|
||||
@@ -101,7 +101,7 @@ class CompositeProperty(ColumnProperty):
|
||||
def __init__(self, class_, *columns, **kwargs):
|
||||
super(CompositeProperty, self).__init__(*columns, **kwargs)
|
||||
self.composite_class = class_
|
||||
self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self)
|
||||
self.comparator_factory = kwargs.pop('comparator', CompositeProperty.Comparator)
|
||||
self.strategy_class = strategies.CompositeColumnLoader
|
||||
|
||||
def do_init(self):
|
||||
@@ -170,8 +170,7 @@ class SynonymProperty(MapperProperty):
|
||||
|
||||
def do_init(self):
|
||||
class_ = self.parent.class_
|
||||
def comparator():
|
||||
return self.parent._get_property(self.key, resolve_synonyms=True).comparator
|
||||
|
||||
self.logger.info("register managed attribute %s on class %s" % (self.key, class_.__name__))
|
||||
if self.descriptor is None:
|
||||
class SynonymProp(object):
|
||||
@@ -184,7 +183,14 @@ class SynonymProperty(MapperProperty):
|
||||
return s
|
||||
return getattr(obj, self.name)
|
||||
self.descriptor = SynonymProp()
|
||||
sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator, parententity=self.parent)
|
||||
|
||||
def comparator_callable(prop, mapper):
|
||||
def comparator():
|
||||
prop = self.parent._get_property(self.key, resolve_synonyms=True)
|
||||
return prop.comparator_factory(prop, mapper)
|
||||
return comparator
|
||||
|
||||
strategies.DefaultColumnLoader(self)._register_attribute(None, None, False, comparator_callable, proxy_property=self.descriptor)
|
||||
|
||||
def merge(self, session, source, dest, _recursive):
|
||||
pass
|
||||
@@ -195,18 +201,13 @@ class ComparableProperty(MapperProperty):
|
||||
|
||||
def __init__(self, comparator_factory, descriptor=None):
|
||||
self.descriptor = descriptor
|
||||
self.comparator = comparator_factory(self)
|
||||
self.comparator_factory = comparator_factory
|
||||
util.set_creation_order(self)
|
||||
|
||||
def do_init(self):
|
||||
"""Set up a proxy to the unmanaged descriptor."""
|
||||
|
||||
class_ = self.parent.class_
|
||||
# refactor me
|
||||
sessionlib.register_attribute(class_, self.key, uselist=False,
|
||||
proxy_property=self.descriptor,
|
||||
useobject=False,
|
||||
comparator=self.comparator)
|
||||
strategies.DefaultColumnLoader(self)._register_attribute(None, None, False, self.comparator_factory, proxy_property=self.descriptor)
|
||||
|
||||
def setup(self, context, entity, path, adapter, **kwargs):
|
||||
pass
|
||||
@@ -252,10 +253,11 @@ class PropertyLoader(StrategizedProperty):
|
||||
self.passive_updates = passive_updates
|
||||
self.remote_side = remote_side
|
||||
self.enable_typechecks = enable_typechecks
|
||||
self.comparator = PropertyLoader.Comparator(self)
|
||||
self.comparator = PropertyLoader.Comparator(self, None)
|
||||
self.join_depth = join_depth
|
||||
self.local_remote_pairs = _local_remote_pairs
|
||||
self.__join_cache = {}
|
||||
self.comparator_factory = PropertyLoader.Comparator
|
||||
util.set_creation_order(self)
|
||||
|
||||
if strategy_class:
|
||||
@@ -295,8 +297,9 @@ class PropertyLoader(StrategizedProperty):
|
||||
self._is_backref = _is_backref
|
||||
|
||||
class Comparator(PropComparator):
|
||||
def __init__(self, prop, of_type=None):
|
||||
def __init__(self, prop, mapper, of_type=None):
|
||||
self.prop = self.property = prop
|
||||
self.mapper = mapper
|
||||
if of_type:
|
||||
self._of_type = _class_to_mapper(of_type)
|
||||
|
||||
@@ -314,7 +317,7 @@ class PropertyLoader(StrategizedProperty):
|
||||
return op(self, *other, **kwargs)
|
||||
|
||||
def of_type(self, cls):
|
||||
return PropertyLoader.Comparator(self.prop, cls)
|
||||
return PropertyLoader.Comparator(self.prop, self.mapper, cls)
|
||||
|
||||
def __eq__(self, other):
|
||||
if other is None:
|
||||
|
||||
+58
-41
@@ -97,7 +97,14 @@ class Query(object):
|
||||
self.__setup_aliasizers(self._entities)
|
||||
|
||||
def __setup_aliasizers(self, entities):
|
||||
d = {}
|
||||
if hasattr(self, '_mapper_adapter_map'):
|
||||
# usually safe to share a single map, but copying to prevent
|
||||
# subtle leaks if end-user is reusing base query with arbitrary
|
||||
# number of aliased() objects
|
||||
self._mapper_adapter_map = d = self._mapper_adapter_map.copy()
|
||||
else:
|
||||
self._mapper_adapter_map = d = {}
|
||||
|
||||
for ent in entities:
|
||||
for entity in ent.entities:
|
||||
if entity not in d:
|
||||
@@ -114,7 +121,7 @@ class Query(object):
|
||||
|
||||
d[entity] = (mapper, adapter, selectable, is_aliased_class, with_polymorphic)
|
||||
ent.setup_entity(entity, *d[entity])
|
||||
|
||||
|
||||
def __mapper_loads_polymorphically_with(self, mapper, adapter):
|
||||
for m2 in mapper._with_polymorphic_mappers:
|
||||
for m in m2.iterate_to_root():
|
||||
@@ -650,26 +657,6 @@ class Query(object):
|
||||
return self.filter(sql.and_(*clauses))
|
||||
|
||||
|
||||
def min(self, col):
|
||||
"""Execute the SQL ``min()`` function against the given column."""
|
||||
|
||||
return self._col_aggregate(col, sql.func.min)
|
||||
|
||||
def max(self, col):
|
||||
"""Execute the SQL ``max()`` function against the given column."""
|
||||
|
||||
return self._col_aggregate(col, sql.func.max)
|
||||
|
||||
def sum(self, col):
|
||||
"""Execute the SQL ``sum()`` function against the given column."""
|
||||
|
||||
return self._col_aggregate(col, sql.func.sum)
|
||||
|
||||
def avg(self, col):
|
||||
"""Execute the SQL ``avg()`` function against the given column."""
|
||||
|
||||
return self._col_aggregate(col, sql.func.avg)
|
||||
|
||||
def order_by(self, *criterion):
|
||||
"""apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``"""
|
||||
|
||||
@@ -1213,18 +1200,17 @@ class Query(object):
|
||||
_should_nest_selectable = property(_should_nest_selectable)
|
||||
|
||||
def count(self):
|
||||
"""Apply this query's criterion to a SELECT COUNT statement.
|
||||
|
||||
this is the purely generative version which will become
|
||||
the public method in version 0.5.
|
||||
|
||||
"""
|
||||
return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._mapper_zero().primary_key))
|
||||
"""Apply this query's criterion to a SELECT COUNT statement."""
|
||||
|
||||
return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._only_mapper_zero().primary_key))
|
||||
|
||||
def _col_aggregate(self, col, func, nested_cols=None):
|
||||
whereclause = self._criterion
|
||||
|
||||
context = QueryContext(self)
|
||||
|
||||
self._adjust_for_single_inheritance(context)
|
||||
|
||||
whereclause = context.whereclause
|
||||
|
||||
from_obj = self.__mapper_zero_from_obj()
|
||||
|
||||
if self._should_nest_selectable:
|
||||
@@ -1371,7 +1357,9 @@ class Query(object):
|
||||
froms = [context.from_clause] # "load from a single FROM" mode, i.e. when select_from() or join() is used
|
||||
else:
|
||||
froms = context.froms # "load from discrete FROMs" mode, i.e. when each _MappedEntity has its own FROM
|
||||
|
||||
|
||||
self._adjust_for_single_inheritance(context)
|
||||
|
||||
if eager_joins and self._should_nest_selectable:
|
||||
# for eager joins present and LIMIT/OFFSET/DISTINCT, wrap the query inside a select,
|
||||
# then append eager joins onto that
|
||||
@@ -1382,7 +1370,15 @@ class Query(object):
|
||||
context.order_by = None
|
||||
order_by_col_expr = []
|
||||
|
||||
inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=froms, use_labels=labels, correlate=False, order_by=context.order_by, **self._select_args)
|
||||
inner = sql.select(
|
||||
context.primary_columns + order_by_col_expr,
|
||||
context.whereclause,
|
||||
from_obj=froms,
|
||||
use_labels=labels,
|
||||
correlate=False,
|
||||
order_by=context.order_by,
|
||||
**self._select_args
|
||||
)
|
||||
|
||||
if self._correlate:
|
||||
inner = inner.correlate(*self._correlate)
|
||||
@@ -1418,7 +1414,17 @@ class Query(object):
|
||||
|
||||
froms += context.eager_joins.values()
|
||||
|
||||
statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=froms, use_labels=labels, for_update=for_update, correlate=False, order_by=context.order_by, **self._select_args)
|
||||
statement = sql.select(
|
||||
context.primary_columns + context.secondary_columns,
|
||||
context.whereclause,
|
||||
from_obj=froms,
|
||||
use_labels=labels,
|
||||
for_update=for_update,
|
||||
correlate=False,
|
||||
order_by=context.order_by,
|
||||
**self._select_args
|
||||
)
|
||||
|
||||
if self._correlate:
|
||||
statement = statement.correlate(*self._correlate)
|
||||
|
||||
@@ -1429,6 +1435,22 @@ class Query(object):
|
||||
|
||||
return context
|
||||
|
||||
def _adjust_for_single_inheritance(self, context):
|
||||
"""Apply single-table-inheritance filtering.
|
||||
|
||||
For all distinct single-table-inheritance mappers represented in the columns
|
||||
clause of this query, add criterion to the WHERE clause of the given QueryContext
|
||||
such that only the appropriate subtypes are selected from the total results.
|
||||
|
||||
"""
|
||||
for entity, (mapper, adapter, s, i, w) in self._mapper_adapter_map.iteritems():
|
||||
if mapper.single and mapper.inherits and mapper.polymorphic_on and mapper.polymorphic_identity is not None:
|
||||
crit = mapper.polymorphic_on.in_([m.polymorphic_identity for m in mapper.polymorphic_iterator()])
|
||||
if adapter:
|
||||
crit = adapter.traverse(crit)
|
||||
crit = self._adapt_clause(crit, False, False)
|
||||
context.whereclause = sql.and_(context.whereclause, crit)
|
||||
|
||||
def __log_debug(self, msg):
|
||||
self.logger.debug(msg)
|
||||
|
||||
@@ -1463,7 +1485,7 @@ class _MapperEntity(_QueryEntity):
|
||||
self.entities = [entity]
|
||||
self.entity_zero = entity
|
||||
self.entity_name = entity_name
|
||||
|
||||
|
||||
def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic):
|
||||
self.mapper = mapper
|
||||
self.extension = self.mapper.extension
|
||||
@@ -1554,15 +1576,10 @@ class _MapperEntity(_QueryEntity):
|
||||
return main, entname
|
||||
|
||||
def setup_context(self, query, context):
|
||||
# if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
|
||||
# that we only load the appropriate types
|
||||
if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
|
||||
context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()]))
|
||||
adapter = self._get_entity_clauses(query, context)
|
||||
|
||||
context.froms.append(self.selectable)
|
||||
|
||||
adapter = self._get_entity_clauses(query, context)
|
||||
|
||||
if context.order_by is False and self.mapper.order_by:
|
||||
context.order_by = self.mapper.order_by
|
||||
|
||||
|
||||
@@ -17,11 +17,31 @@ from sqlalchemy.orm import session as sessionlib
|
||||
from sqlalchemy.orm import util as mapperutil
|
||||
|
||||
|
||||
class ColumnLoader(LoaderStrategy):
|
||||
"""Default column loader."""
|
||||
class DefaultColumnLoader(LoaderStrategy):
|
||||
def _register_attribute(self, compare_function, copy_function, mutable_scalars, comparator_factory, callable_=None, proxy_property=None):
|
||||
self.logger.info("%s register managed attribute" % self)
|
||||
|
||||
for mapper in self.parent.polymorphic_iterator():
|
||||
if mapper is self.parent or not mapper.concrete:
|
||||
sessionlib.register_attribute(
|
||||
mapper.class_,
|
||||
self.key,
|
||||
uselist=False,
|
||||
useobject=False,
|
||||
copy_function=copy_function,
|
||||
compare_function=compare_function,
|
||||
mutable_scalars=mutable_scalars,
|
||||
comparator=comparator_factory(self.parent_property, mapper),
|
||||
parententity=mapper,
|
||||
callable_=callable_,
|
||||
proxy_property=proxy_property
|
||||
)
|
||||
|
||||
DefaultColumnLoader.logger = log.class_logger(DefaultColumnLoader)
|
||||
|
||||
class ColumnLoader(DefaultColumnLoader):
|
||||
|
||||
def init(self):
|
||||
super(ColumnLoader, self).init()
|
||||
self.columns = self.parent_property.columns
|
||||
self._should_log_debug = log.is_debug_enabled(self.logger)
|
||||
self.is_composite = hasattr(self.parent_property, 'composite_class')
|
||||
@@ -34,9 +54,14 @@ class ColumnLoader(LoaderStrategy):
|
||||
|
||||
def init_class_attribute(self):
|
||||
self.is_class_level = True
|
||||
self.logger.info("%s register managed attribute" % self)
|
||||
coltype = self.columns[0].type
|
||||
sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent)
|
||||
|
||||
self._register_attribute(
|
||||
coltype.compare_values,
|
||||
coltype.copy_value,
|
||||
self.columns[0].type.is_mutable(),
|
||||
self.parent_property.comparator_factory
|
||||
)
|
||||
|
||||
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
|
||||
key, col = self.key, self.columns[0]
|
||||
@@ -78,7 +103,13 @@ class CompositeColumnLoader(ColumnLoader):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator, parententity=self.parent)
|
||||
|
||||
self._register_attribute(
|
||||
compare,
|
||||
copy,
|
||||
True,
|
||||
self.parent_property.comparator_factory
|
||||
)
|
||||
|
||||
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
|
||||
key, columns, composite_class = self.key, self.columns, self.parent_property.composite_class
|
||||
@@ -106,7 +137,7 @@ class CompositeColumnLoader(ColumnLoader):
|
||||
|
||||
CompositeColumnLoader.logger = log.class_logger(CompositeColumnLoader)
|
||||
|
||||
class DeferredColumnLoader(LoaderStrategy):
|
||||
class DeferredColumnLoader(DefaultColumnLoader):
|
||||
"""Deferred column loader, a per-column or per-column-group lazy loader."""
|
||||
|
||||
def create_row_processor(self, selectcontext, path, mapper, row, adapter):
|
||||
@@ -130,7 +161,6 @@ class DeferredColumnLoader(LoaderStrategy):
|
||||
return (new_execute, None)
|
||||
|
||||
def init(self):
|
||||
super(DeferredColumnLoader, self).init()
|
||||
if hasattr(self.parent_property, 'composite_class'):
|
||||
raise NotImplementedError("Deferred loading for composite types not implemented yet")
|
||||
self.columns = self.parent_property.columns
|
||||
@@ -139,8 +169,13 @@ class DeferredColumnLoader(LoaderStrategy):
|
||||
|
||||
def init_class_attribute(self):
|
||||
self.is_class_level = True
|
||||
self.logger.info("%s register managed attribute" % self)
|
||||
sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent)
|
||||
self._register_attribute(
|
||||
self.columns[0].type.compare_values,
|
||||
self.columns[0].type.copy_value,
|
||||
self.columns[0].type.is_mutable(),
|
||||
self.parent_property.comparator_factory,
|
||||
callable_=self.class_level_loader,
|
||||
)
|
||||
|
||||
def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs):
|
||||
if \
|
||||
@@ -238,7 +273,6 @@ class UndeferGroupOption(MapperOption):
|
||||
|
||||
class AbstractRelationLoader(LoaderStrategy):
|
||||
def init(self):
|
||||
super(AbstractRelationLoader, self).init()
|
||||
for attr in ['mapper', 'target', 'table', 'uselist']:
|
||||
setattr(self, attr, getattr(self.parent_property, attr))
|
||||
self._should_log_debug = log.is_debug_enabled(self.logger)
|
||||
@@ -249,7 +283,7 @@ class AbstractRelationLoader(LoaderStrategy):
|
||||
else:
|
||||
state.initialize(self.key)
|
||||
|
||||
def _register_attribute(self, class_, callable_=None, **kwargs):
|
||||
def _register_attribute(self, class_, callable_=None, impl_class=None, **kwargs):
|
||||
self.logger.info("%s register managed %s attribute" % (self, (self.uselist and "collection" or "scalar")))
|
||||
|
||||
if self.parent_property.backref:
|
||||
@@ -257,7 +291,21 @@ class AbstractRelationLoader(LoaderStrategy):
|
||||
else:
|
||||
attribute_ext = None
|
||||
|
||||
sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=attribute_ext, cascade=self.parent_property.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, parententity=self.parent, **kwargs)
|
||||
sessionlib.register_attribute(
|
||||
class_,
|
||||
self.key,
|
||||
uselist=self.uselist,
|
||||
useobject=True,
|
||||
extension=attribute_ext,
|
||||
cascade=self.parent_property.cascade,
|
||||
trackparent=True,
|
||||
typecallable=self.parent_property.collection_class,
|
||||
callable_=callable_,
|
||||
comparator=self.parent_property.comparator,
|
||||
parententity=self.parent,
|
||||
impl_class=impl_class,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
class NoLoader(AbstractRelationLoader):
|
||||
def init_class_attribute(self):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import testenv; testenv.configure_for_tests()
|
||||
from testlib import testing, sa
|
||||
from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData
|
||||
from testlib.sa import Table, Column, Integer, String, ForeignKey, MetaData, func
|
||||
from sqlalchemy.orm import mapper, relation, create_session
|
||||
from testlib.testing import eq_
|
||||
from testlib.compat import set
|
||||
@@ -57,8 +57,9 @@ class GenerativeQueryTest(_base.MappedTest):
|
||||
sess = create_session()
|
||||
query = sess.query(Foo)
|
||||
assert query.count() == 100
|
||||
assert query.filter(foo.c.bar<30).min(foo.c.bar) == 0
|
||||
assert query.filter(foo.c.bar<30).max(foo.c.bar) == 29
|
||||
assert sess.query(func.min(foo.c.bar)).filter(foo.c.bar<30).one() == (0,)
|
||||
|
||||
assert sess.query(func.max(foo.c.bar)).filter(foo.c.bar<30).one() == (29,)
|
||||
assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29
|
||||
assert query.filter(foo.c.bar<30).values(sa.func.max(foo.c.bar)).next()[0] == 29
|
||||
|
||||
@@ -68,14 +69,14 @@ class GenerativeQueryTest(_base.MappedTest):
|
||||
testing.db.dialect.dbapi.version_info[:4] == (1, 2, 1, 'gamma')):
|
||||
return
|
||||
|
||||
query = create_session().query(Foo)
|
||||
assert query.filter(foo.c.bar<30).sum(foo.c.bar) == 435
|
||||
query = create_session().query(func.sum(foo.c.bar))
|
||||
assert query.filter(foo.c.bar<30).one() == (435,)
|
||||
|
||||
@testing.fails_on('firebird', 'mssql')
|
||||
@testing.resolve_artifact_names
|
||||
def test_aggregate_2(self):
|
||||
query = create_session().query(Foo)
|
||||
avg = query.filter(foo.c.bar < 30).avg(foo.c.bar)
|
||||
query = create_session().query(func.avg(foo.c.bar))
|
||||
avg = query.filter(foo.c.bar < 30).one()[0]
|
||||
eq_(round(avg, 1), 14.5)
|
||||
|
||||
@testing.resolve_artifact_names
|
||||
|
||||
@@ -510,7 +510,7 @@ def make_test(select_type):
|
||||
|
||||
self.assertEquals(sess.query(Person).filter(Person.person_id==subq).one(), e1)
|
||||
|
||||
|
||||
|
||||
def test_mixed_entities(self):
|
||||
sess = create_session()
|
||||
|
||||
@@ -525,6 +525,41 @@ def make_test(select_type):
|
||||
[(Engineer(status=u'elbonian engineer',engineer_name=u'vlad',name=u'vlad',primary_language=u'cobol'),
|
||||
u'Elbonia, Inc.')]
|
||||
)
|
||||
|
||||
|
||||
self.assertEquals(
|
||||
sess.query(Manager.name).all(),
|
||||
[('pointy haired boss', ), ('dogbert',)]
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
sess.query(Manager.name + " foo").all(),
|
||||
[('pointy haired boss foo', ), ('dogbert foo',)]
|
||||
)
|
||||
|
||||
|
||||
self.assertEquals(
|
||||
sess.query(Engineer.name, Engineer.primary_language).all(),
|
||||
[(u'dilbert', u'java'), (u'wally', u'c++'), (u'vlad', u'cobol')]
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
sess.query(Boss.name, Boss.golf_swing).all(),
|
||||
[(u'pointy haired boss', u'fore')]
|
||||
)
|
||||
|
||||
# TODO: I think raise error on these for now. different inheritance/loading schemes have different
|
||||
# results here, all incorrect
|
||||
#
|
||||
# self.assertEquals(
|
||||
# sess.query(Person.name, Engineer.primary_language).all(),
|
||||
# []
|
||||
# )
|
||||
|
||||
# self.assertEquals(
|
||||
# sess.query(Person.name, Engineer.primary_language, Manager.manager_name).all(),
|
||||
# []
|
||||
# )
|
||||
|
||||
self.assertEquals(
|
||||
sess.query(Person.name, Company.name).join(Company.employees).filter(Company.name=='Elbonia, Inc.').all(),
|
||||
|
||||
@@ -3,8 +3,9 @@ from sqlalchemy import *
|
||||
from sqlalchemy.orm import *
|
||||
from testlib import *
|
||||
from testlib.fixtures import Base
|
||||
from orm._base import MappedTest, ComparableEntity
|
||||
|
||||
class SingleInheritanceTest(ORMTest):
|
||||
class SingleInheritanceTest(MappedTest):
|
||||
def define_tables(self, metadata):
|
||||
global employees_table
|
||||
employees_table = Table('employees', metadata,
|
||||
@@ -14,9 +15,9 @@ class SingleInheritanceTest(ORMTest):
|
||||
Column('engineer_info', String(50)),
|
||||
Column('type', String(20))
|
||||
)
|
||||
|
||||
def test_single_inheritance(self):
|
||||
class Employee(Base):
|
||||
|
||||
def setup_classes(self):
|
||||
class Employee(ComparableEntity):
|
||||
pass
|
||||
class Manager(Employee):
|
||||
pass
|
||||
@@ -25,19 +26,22 @@ class SingleInheritanceTest(ORMTest):
|
||||
class JuniorEngineer(Engineer):
|
||||
pass
|
||||
|
||||
@testing.resolve_artifact_names
|
||||
def setup_mappers(self):
|
||||
mapper(Employee, employees_table, polymorphic_on=employees_table.c.type)
|
||||
mapper(Manager, inherits=Employee, polymorphic_identity='manager')
|
||||
mapper(Engineer, inherits=Employee, polymorphic_identity='engineer')
|
||||
mapper(JuniorEngineer, inherits=Engineer, polymorphic_identity='juniorengineer')
|
||||
|
||||
@testing.resolve_artifact_names
|
||||
def test_single_inheritance(self):
|
||||
|
||||
session = create_session()
|
||||
|
||||
m1 = Manager(name='Tom', manager_data='knows how to manage things')
|
||||
e1 = Engineer(name='Kurt', engineer_info='knows how to hack')
|
||||
e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed')
|
||||
session.save(m1)
|
||||
session.save(e1)
|
||||
session.save(e2)
|
||||
session.add_all([m1, e1, e2])
|
||||
session.flush()
|
||||
|
||||
assert session.query(Employee).all() == [m1, e1, e2]
|
||||
@@ -48,6 +52,86 @@ class SingleInheritanceTest(ORMTest):
|
||||
m1 = session.query(Manager).one()
|
||||
session.expire(m1, ['manager_data'])
|
||||
self.assertEquals(m1.manager_data, "knows how to manage things")
|
||||
|
||||
@testing.resolve_artifact_names
|
||||
def test_multi_qualification(self):
|
||||
session = create_session()
|
||||
|
||||
m1 = Manager(name='Tom', manager_data='knows how to manage things')
|
||||
e1 = Engineer(name='Kurt', engineer_info='knows how to hack')
|
||||
e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed')
|
||||
|
||||
session.add_all([m1, e1, e2])
|
||||
session.flush()
|
||||
|
||||
ealias = aliased(Engineer)
|
||||
self.assertEquals(
|
||||
session.query(Manager, ealias).all(),
|
||||
[(m1, e1), (m1, e2)]
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
session.query(Manager.name).all(),
|
||||
[("Tom",)]
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
session.query(Manager.name, ealias.name).all(),
|
||||
[("Tom", "Kurt"), ("Tom", "Ed")]
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
session.query(func.upper(Manager.name), func.upper(ealias.name)).all(),
|
||||
[("TOM", "KURT"), ("TOM", "ED")]
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
session.query(Manager).add_entity(ealias).all(),
|
||||
[(m1, e1), (m1, e2)]
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
session.query(Manager.name).add_column(ealias.name).all(),
|
||||
[("Tom", "Kurt"), ("Tom", "Ed")]
|
||||
)
|
||||
|
||||
# TODO: I think raise error on this for now
|
||||
# self.assertEquals(
|
||||
# session.query(Employee.name, Manager.manager_data, Engineer.engineer_info).all(),
|
||||
# []
|
||||
# )
|
||||
|
||||
@testing.resolve_artifact_names
|
||||
def test_select_from(self):
|
||||
sess = create_session()
|
||||
m1 = Manager(name='Tom', manager_data='data1')
|
||||
m2 = Manager(name='Tom2', manager_data='data2')
|
||||
e1 = Engineer(name='Kurt', engineer_info='knows how to hack')
|
||||
e2 = JuniorEngineer(name='Ed', engineer_info='oh that ed')
|
||||
sess.add_all([m1, m2, e1, e2])
|
||||
sess.flush()
|
||||
|
||||
self.assertEquals(
|
||||
sess.query(Manager).select_from(employees_table.select().limit(10)).all(),
|
||||
[m1, m2]
|
||||
)
|
||||
|
||||
@testing.resolve_artifact_names
|
||||
def test_count(self):
|
||||
sess = create_session()
|
||||
m1 = Manager(name='Tom', manager_data='data1')
|
||||
m2 = Manager(name='Tom2', manager_data='data2')
|
||||
e1 = Engineer(name='Kurt', engineer_info='data3')
|
||||
e2 = JuniorEngineer(name='marvin', engineer_info='data4')
|
||||
sess.add_all([m1, m2, e1, e2])
|
||||
sess.flush()
|
||||
|
||||
self.assertEquals(sess.query(Manager).count(), 2)
|
||||
self.assertEquals(sess.query(Engineer).count(), 2)
|
||||
self.assertEquals(sess.query(Employee).count(), 4)
|
||||
|
||||
self.assertEquals(sess.query(Manager).filter(Manager.name.like('%m%')).count(), 2)
|
||||
self.assertEquals(sess.query(Employee).filter(Employee.name.like('%m%')).count(), 3)
|
||||
|
||||
class SingleOnJoinedTest(ORMTest):
|
||||
def define_tables(self, metadata):
|
||||
|
||||
+1
-1
@@ -542,7 +542,7 @@ class AggregateTest(QueryTest):
|
||||
def test_sum(self):
|
||||
sess = create_session()
|
||||
orders = sess.query(Order).filter(Order.id.in_([2, 3, 4]))
|
||||
assert orders.sum(Order.user_id * Order.address_id) == 79
|
||||
self.assertEquals(orders.values(func.sum(Order.user_id * Order.address_id)).next(), (79,))
|
||||
|
||||
def test_apply(self):
|
||||
sess = create_session()
|
||||
|
||||
Reference in New Issue
Block a user