diff --git a/CHANGES b/CHANGES index 501db5d286..3e5d36c20f 100644 --- a/CHANGES +++ b/CHANGES @@ -25,7 +25,17 @@ CHANGES work properly with self-referential relations - the clause inside the EXISTS is aliased on the "remote" side to distinguish it from the parent table. - + + - fixed bug whereby session.expire() attributes were not + loading on an polymorphically-mapped instance mapped + by a select_table mapper. + + - added query.with_polymorphic() - specifies a list + of classes which descend from the base class, which will + be added to the FROM clause of the query. Allows subclasses + to be used within filter() criterion as well as eagerly loads + the attributes of those subclasses. + - Your cries have been heard: removing a pending item from an attribute or collection with delete-orphan expunges the item from the session; no FlushError is raised. Note that if you @@ -35,6 +45,9 @@ CHANGES - Fixed potential generative bug when the same Query was used to generate multiple Query objects using join(). + - deprecated Query methods apply_sum(), apply_max(), apply_min(), + apply_avg(). Better methodologies are coming.... + - Added a new "higher level" operator called "of_type()": used in join() as well as with any() and has(), qualifies the subclass which will be used in filter criterion, e.g.: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 62067fc358..297d222466 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1579,6 +1579,7 @@ def _load_scalar_attributes(instance, attribute_names): identity_key = state.dict['_instance_key'] else: identity_key = mapper._identity_key_from_state(state) + if session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names) is None: raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance)) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 46f986d14e..ebe62e915c 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -33,14 +33,12 @@ class Query(object): """Encapsulates the object-fetching operations provided by Mappers.""" def __init__(self, class_or_mapper, session=None, entity_name=None): - self.mapper = _class_to_mapper(class_or_mapper, entity_name=entity_name) - self.select_mapper = self.mapper.get_select_mapper().compile() - + self._init_mapper(_class_to_mapper(class_or_mapper, entity_name=entity_name)) self._session = session self._with_options = [] self._lockmode = None - self._extension = self.mapper.extension + self._entities = [] self._order_by = False self._group_by = False @@ -54,24 +52,41 @@ class Query(object): self._joinable_tables = None self._having = None self._column_aggregate = None - self._joinpoint = self.mapper self._aliases = None self._alias_ids = {} - self._from_obj = self.table self._populate_existing = False self._version_check = False self._autoflush = True - self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]])) + self._attributes = {} self._current_path = () self._only_load_props = None self._refresh_instance = None - + + def _init_mapper(self, mapper, select_mapper=None): + """populate all instance variables derived from this Query's mapper.""" + + self.mapper = mapper + self.select_mapper = select_mapper or self.mapper.get_select_mapper().compile() + self.table = self._from_obj = self.select_mapper.mapped_table + self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]])) + self._extension = self.mapper.extension self._adapter = self.select_mapper._clause_adapter - + self._joinpoint = self.mapper + self._with_polymorphic = [] + def _no_criterion(self, meth): - q = self._clone() + return self._conditional_clone(meth, [self._no_criterion_condition]) + def _no_statement(self, meth): + return self._conditional_clone(meth, [self._no_statement_condition]) + + def _new_base_mapper(self, mapper, meth): + q = self._conditional_clone(meth, [self._no_criterion_condition]) + q._init_mapper(mapper, mapper) + return q + + def _no_criterion_condition(self, q, meth): if q._criterion or q._statement or q._from_obj is not self.table: util.warn( ("Query.%s() being called on a Query with existing criterion; " @@ -83,16 +98,20 @@ class Query(object): q._joinpoint = self.mapper q._statement = q._aliases = q._criterion = None q._order_by = q._group_by = q._distinct = False - return q - - def _no_statement(self, meth): - q = self._clone() + + def _no_statement_condition(self, q, meth): if q._statement: raise exceptions.InvalidRequestError( ("Query.%s() being called on a Query with an existing full " "statement - can't apply criterion.") % meth) + + def _conditional_clone(self, methname=None, conditions=None): + q = self._clone() + if conditions: + for condition in conditions: + condition(q, methname) return q - + def _clone(self): q = Query.__new__(Query) q.__dict__ = self.__dict__.copy() @@ -104,7 +123,6 @@ class Query(object): else: return self._session - table = property(lambda s:s.select_mapper.mapped_table) primary_key_columns = property(lambda s:s.select_mapper.primary_key) session = property(_get_session) @@ -112,7 +130,63 @@ class Query(object): q = self._clone() q._current_path = path return q + + def with_polymorphic(self, cls_or_mappers, selectable=None): + """Load columns for descendant mappers of this Query's mapper. + + Using this method will ensure that each descendant mapper's + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting + instances will also have those columns already loaded so that + no "post fetch" of those columns will be required. + + If this Query's mapper has a ``select_table`` argument, + with_polymorphic() overrides it; the FROM clause will be against + the local table of the base mapper outer joined with the local + tables of each specified descendant mapper (unless ``selectable`` + is specified). + + ``cls_or_mappers`` is a single class or mapper, or list of class/mappers, + which inherit from this Query's mapper. Alternatively, it + may also be the string ``'*'``, in which case all descending + mappers will be added to the FROM clause. + + ``selectable`` is a table or select() statement that will + be used in place of the generated FROM clause. This argument + is required if any of the desired mappers use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` + argument must represent the full set of tables and columns mapped + by every desired mapper. Otherwise, the unaccounted mapped columns + will result in their table being appended directly to the FROM + clause which will usually lead to incorrect results. + """ + + q = self._new_base_mapper(self.mapper, 'with_polymorphic') + + if cls_or_mappers == '*': + cls_or_mappers = self.mapper.polymorphic_iterator() + else: + cls_or_mappers = util.to_list(cls_or_mappers) + + if selectable: + q = q.select_from(selectable) + + for cls_or_mapper in cls_or_mappers: + poly_mapper = _class_to_mapper(cls_or_mapper) + if poly_mapper is self.mapper: + continue + + q._with_polymorphic.append(poly_mapper) + if not selectable: + if poly_mapper.concrete: + raise exceptions.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.") + elif not poly_mapper.single: + q._from_obj = q._from_obj.outerjoin(poly_mapper.local_table, poly_mapper.inherit_condition) + + return q + def yield_per(self, count): """Yield only ``count`` rows at a time. @@ -412,6 +486,8 @@ class Query(object): # hand side. if self._adapter and not self._aliases: # at the beginning of a join, look at leftmost adapter adapt_against = self._adapter.selectable + elif start is self.select_mapper: # or if its our base mapper, go against our base table + adapt_against = self.table elif start.select_table is not start.mapped_table: # in the middle of a join, look for a polymorphic mapper adapt_against = start.select_table else: @@ -444,7 +520,7 @@ class Query(object): raise exceptions.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (use_selectable.description, prop.mapper.mapped_table.description)) if not isinstance(use_selectable, expression.Alias): use_selectable = use_selectable.alias() - + if prop._is_self_referential() and not create_aliases and not use_selectable: raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires create_aliases=True argument." % str(prop)) @@ -503,24 +579,32 @@ class Query(object): def apply_min(self, col): """apply the SQL ``min()`` function against the given column to the query and return the newly resulting ``Query``. + + DEPRECATED. """ return self._generative_col_aggregate(col, sql.func.min) def apply_max(self, col): """apply the SQL ``max()`` function against the given column to the query and return the newly resulting ``Query``. + + DEPRECATED. """ return self._generative_col_aggregate(col, sql.func.max) def apply_sum(self, col): """apply the SQL ``sum()`` function against the given column to the query and return the newly resulting ``Query``. + + DEPRECATED. """ return self._generative_col_aggregate(col, sql.func.sum) def apply_avg(self, col): """apply the SQL ``avg()`` function against the given column to the query and return the newly resulting ``Query``. + + DEPRECATED. """ return self._generative_col_aggregate(col, sql.func.avg) @@ -852,6 +936,11 @@ class Query(object): context.runid = _new_runid() + # for with_polymorphic, instruct descendant mappers that they + # don't need to post-fetch anything + for m in self._with_polymorphic: + context.attributes[('polymorphic_fetch', m)] = (self.select_mapper, []) + mappers_or_columns = tuple(self._entities) + mappers_or_columns tuples = bool(mappers_or_columns) @@ -950,12 +1039,17 @@ class Query(object): ident = util.to_list(ident) q = self + + # dont use 'polymorphic' mapper if we are refreshing an instance + if refresh_instance and q.select_mapper is not q.mapper: + q = q._new_base_mapper(q.mapper, '_get') + if ident is not None: q = q._no_criterion('get') params = {} - (_get_clause, _get_params) = self.select_mapper._get_clause + (_get_clause, _get_params) = q.select_mapper._get_clause q = q.filter(_get_clause) - for i, primary_key in enumerate(self.primary_key_columns): + for i, primary_key in enumerate(q.primary_key_columns): try: params[_get_params[primary_key].key] = ident[i] except IndexError: @@ -1027,25 +1121,10 @@ class Query(object): return context whereclause = self._criterion - from_obj = self._from_obj - - # if the query's ClauseAdapter is present, and its - # specifically adapting against a modified "select_from" - # argument, apply adaptation to the - # individually selected columns as well as "eager" clauses added; - # otherwise its currently not needed - if self._adapter and self.table not in self._get_joinable_tables(): - adapter = self._adapter - else: - adapter = None - adapter = self._adapter - - # TODO: mappers added via add_entity(), adapt their queries also, - # if those mappers are polymorphic - order_by = self._order_by + if order_by is False: order_by = self.select_mapper.order_by if order_by is False: @@ -1055,22 +1134,31 @@ class Query(object): if from_obj.default_order_by() is not None: order_by = from_obj.default_order_by() - try: - for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode] - except KeyError: - raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode) - + if self._lockmode: + try: + for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode] + except KeyError: + raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode) + else: + for_update = False + # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so # that we only load the appropriate types - if self.select_mapper.single and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None: + if self.select_mapper.single and self.select_mapper.inherits is not None and self.select_mapper.polymorphic_on is not None and self.select_mapper.polymorphic_identity is not None: whereclause = sql.and_(whereclause, self.select_mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.select_mapper.polymorphic_iterator()])) context.from_clause = from_obj - # give all the attached properties a chance to modify the query - # TODO: doing this off the select_mapper. if its the polymorphic mapper, then - # it has no relations() on it. should we compile those too into the query ? (i.e. eagerloads) - for value in self.select_mapper.iterate_properties: + # TODO: compile eagerloads from select_mapper if polymorphic ? [ticket:917] + if self._with_polymorphic: + props = util.Set() + for m in [self.select_mapper] + self._with_polymorphic: + for value in m.iterate_properties: + props.add(value) + else: + props = self.select_mapper.iterate_properties + + for value in props: if self._only_load_props and value.key not in self._only_load_props: continue context.exec_with_path(self.select_mapper, value.key, value.setup, context, only_load_props=self._only_load_props) @@ -1091,12 +1179,9 @@ class Query(object): # eager loaders are present, and the SELECT has limiting criterion # produce a "wrapped" selectable. - # ensure all 'order by' elements are ClauseElement instances - # (since they will potentially be aliased) # locate all embedded Column clauses so they can be added to the # "inner" select statement where they'll be available to the enclosing # statement's "order by" - cf = util.Set() if order_by: order_by = [expression._literal_as_text(o) for o in util.to_list(order_by) or []] @@ -1105,7 +1190,7 @@ class Query(object): if adapter: # TODO: make usage of the ClauseAdapter here to create the list - # of primary columns + # of primary columns ? context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns] cf = [from_obj.corresponding_column(c) or c for c in cf] @@ -1128,7 +1213,7 @@ class Query(object): else: if adapter: # TODO: make usage of the ClauseAdapter here to create row adapter, list - # of primary columns + # of primary columns ? context.primary_columns = [from_obj.corresponding_column(c) or c for c in context.primary_columns] context.row_adapter = mapperutil.create_row_adapter(from_obj, self.table) @@ -1425,13 +1510,12 @@ class Query(object): return self._legacy_filter_by(*args, **params).one() - for deprecated_method in ('list', 'scalar', 'count_by', 'select_whereclause', 'get_by', 'select_by', 'join_by', 'selectfirst', 'selectone', 'select', 'execute', 'select_statement', 'select_text', 'join_to', 'join_via', 'selectfirst_by', - 'selectone_by'): + 'selectone_by', 'apply_max', 'apply_min', 'apply_avg', 'apply_sum'): setattr(Query, deprecated_method, util.deprecated(getattr(Query, deprecated_method), add_deprecation_to_docstring=False)) diff --git a/test/orm/generative.py b/test/orm/generative.py index 9967f34f7e..db8e313e67 100644 --- a/test/orm/generative.py +++ b/test/orm/generative.py @@ -53,6 +53,7 @@ class GenerativeQueryTest(TestBase): assert list(query[-5:]) == orig[-5:] assert query[10:20][5] == orig[10:20][5] + @testing.uses_deprecated('Call to deprecated function apply_max') def test_aggregate(self): sess = create_session(bind=testing.db) query = sess.query(Foo) @@ -77,6 +78,7 @@ class GenerativeQueryTest(TestBase): assert round(avg, 1) == 14.5 @testing.fails_on('firebird', 'mssql') + @testing.uses_deprecated('Call to deprecated function apply_avg') def test_aggregate_3(self): query = create_session(bind=testing.db).query(Foo) diff --git a/test/orm/inheritance/query.py b/test/orm/inheritance/query.py index 3571480292..7d7b8b9d91 100644 --- a/test/orm/inheritance/query.py +++ b/test/orm/inheritance/query.py @@ -194,6 +194,19 @@ def make_test(select_type): self.assertEquals(sess.query(Engineer).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1]) self.assertEquals(sess.query(Person).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) + + def test_join_from_with_polymorphic(self): + sess = create_session() + + for aliased in (True, False): + sess.clear() + self.assertEquals(sess.query(Person).with_polymorphic(Manager).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%review%')).all(), [b1, m1]) + + sess.clear() + self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Paperwork.description.like('%#2%')).all(), [e1, m1]) + + sess.clear() + self.assertEquals(sess.query(Person).with_polymorphic([Manager, Engineer]).join('paperwork', aliased=aliased).filter(Person.c.name.like('%dog%')).filter(Paperwork.description.like('%#2%')).all(), [m1]) def test_join_to_polymorphic(self): sess = create_session() @@ -223,7 +236,58 @@ def make_test(select_type): sess.query(Company).filter(Company.employees.any(and_(Engineer.primary_language=='cobol', people.c.person_id==engineers.c.person_id))).one(), c2 ) - + + def test_expire(self): + """test that individual column refresh doesn't get tripped up by the select_table mapper""" + + sess = create_session() + m1 = sess.query(Manager).filter(Manager.name=='dogbert').one() + sess.expire(m1) + assert m1.status == 'regular manager' + + m2 = sess.query(Manager).filter(Manager.name=='pointy haired boss').one() + sess.expire(m2, ['manager_name', 'golf_swing']) + assert m2.golf_swing=='fore' + + def test_with_polymorphic(self): + + sess = create_session() + + # compare to entities without related collections to prevent additional lazy SQL from firing on + # loaded entities + emps_without_relations = [ + Engineer(name="dilbert", engineer_name="dilbert", primary_language="java", status="regular engineer"), + Engineer(name="wally", engineer_name="wally", primary_language="c++", status="regular engineer"), + Boss(name="pointy haired boss", golf_swing="fore", manager_name="pointy", status="da boss"), + Manager(name="dogbert", manager_name="dogbert", status="regular manager"), + Engineer(name="vlad", engineer_name="vlad", primary_language="cobol", status="elbonian engineer") + ] + + def go(): + self.assertEquals(sess.query(Person).with_polymorphic(Engineer).filter(Engineer.primary_language=='java').all(), emps_without_relations[0:1]) + self.assert_sql_count(testing.db, go, 1) + + sess.clear() + def go(): + self.assertEquals(sess.query(Person).with_polymorphic('*').all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 1) + + sess.clear() + def go(): + self.assertEquals(sess.query(Person).with_polymorphic(Engineer).all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 3) + + sess.clear() + def go(): + self.assertEquals(sess.query(Person).with_polymorphic(Engineer, people.outerjoin(engineers)).all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 3) + + sess.clear() + def go(): + # limit the polymorphic join down to just "Person", overriding select_table + self.assertEquals(sess.query(Person).with_polymorphic(Person).all(), emps_without_relations) + self.assert_sql_count(testing.db, go, 6) + def test_join_to_subclass(self): sess = create_session() diff --git a/test/orm/query.py b/test/orm/query.py index 41ae444614..62bb99a323 100644 --- a/test/orm/query.py +++ b/test/orm/query.py @@ -389,6 +389,7 @@ class AggregateTest(QueryTest): orders = sess.query(Order).filter(Order.id.in_([2, 3, 4])) assert orders.sum(Order.user_id * Order.address_id) == 79 + @testing.uses_deprecated('Call to deprecated function apply_sum') def test_apply(self): sess = create_session() assert sess.query(Order).apply_sum(Order.user_id * Order.address_id).filter(Order.id.in_([2, 3, 4])).one() == 79