diff --git a/doc/build/changelog/unreleased_13/4787.rst b/doc/build/changelog/unreleased_13/4787.rst new file mode 100644 index 0000000000..911a287e6d --- /dev/null +++ b/doc/build/changelog/unreleased_13/4787.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, sql + :tickets: 4787 + + Fixed bug where :meth:`.TypeEngine.column_expression` method would not be + applied to subsequent SELECT statements inside of a UNION or other + :class:`.CompoundSelect`, even though the SELECT statements are rendered at + the topmost level of the statement. New logic now differentiates between + rendering the column expression, which is needed for all SELECTs in the + list, vs. gathering the returned data type for the result row, which is + needed only for the first SELECT. diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 400ac2749c..8bd2249e27 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1799,18 +1799,26 @@ class SQLCompiler(Compiled): column_clause_args, name=None, within_columns_clause=True, + need_column_expressions=False, ): """produce labeled columns present in a select().""" impl = column.type.dialect_impl(self.dialect) - if impl._has_column_expression and populate_result_map: + + if impl._has_column_expression and ( + need_column_expressions or populate_result_map + ): col_expr = impl.column_expression(column) - def add_to_result_map(keyname, name, objects, type_): - self._add_to_result_map( - keyname, name, (column,) + objects, type_ - ) + if populate_result_map: + def add_to_result_map(keyname, name, objects, type_): + self._add_to_result_map( + keyname, name, (column,) + objects, type_ + ) + + else: + add_to_result_map = None else: col_expr = column if populate_result_map: @@ -2085,15 +2093,15 @@ class SQLCompiler(Compiled): toplevel = not self.stack entry = self._default_stack_entry if toplevel else self.stack[-1] - populate_result_map = ( + populate_result_map = need_column_expressions = ( toplevel - or ( - compound_index == 0 - and entry.get("need_result_map_for_compound", False) - ) + or entry.get("need_result_map_for_compound", False) or entry.get("need_result_map_for_nested", False) ) + if compound_index > 0: + populate_result_map = False + # this was first proposed as part of #3372; however, it is not # reached in current tests and could possibly be an assertion # instead. @@ -2138,6 +2146,7 @@ class SQLCompiler(Compiled): asfrom, column_clause_args, name=name, + need_column_expressions=need_column_expressions, ) for name, column in select._columns_plus_names ] diff --git a/test/sql/test_type_expressions.py b/test/sql/test_type_expressions.py index f913ab6389..1f4649ffce 100644 --- a/test/sql/test_type_expressions.py +++ b/test/sql/test_type_expressions.py @@ -7,6 +7,7 @@ from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import TypeDecorator +from sqlalchemy import union from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures @@ -272,6 +273,35 @@ class SelectTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): "test_table.y = outside_bind(inside_bind(:y_1))", ) + def test_compound_select(self): + table = self._fixture() + + s1 = select([table]).where(table.c.y == "hi") + s2 = select([table]).where(table.c.y == "there") + + self.assert_compile( + union(s1, s2), + "SELECT test_table.x, lower(test_table.y) AS y " + "FROM test_table WHERE test_table.y = lower(:y_1) " + "UNION SELECT test_table.x, lower(test_table.y) AS y " + "FROM test_table WHERE test_table.y = lower(:y_2)", + ) + + def test_select_of_compound_select(self): + table = self._fixture() + + s1 = select([table]).where(table.c.y == "hi") + s2 = select([table]).where(table.c.y == "there") + + self.assert_compile( + union(s1, s2).alias().select(), + "SELECT anon_1.x, lower(anon_1.y) AS y FROM " + "(SELECT test_table.x AS x, test_table.y AS y " + "FROM test_table WHERE test_table.y = lower(:y_1) " + "UNION SELECT test_table.x AS x, test_table.y AS y " + "FROM test_table WHERE test_table.y = lower(:y_2)) AS anon_1", + ) + class DerivedTest(_ExprFixture, fixtures.TestBase, AssertsCompiledSQL): __dialect__ = "default"