- fixed bug whereby session.expire() attributes were not

loading on an polymorphically-mapped instance mapped
by a select_table mapper.

- added query.with_polymorphic() - specifies a list
of classes which descend from the base class, which will
be added to the FROM clause of the query.  Allows subclasses
to be used within filter() criterion as well as eagerly loads
the attributes of those subclasses.

- deprecated Query methods apply_sum(), apply_max(), apply_min(),
apply_avg().  Better methodologies are coming....
This commit is contained in:
Mike Bayer
2008-03-01 01:46:23 +00:00
parent bda6f1e06f
commit 075eb9076b
6 changed files with 219 additions and 54 deletions
+14 -1
View File
@@ -25,7 +25,17 @@ CHANGES
work properly with self-referential relations - the clause
inside the EXISTS is aliased on the "remote" side to
distinguish it from the parent table.
- fixed bug whereby session.expire() attributes were not
loading on an polymorphically-mapped instance mapped
by a select_table mapper.
- added query.with_polymorphic() - specifies a list
of classes which descend from the base class, which will
be added to the FROM clause of the query. Allows subclasses
to be used within filter() criterion as well as eagerly loads
the attributes of those subclasses.
- Your cries have been heard: removing a pending item from an
attribute or collection with delete-orphan expunges the item
from the session; no FlushError is raised. Note that if you
@@ -35,6 +45,9 @@ CHANGES
- Fixed potential generative bug when the same Query was used
to generate multiple Query objects using join().
- deprecated Query methods apply_sum(), apply_max(), apply_min(),
apply_avg(). Better methodologies are coming....
- Added a new "higher level" operator called "of_type()": used
in join() as well as with any() and has(), qualifies the
subclass which will be used in filter criterion, e.g.:
+1
View File
@@ -1579,6 +1579,7 @@ def _load_scalar_attributes(instance, attribute_names):
identity_key = state.dict['_instance_key']
else:
identity_key = mapper._identity_key_from_state(state)
if session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names) is None:
raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance))
+136 -52
View File
@@ -33,14 +33,12 @@ class Query(object):
"""Encapsulates the object-fetching operations provided by Mappers."""
def __init__(self, class_or_mapper, session=None, entity_name=None):
self.mapper = _class_to_mapper(class_or_mapper, entity_name=entity_name)
self.select_mapper = self.mapper.get_select_mapper().compile()
self._init_mapper(_class_to_mapper(class_or_mapper, entity_name=entity_name))
self._session = session
self._with_options = []
self._lockmode = None
self._extension = self.mapper.extension
self._entities = []
self._order_by = False
self._group_by = False
@@ -54,24 +52,41 @@ class Query(object):
self._joinable_tables = None
self._having = None
self._column_aggregate = None
self._joinpoint = self.mapper
self._aliases = None
self._alias_ids = {}
self._from_obj = self.table
self._populate_existing = False
self._version_check = False
self._autoflush = True
self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]]))
self._attributes = {}
self._current_path = ()
self._only_load_props = None
self._refresh_instance = None
def _init_mapper(self, mapper, select_mapper=None):
"""populate all instance variables derived from this Query's mapper."""
self.mapper = mapper
self.select_mapper = select_mapper or self.mapper.get_select_mapper().compile()
self.table = self._from_obj = self.select_mapper.mapped_table
self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]]))
self._extension = self.mapper.extension
self._adapter = self.select_mapper._clause_adapter
self._joinpoint = self.mapper
self._with_polymorphic = []
def _no_criterion(self, meth):
q = self._clone()
return self._conditional_clone(meth, [self._no_criterion_condition])
def _no_statement(self, meth):
return self._conditional_clone(meth, [self._no_statement_condition])
def _new_base_mapper(self, mapper, meth):
q = self._conditional_clone(meth, [self._no_criterion_condition])
q._init_mapper(mapper, mapper)
return q
def _no_criterion_condition(self, q, meth):
if q._criterion or q._statement or q._from_obj is not self.table:
util.warn(
("Query.%s() being called on a Query with existing criterion; "
@@ -83,16 +98,20 @@ class Query(object):
q._joinpoint = self.mapper
q._statement = q._aliases = q._criterion = None
q._order_by = q._group_by = q._distinct = False
return q
def _no_statement(self, meth):
q = self._clone()
def _no_statement_condition(self, q, meth):
if q._statement:
raise exceptions.InvalidRequestError(
("Query.%s() being called on a Query with an existing full "
"statement - can't apply criterion.") % meth)
def _conditional_clone(self, methname=None, conditions=None):
q = self._clone()
if conditions:
for condition in conditions:
condition(q, methname)
return q
def _clone(self):
q = Query.__new__(Query)
q.__dict__ = self.__dict__.copy()
@@ -104,7 +123,6 @@ class Query(object):
else:
return self._session
table = property(lambda s:s.select_mapper.mapped_table)
primary_key_columns = property(lambda s:s.select_mapper.primary_key)
session = property(_get_session)
@@ -112,7 +130,63 @@ class Query(object):
q = self._clone()
q._current_path = path
return q
def with_polymorphic(self, cls_or_mappers, selectable=None):
"""Load columns for descendant mappers of this Query's mapper.
Using this method will ensure that each descendant mapper's
tables are included in the FROM clause, and will allow filter()
criterion to be used against those tables. The resulting
instances will also have those columns already loaded so that
no "post fetch" of those columns will be required.
If this Query's mapper has a ``select_table`` argument,
with_polymorphic() overrides it; the FROM clause will be against
the local table of the base mapper outer joined with the local
tables of each specified descendant mapper (unless ``selectable``
is specified).
``cls_or_mappers`` is a single class or mapper, or list of class/mappers,
which inherit from this Query's mapper. Alternatively, it
may also be the string ``'*'``, in which case all descending
mappers will be added to the FROM clause.
``selectable`` is a table or select() statement that will
be used in place of the generated FROM clause. This argument
is required if any of the desired mappers use concrete table
inheritance, since SQLAlchemy currently cannot generate UNIONs
among tables automatically. If used, the ``selectable``
argument must represent the full set of tables and columns mapped
by every desired mapper. Otherwise, the unaccounted mapped columns
will result in their table being appended directly to the FROM
clause which will usually lead to incorrect results.
"""
q = self._new_base_mapper(self.mapper, 'with_polymorphic')
if cls_or_mappers == '*':
cls_or_mappers = self.mapper.polymorphic_iterator()
else:
cls_or_mappers = util.to_list(cls_or_mappers)
if selectable:
q = q.select_from(selectable)
for cls_or_mapper in cls_or_mappers:
poly_mapper = _class_to_mapper(cls_or_mapper)
if poly_mapper is self.mapper:
continue
q._with_polymorphic.append(poly_mapper)
if not selectable:
if poly_mapper.concrete:
raise exceptions.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.")
elif not poly_mapper.single:
q._from_obj = q._from_obj.outerjoin(poly_mapper.local_table, poly_mapper.inherit_condition)
return q
def yield_per(self, count):
"""Yield only ``count`` rows at a time.
@@ -412,6 +486,8 @@ class Query(object):
# hand side.
if self._adapter and not self._aliases: # at the beginning of a join, look at leftmost adapter
adapt_against = self._adapter.selectable
elif start is self.select_mapper: # or if its our base mapper, go against our base table
adapt_against = self.table
elif start.select_table is not start.mapped_table: # in the middle of a join, look for a polymorphic mapper
adapt_against = start.select_table
else:
@@ -444,7 +520,7 @@ class Query(object):
raise exceptions.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (use_selectable.description, prop.mapper.mapped_table.description))
if not isinstance(use_selectable, expression.Alias):
use_selectable = use_selectable.alias()
if prop._is_self_referential() and not create_aliases and not use_selectable:
raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires create_aliases=True argument." % str(prop))
@@ -503,24 +579,32 @@ class Query(object):
def apply_min(self, col):
"""apply the SQL ``min()`` function against the given column to the
query and return the newly resulting ``Query``.
DEPRECATED.
"""
return self._generative_col_aggregate(col, sql.func.min)
def apply_max(self, col):
"""apply the SQL ``max()`` function against the given column to the
query and return the newly resulting ``Query``.
DEPRECATED.
"""
return self._generative_col_aggregate(col, sql.func.max)
def apply_sum(self, col):
"""apply the SQL ``sum()`` function against the given column to the
query and return the newly resulting ``Query``.
DEPRECATED.
"""
return self._generative_col_aggregate(col, sql.func.sum)
def apply_avg(self, col):
"""apply the SQL ``avg()`` function against the given column to the
query and return the newly resulting ``Query``.
DEPRECATED.
"""
return self._generative_col_aggregate(col, sql.func.avg)
@@ -852,6 +936,11 @@ class Query(object):
context.runid = _new_runid()
# for with_polymorphic, instruct descendant mappers that they
# don't need to post-fetch anything
for m in self._with_polymorphic:
context.attributes[('polymorphic_fetch', m)] = (self.select_mapper, [])
mappers_or_columns = tuple(self._entities) + mappers_or_columns
tuples = bool(mappers_or_columns)
@@ -950,12 +1039,17 @@ class Query(object):
ident = util.to_list(ident)
q = self
# dont use 'polymorphic' mapper if we are refreshing an instance
if refresh_instance and q.select_mapper is not q.mapper:
q = q._new_base_mapper(q.mapper, '_get')
if ident is not None:
q = q._no_criterion('get')
params = {}
(_get_clause, _get_params) = self.select_mapper._get_clause
(_get_clause, _get_params) = q.select_mapper._get_clause
q = q.filter(_get_clause)
for i, primary_key in enumerate(self.primary_key_columns):
for i, primary_key in enumerate(q.primary_key_columns):
try:
params[_get_params[primary_key].key] = ident[i]
except IndexError:
@@ -1027,25 +1121,10 @@ class Query(object):
return context
whereclause = self._criterion
from_obj = self._from_obj
# if the query's ClauseAdapter is present, and its
# specifically adapting against a modified "select_from"
# argument, apply adaptation to the
# individually selected columns as well as "eager" clauses added;
# otherwise its currently not needed
if self._adapter and self.table not in self._get_joinable_tables():
adapter = self._adapter
else:
adapter = None
adapter = self._adapter
# TODO: mappers added via add_entity(), adapt their queries also,
# if those mappers are polymorphic
order_by = self._order_by
if order_by is False:
order_by = self.select_mapper.order_by
if order_by is False:
@@ -1055,22 +1134,31 @@ class Query(object):
if from_obj.default_order_by() is not None:
order_by = from_obj.default_order_by()
try:
for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode]
except KeyError:
raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode)
if self._lockmode:
try:
for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode]
except KeyError:
raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode)
else:
for_update = False
# if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so
# that we only load the appropriate types
if self.select_mapper.single and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None:
if self.select_mapper.single and self.select_mapper.inherits is not None and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None:
whereclause = sql.and_(whereclause, self.select_mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.select_mapper.polymorphic_iterator()]))
context.from_clause = from_obj
# give all the attached properties a chance to modify the query
# TODO: doing this off the select_mapper. if its the polymorphic mapper, then
# it has no relations() on it. should we compile those too into the query ? (i.e. eagerloads)
for value in self.select_mapper.iterate_properties:
# TODO: compile eagerloads from select_mapper if polymorphic ? [ticket:917]
if self._with_polymorphic:
props = util.Set()
for m in [self.select_mapper] + self._with_polymorphic:
for value in m.iterate_properties:
props.add(value)
else:
props = self.select_mapper.iterate_properties
for value in props:
if self._only_load_props and value.key not in self._only_load_props:
continue
context.exec_with_path(self.select_mapper, value.key, value.setup, context, only_load_props=self._only_load_props)
@@ -1091,12 +1179,9 @@ class Query(object):
# eager loaders are present, and the SELECT has limiting criterion
# produce a "wrapped" selectable.
# ensure all 'order by' elements are ClauseElement instances
# (since they will potentially be aliased)
# locate all embedded Column clauses so they can be added to the
# "inner" select statement where they'll be available to the enclosing
# statement's "order by"
cf = util.Set()
if order_by:
order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []]
@@ -1105,7 +1190,7 @@ class Query(object):
if adapter:
# TODO: make usage of the ClauseAdapter here to create the list
# of primary columns
# of primary columns ?
context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns]
cf = [from_obj.corresponding_column(c) or c for c in cf]
@@ -1128,7 +1213,7 @@ class Query(object):
else:
if adapter:
# TODO: make usage of the ClauseAdapter here to create row adapter, list
# of primary columns
# of primary columns ?
context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns]
context.row_adapter = mapperutil.create_row_adapter(from_obj, self.table)
@@ -1425,13 +1510,12 @@ class Query(object):
return self._legacy_filter_by(*args, **params).one()
for deprecated_method in ('list', 'scalar', 'count_by',
'select_whereclause', 'get_by', 'select_by',
'join_by', 'selectfirst', 'selectone', 'select',
'execute', 'select_statement', 'select_text',
'join_to', 'join_via', 'selectfirst_by',
'selectone_by'):
'selectone_by', 'apply_max', 'apply_min', 'apply_avg', 'apply_sum'):
setattr(Query, deprecated_method,
util.deprecated(getattr(Query, deprecated_method),
add_deprecation_to_docstring=False))
+2
View File
@@ -53,6 +53,7 @@ class GenerativeQueryTest(TestBase):
assert list(query[-5:]) == orig[-5:]
assert query[10:20][5] == orig[10:20][5]
@testing.uses_deprecated('Call to deprecated function apply_max')
def test_aggregate(self):
sess = create_session(bind=testing.db)
query = sess.query(Foo)
@@ -77,6 +78,7 @@ class GenerativeQueryTest(TestBase):
assert round(avg, 1) == 14.5
@testing.fails_on('firebird', 'mssql')
@testing.uses_deprecated('Call to deprecated function apply_avg')
def test_aggregate_3(self):
query = create_session(bind=testing.db).query(Foo)
+65 -1
View File
@@ -194,6 +194,19 @@ def make_test(select_type):
self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1])
self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
def test_join_from_with_polymorphic(self):
sess = create_session()
for aliased in (True, False):
sess.clear()
self.assertEquals(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1])
sess.clear()
self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1])
sess.clear()
self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1])
def test_join_to_polymorphic(self):
sess = create_session()
@@ -223,7 +236,58 @@ def make_test(select_type):
sess.query(Company).filter(Company.employees.any(and_(Engineer.primary_language=='cobol', people.c.person_id==engineers.c.person_id))).one(),
c2
)
def test_expire(self):
"""test that individual column refresh doesn't get tripped up by the select_table mapper"""
sess = create_session()
m1 = sess.query(Manager).filter(Manager.name=='dogbert').one()
sess.expire(m1)
assert m1.status == 'regular manager'
m2 = sess.query(Manager).filter(Manager.name=='pointy haired boss').one()
sess.expire(m2, ['manager_name', 'golf_swing'])
assert m2.golf_swing=='fore'
def test_with_polymorphic(self):
sess = create_session()
# compare to entities without related collections to prevent additional lazy SQL from firing on
# loaded entities
emps_without_relations = [
Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer"),
Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer"),
Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss"),
Manager(name="dogbert", manager_name="dogbert", status="regular manager"),
Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer")
]
def go():
self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1])
self.assert_sql_count(testing.db, go, 1)
sess.clear()
def go():
self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations)
self.assert_sql_count(testing.db, go, 1)
sess.clear()
def go():
self.assertEquals(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations)
self.assert_sql_count(testing.db, go, 3)
sess.clear()
def go():
self.assertEquals(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations)
self.assert_sql_count(testing.db, go, 3)
sess.clear()
def go():
# limit the polymorphic join down to just "Person", overriding select_table
self.assertEquals(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations)
self.assert_sql_count(testing.db, go, 6)
def test_join_to_subclass(self):
sess = create_session()
+1
View File
@@ -389,6 +389,7 @@ class AggregateTest(QueryTest):
orders = sess.query(Order).filter(Order.id.in_([2, 3, 4]))
assert orders.sum(Order.user_id * Order.address_id) == 79
@testing.uses_deprecated('Call to deprecated function apply_sum')
def test_apply(self):
sess = create_session()
assert sess.query(Order).apply_sum(Order.user_id * Order.address_id).filter(Order.id.in_([2, 3, 4])).one() == 79