Merge "Add more nesting features to add_cte()" into main

This commit is contained in:
mike bayer
2022-02-25 17:48:30 +00:00
committed by Gerrit Code Review
4 changed files with 487 additions and 66 deletions
+12
View File
@@ -0,0 +1,12 @@
.. change::
:tags: usecase, sql
:tickets: 7759
Added new parameter :paramref:`.HasCTE.add_cte.nest_here` to
:meth:`.HasCTE.add_cte` which will "nest" a given :class:`.CTE` at the
level of the parent statement. This parameter is equivalent to using the
:paramref:`.HasCTE.cte.nesting` parameter, but may be more intuitive in
some scenarios as it allows the nesting attribute to be set simultaneously
along with the explicit level of the CTE.
The :meth:`.HasCTE.add_cte` method also accepts multiple CTE objects.
+97 -55
View File
@@ -31,6 +31,13 @@ import itertools
import operator
import re
from time import perf_counter
import typing
from typing import Any
from typing import Dict
from typing import List
from typing import MutableMapping
from typing import Optional
from typing import Tuple
from . import base
from . import coercions
@@ -47,6 +54,12 @@ from .elements import quoted_name
from .. import exc
from .. import util
if typing.TYPE_CHECKING:
from .selectable import CTE
from .selectable import FromClause
_FromHintsType = Dict["FromClause", str]
RESERVED_WORDS = set(
[
"all",
@@ -842,7 +855,7 @@ class SQLCompiler(Compiled):
return {}
@util.memoized_instancemethod
def _init_cte_state(self):
def _init_cte_state(self) -> None:
"""Initialize collections related to CTEs only if
a CTE is located, to save on the overhead of
these collections otherwise.
@@ -850,19 +863,21 @@ class SQLCompiler(Compiled):
"""
# collect CTEs to tack on top of a SELECT
# To store the query to print - Dict[cte, text_query]
self.ctes = util.OrderedDict()
self.ctes: MutableMapping[CTE, str] = util.OrderedDict()
# Detect same CTE references - Dict[(level, name), cte]
# Level is required for supporting nesting
self.ctes_by_level_name = {}
self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {}
# To retrieve key/level in ctes_by_level_name -
# Dict[cte_reference, (level, cte_name)]
self.level_name_by_cte = {}
# Dict[cte_reference, (level, cte_name, cte_opts)]
self.level_name_by_cte: Dict[
CTE, Tuple[int, str, selectable._CTEOpts]
] = {}
self.ctes_recursive = False
self.ctes_recursive: bool = False
if self.positional:
self.cte_positional = {}
self.cte_positional: Dict[CTE, List[str]] = {}
@contextlib.contextmanager
def _nested_result(self):
@@ -1604,8 +1619,7 @@ class SQLCompiler(Compiled):
self.stack.append(new_entry)
if taf._independent_ctes:
for cte in taf._independent_ctes:
cte._compiler_dispatch(self, **kw)
self._dispatch_independent_ctes(taf, kw)
populate_result_map = (
toplevel
@@ -1879,8 +1893,7 @@ class SQLCompiler(Compiled):
)
if compound_stmt._independent_ctes:
for cte in compound_stmt._independent_ctes:
cte._compiler_dispatch(self, **kwargs)
self._dispatch_independent_ctes(compound_stmt, kwargs)
keyword = self.compound_keywords.get(cs.keyword)
@@ -2671,16 +2684,25 @@ class SQLCompiler(Compiled):
return ret
def _dispatch_independent_ctes(self, stmt, kw):
local_kw = kw.copy()
local_kw.pop("cte_opts", None)
for cte, opt in zip(
stmt._independent_ctes, stmt._independent_ctes_opts
):
cte._compiler_dispatch(self, cte_opts=opt, **local_kw)
def visit_cte(
self,
cte,
asfrom=False,
ashint=False,
fromhints=None,
visiting_cte=None,
from_linter=None,
**kwargs,
):
cte: CTE,
asfrom: bool = False,
ashint: bool = False,
fromhints: Optional[_FromHintsType] = None,
visiting_cte: Optional[CTE] = None,
from_linter: Optional[FromLinter] = None,
cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
**kwargs: Any,
) -> Optional[str]:
self._init_cte_state()
kwargs["visiting_cte"] = cte
@@ -2695,15 +2717,48 @@ class SQLCompiler(Compiled):
_reference_cte = cte._get_reference_cte()
if _reference_cte in self.level_name_by_cte:
cte_level, _ = self.level_name_by_cte[_reference_cte]
assert _ == cte_name
else:
cte_level = len(self.stack) if cte.nesting else 1
nesting = cte.nesting or cte_opts.nesting
cte_level_name = (cte_level, cte_name)
if cte_level_name in self.ctes_by_level_name:
# check for CTE already encountered
if _reference_cte in self.level_name_by_cte:
cte_level, _, existing_cte_opts = self.level_name_by_cte[
_reference_cte
]
assert _ == cte_name
cte_level_name = (cte_level, cte_name)
existing_cte = self.ctes_by_level_name[cte_level_name]
# check if we are receiving it here with a specific
# "nest_here" location; if so, move it to this location
if cte_opts.nesting:
if existing_cte_opts.nesting:
raise exc.CompileError(
"CTE is stated as 'nest_here' in "
"more than one location"
)
old_level_name = (cte_level, cte_name)
cte_level = len(self.stack) if nesting else 1
cte_level_name = new_level_name = (cte_level, cte_name)
del self.ctes_by_level_name[old_level_name]
self.ctes_by_level_name[new_level_name] = existing_cte
self.level_name_by_cte[_reference_cte] = new_level_name + (
cte_opts,
)
else:
cte_level = len(self.stack) if nesting else 1
cte_level_name = (cte_level, cte_name)
if cte_level_name in self.ctes_by_level_name:
existing_cte = self.ctes_by_level_name[cte_level_name]
else:
existing_cte = None
if existing_cte is not None:
embedded_in_current_named_cte = visiting_cte is existing_cte
# we've generated a same-named CTE that we are enclosed in,
@@ -2718,10 +2773,8 @@ class SQLCompiler(Compiled):
existing_cte_reference_cte = existing_cte._get_reference_cte()
# TODO: determine if these assertions are correct. they
# pass for current test cases
# assert existing_cte_reference_cte is _reference_cte
# assert existing_cte_reference_cte is existing_cte
assert existing_cte_reference_cte is _reference_cte
assert existing_cte_reference_cte is existing_cte
del self.level_name_by_cte[existing_cte_reference_cte]
else:
@@ -2746,19 +2799,9 @@ class SQLCompiler(Compiled):
if is_new_cte:
self.ctes_by_level_name[cte_level_name] = cte
self.level_name_by_cte[_reference_cte] = cte_level_name
if (
"autocommit" in cte.element._execution_options
and "autocommit" not in self.execution_options
):
self.execution_options = self.execution_options.union(
{
"autocommit": cte.element._execution_options[
"autocommit"
]
}
)
self.level_name_by_cte[_reference_cte] = cte_level_name + (
cte_opts,
)
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
@@ -3378,8 +3421,7 @@ class SQLCompiler(Compiled):
byfrom = None
if select_stmt._independent_ctes:
for cte in select_stmt._independent_ctes:
cte._compiler_dispatch(self, **kwargs)
self._dispatch_independent_ctes(select_stmt, kwargs)
if select_stmt._prefixes:
text += self._generate_prefixes(
@@ -3485,7 +3527,9 @@ class SQLCompiler(Compiled):
return text
def _setup_select_hints(self, select):
def _setup_select_hints(
self, select: Select
) -> Tuple[str, _FromHintsType]:
byfrom = dict(
[
(
@@ -3663,13 +3707,14 @@ class SQLCompiler(Compiled):
if nesting_level and nesting_level > 1:
ctes = util.OrderedDict()
for cte in list(self.ctes.keys()):
cte_level, cte_name = self.level_name_by_cte[
cte_level, cte_name, cte_opts = self.level_name_by_cte[
cte._get_reference_cte()
]
nesting = cte.nesting or cte_opts.nesting
is_rendered_level = cte_level == nesting_level or (
include_following_stack and cte_level == nesting_level + 1
)
if not (cte.nesting and is_rendered_level):
if not (nesting and is_rendered_level):
continue
ctes[cte] = self.ctes[cte]
@@ -3693,7 +3738,7 @@ class SQLCompiler(Compiled):
if nesting_level and nesting_level > 1:
for cte in list(ctes.keys()):
cte_level, cte_name = self.level_name_by_cte[
cte_level, cte_name, cte_opts = self.level_name_by_cte[
cte._get_reference_cte()
]
del self.ctes[cte]
@@ -3939,8 +3984,7 @@ class SQLCompiler(Compiled):
_, table_text = self._setup_crud_hints(insert_stmt, table_text)
if insert_stmt._independent_ctes:
for cte in insert_stmt._independent_ctes:
cte._compiler_dispatch(self, **kw)
self._dispatch_independent_ctes(insert_stmt, kw)
text += table_text
@@ -4108,8 +4152,7 @@ class SQLCompiler(Compiled):
dialect_hints = None
if update_stmt._independent_ctes:
for cte in update_stmt._independent_ctes:
cte._compiler_dispatch(self, **kw)
self._dispatch_independent_ctes(update_stmt, kw)
text += table_text
@@ -4221,8 +4264,7 @@ class SQLCompiler(Compiled):
dialect_hints = None
if delete_stmt._independent_ctes:
for cte in delete_stmt._independent_ctes:
cte._compiler_dispatch(self, **kw)
self._dispatch_independent_ctes(delete_stmt, kw)
text += table_text
+85 -11
View File
@@ -19,6 +19,7 @@ import itertools
from operator import attrgetter
import typing
from typing import Any as TODO_Any
from typing import NamedTuple
from typing import Optional
from typing import Tuple
@@ -1809,6 +1810,10 @@ class CTE(
SelfHasCTE = typing.TypeVar("SelfHasCTE", bound="HasCTE")
class _CTEOpts(NamedTuple):
nesting: bool
class HasCTE(roles.HasCTERole):
"""Mixin that declares a class to include CTE support.
@@ -1818,20 +1823,36 @@ class HasCTE(roles.HasCTERole):
_has_ctes_traverse_internals = [
("_independent_ctes", InternalTraversal.dp_clauseelement_list),
("_independent_ctes_opts", InternalTraversal.dp_plain_obj),
]
_independent_ctes = ()
_independent_ctes_opts = ()
@_generative
def add_cte(self: SelfHasCTE, cte) -> SelfHasCTE:
"""Add a :class:`_sql.CTE` to this statement object that will be
independently rendered even if not referenced in the statement
otherwise.
def add_cte(self: SelfHasCTE, *ctes, nest_here=False) -> SelfHasCTE:
r"""Add one or more :class:`_sql.CTE` constructs to this statement.
This feature is useful for the use case of embedding a DML statement
such as an INSERT or UPDATE as a CTE inline with a primary statement
that may draw from its results indirectly; while PostgreSQL is known
to support this usage, it may not be supported by other backends.
This method will associate the given :class:`_sql.CTE` constructs with
the parent statement such that they will each be unconditionally
rendered in the WITH clause of the final statement, even if not
referenced elsewhere within the statement or any sub-selects.
The optional :paramref:`.HasCTE.add_cte.nest_here` parameter when set
to True will have the effect that each given :class:`_sql.CTE` will
render in a WITH clause rendered directly along with this statement,
rather than being moved to the top of the ultimate rendered statement,
even if this statement is rendered as a subquery within a larger
statement.
This method has two general uses. One is to embed CTE statements that
serve some purpose without being referenced explicitly, such as the use
case of embedding a DML statement such as an INSERT or UPDATE as a CTE
inline with a primary statement that may draw from its results
indirectly. The other is to provide control over the exact placement
of a particular series of CTE constructs that should remain rendered
directly in terms of a particular statement that may be nested in a
larger statement.
E.g.::
@@ -1885,9 +1906,32 @@ class HasCTE(roles.HasCTERole):
.. versionadded:: 1.4.21
:param \*ctes: zero or more :class:`.CTE` constructs.
.. versionchanged:: 2.0 Multiple CTE instances are accepted
:param nest_here: if True, the given CTE or CTEs will be rendered
as though they specified the :paramref:`.HasCTE.cte.nesting` flag
to ``True`` when they were added to this :class:`.HasCTE`.
Assuming the given CTEs are not referenced in an outer-enclosing
statement as well, the CTEs given should render at the level of
this statement when this flag is given.
.. versionadded:: 2.0
.. seealso::
:paramref:`.HasCTE.cte.nesting`
"""
cte = coercions.expect(roles.IsCTERole, cte)
self._independent_ctes += (cte,)
opt = _CTEOpts(
nest_here,
)
for cte in ctes:
cte = coercions.expect(roles.IsCTERole, cte)
self._independent_ctes += (cte,)
self._independent_ctes_opts += (opt,)
return self
def cte(self, name=None, recursive=False, nesting=False):
@@ -1931,10 +1975,18 @@ class HasCTE(roles.HasCTERole):
conjunction with UNION ALL in order to derive rows
from those already selected.
:param nesting: if ``True``, will render the CTE locally to the
actual statement.
statement in which it is referenced. For more complex scenarios,
the :meth:`.HasCTE.add_cte` method using the
:paramref:`.HasCTE.add_cte.nest_here`
parameter may also be used to more carefully
control the exact placement of a particular CTE.
.. versionadded:: 1.4.24
.. seealso::
:meth:`.HasCTE.add_cte`
The following examples include two from PostgreSQL's documentation at
https://www.postgresql.org/docs/current/static/queries-with.html,
as well as additional examples.
@@ -2084,6 +2136,28 @@ class HasCTE(roles.HasCTERole):
SELECT value_a.n AS a, value_b.n AS b
FROM value_a, value_b
The same CTE can be set up using the :meth:`.HasCTE.add_cte` method
as follows (SQLAlchemy 2.0 and above)::
value_a = select(
literal("root").label("n")
).cte("value_a")
# A nested CTE with the same name as the root one
value_a_nested = select(
literal("nesting").label("n")
).cte("value_a")
# Nesting CTEs takes ascendency locally
# over the CTEs at a higher level
value_b = (
select(value_a_nested.c.n).
add_cte(value_a_nested, nest_here=True).
cte("value_b")
)
value_ab = select(value_a.c.n.label("a"), value_b.c.n.label("b"))
Example 5, Non-Linear CTE (SQLAlchemy 1.4.28 and above)::
edge = Table(
+293
View File
@@ -1,11 +1,13 @@
from sqlalchemy import Column
from sqlalchemy import delete
from sqlalchemy import exc
from sqlalchemy import Integer
from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
from sqlalchemy import MetaData
from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy import true
from sqlalchemy import update
from sqlalchemy.dialects import mssql
from sqlalchemy.engine import default
@@ -25,6 +27,7 @@ from sqlalchemy.sql.visitors import cloned_traverse
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
@@ -1869,6 +1872,21 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT cte.outer_cte FROM cte",
)
def test_select_with_nesting_cte_in_cte_w_add_cte(self):
nesting_cte = select(literal(1).label("inner_cte")).cte("nesting")
stmt = select(
select(nesting_cte.c.inner_cte.label("outer_cte"))
.add_cte(nesting_cte, nest_here=True)
.cte("cte")
)
self.assert_compile(
stmt,
"WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) "
"SELECT nesting.inner_cte AS outer_cte FROM nesting) "
"SELECT cte.outer_cte FROM cte",
)
def test_select_with_aliased_nesting_cte_in_cte(self):
nesting_cte = (
select(literal(1).label("inner_cte"))
@@ -1887,6 +1905,25 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT cte.outer_cte FROM cte",
)
def test_select_with_aliased_nesting_cte_in_cte_w_add_cte(self):
inner_nesting_cte = select(literal(1).label("inner_cte")).cte(
"nesting"
)
outer_cte = select().add_cte(inner_nesting_cte, nest_here=True)
nesting_cte = inner_nesting_cte.alias("aliased_nested")
outer_cte = outer_cte.add_columns(
nesting_cte.c.inner_cte.label("outer_cte")
).cte("cte")
stmt = select(outer_cte)
self.assert_compile(
stmt,
"WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) "
"SELECT aliased_nested.inner_cte AS outer_cte "
"FROM nesting AS aliased_nested) "
"SELECT cte.outer_cte FROM cte",
)
def test_nesting_cte_in_cte_with_same_name(self):
nesting_cte = select(literal(1).label("inner_cte")).cte(
"some_cte", nesting=True
@@ -1904,6 +1941,23 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT some_cte.outer_cte FROM some_cte",
)
def test_nesting_cte_in_cte_with_same_name_w_add_cte(self):
nesting_cte = select(literal(1).label("inner_cte")).cte("some_cte")
stmt = select(
select(nesting_cte.c.inner_cte.label("outer_cte"))
.add_cte(nesting_cte, nest_here=True)
.cte("some_cte")
)
self.assert_compile(
stmt,
"WITH some_cte AS (WITH some_cte AS "
"(SELECT :param_1 AS inner_cte) "
"SELECT some_cte.inner_cte AS outer_cte "
"FROM some_cte) "
"SELECT some_cte.outer_cte FROM some_cte",
)
def test_nesting_cte_at_top_level(self):
nesting_cte = select(literal(1).label("val")).cte(
"nesting_cte", nesting=True
@@ -1918,6 +1972,20 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
" SELECT nesting_cte.val, cte.val AS val_1 FROM nesting_cte, cte",
)
def test_nesting_cte_at_top_level_w_add_cte(self):
nesting_cte = select(literal(1).label("val")).cte("nesting_cte")
cte = select(literal(2).label("val")).cte("cte")
stmt = select(nesting_cte.c.val, cte.c.val).add_cte(
nesting_cte, nest_here=True
)
self.assert_compile(
stmt,
"WITH nesting_cte AS (SELECT :param_1 AS val)"
", cte AS (SELECT :param_2 AS val)"
" SELECT nesting_cte.val, cte.val AS val_1 FROM nesting_cte, cte",
)
def test_double_nesting_cte_in_cte(self):
"""
Validate that the SELECT in the 2nd nesting CTE does not render
@@ -1950,6 +2018,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
") SELECT cte.outer_1, cte.outer_2 FROM cte",
)
def test_double_nesting_cte_in_cte_w_add_cte(self):
"""
Validate that the SELECT in the 2nd nesting CTE does not render
the 1st CTE.
It implies that nesting CTE level is taken in account.
"""
select_1_cte = select(literal(1).label("inner_cte")).cte("nesting_1")
select_2_cte = select(literal(2).label("inner_cte")).cte("nesting_2")
stmt = select(
select(
select_1_cte.c.inner_cte.label("outer_1"),
select_2_cte.c.inner_cte.label("outer_2"),
)
.add_cte(select_1_cte, select_2_cte, nest_here=True)
.cte("cte")
)
self.assert_compile(
stmt,
"WITH cte AS ("
"WITH nesting_1 AS (SELECT :param_1 AS inner_cte)"
", nesting_2 AS (SELECT :param_2 AS inner_cte)"
" SELECT nesting_1.inner_cte AS outer_1"
", nesting_2.inner_cte AS outer_2"
" FROM nesting_1, nesting_2"
") SELECT cte.outer_1, cte.outer_2 FROM cte",
)
def test_double_nesting_cte_with_cross_reference_in_cte(self):
select_1_cte = select(literal(1).label("inner_cte_1")).cte(
"nesting_1", nesting=True
@@ -1993,6 +2091,32 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
") SELECT cte.inner_cte_2, cte.inner_cte_1 FROM cte",
)
def test_double_nesting_cte_with_cross_reference_in_cte_w_add_cte(self):
select_1_cte = select(literal(1).label("inner_cte_1")).cte("nesting_1")
select_2_cte = select(
(select_1_cte.c.inner_cte_1 + 1).label("inner_cte_2")
).cte("nesting_2")
# 1 next 2
nesting_cte_1_2 = (
select(select_1_cte, select_2_cte)
.add_cte(select_1_cte, select_2_cte, nest_here=True)
.cte("cte")
)
stmt_1_2 = select(nesting_cte_1_2)
self.assert_compile(
stmt_1_2,
"WITH cte AS ("
"WITH nesting_1 AS (SELECT :param_1 AS inner_cte_1)"
", nesting_2 AS (SELECT nesting_1.inner_cte_1 + :inner_cte_1_1"
" AS inner_cte_2 FROM nesting_1)"
" SELECT nesting_1.inner_cte_1 AS inner_cte_1"
", nesting_2.inner_cte_2 AS inner_cte_2"
" FROM nesting_1, nesting_2"
") SELECT cte.inner_cte_1, cte.inner_cte_2 FROM cte",
)
def test_nesting_cte_in_nesting_cte_in_cte(self):
select_1_cte = select(literal(1).label("inner_cte")).cte(
"nesting_1", nesting=True
@@ -2069,6 +2193,31 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
"SELECT rec_cte.outer_cte FROM rec_cte",
)
def test_nesting_cte_in_recursive_cte_w_add_cte(self):
nesting_cte = select(literal(1).label("inner_cte")).cte(
"nesting", nesting=True
)
rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
"rec_cte", recursive=True
)
rec_part = select(rec_cte.c.outer_cte).where(
rec_cte.c.outer_cte == literal(1)
)
rec_cte = rec_cte.union(rec_part)
stmt = select(rec_cte)
self.assert_compile(
stmt,
"WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
"(SELECT :param_1 AS inner_cte) "
"SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
"SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
"WHERE rec_cte.outer_cte = :param_2) "
"SELECT rec_cte.outer_cte FROM rec_cte",
)
def test_recursive_nesting_cte_in_cte(self):
rec_root = select(literal(1).label("inner_cte")).cte(
"nesting", recursive=True, nesting=True
@@ -2209,6 +2358,80 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
"FROM nesting_cte",
)
def test_add_cte_dont_nest_in_two_places(self):
nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte(
"nesting_cte"
)
select_add_cte = select(
(nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value")
).cte("nesting_2")
union_cte = (
select(
(nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value")
)
.add_cte(nesting_cte_used_twice, nest_here=True)
.union(
select(select_add_cte).add_cte(select_add_cte, nest_here=True)
)
.cte("wrapper")
)
stmt = (
select(union_cte)
.add_cte(nesting_cte_used_twice, nest_here=True)
.union(select(nesting_cte_used_twice))
)
with expect_raises_message(
exc.CompileError,
"CTE is stated as 'nest_here' in more than one location",
):
stmt.compile()
def test_same_nested_cte_is_not_generated_twice_w_add_cte(self):
# Same = name and query
nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte(
"nesting_cte"
)
select_add_cte = select(
(nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value")
).cte("nesting_2")
union_cte = (
select(
(nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value")
)
.add_cte(nesting_cte_used_twice)
.union(
select(select_add_cte).add_cte(select_add_cte, nest_here=True)
)
.cte("wrapper")
)
stmt = (
select(union_cte)
.add_cte(nesting_cte_used_twice, nest_here=True)
.union(select(nesting_cte_used_twice))
)
self.assert_compile(
stmt,
"WITH nesting_cte AS "
"(SELECT :param_1 AS inner_cte_1)"
", wrapper AS "
"(WITH nesting_2 AS "
"(SELECT nesting_cte.inner_cte_1 + :inner_cte_1_2 "
"AS next_value "
"FROM nesting_cte)"
" SELECT nesting_cte.inner_cte_1 - :inner_cte_1_1 "
"AS next_value "
"FROM nesting_cte UNION SELECT nesting_2.next_value AS next_value "
"FROM nesting_2)"
" SELECT wrapper.next_value "
"FROM wrapper UNION SELECT nesting_cte.inner_cte_1 "
"FROM nesting_cte",
)
def test_recursive_nesting_cte_in_recursive_cte(self):
nesting_cte = select(literal(1).label("inner_cte")).cte(
"nesting", nesting=True, recursive=True
@@ -2363,6 +2586,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
") SELECT cte.outer_cte FROM cte",
)
def test_compound_select_with_nesting_cte_in_custom_order_w_add_cte(self):
select_1_cte = select(literal(1).label("inner_cte")).cte("nesting_1")
select_2_cte = select(literal(2).label("inner_cte")).cte("nesting_2")
nesting_cte = (
select(select_1_cte)
.add_cte(select_1_cte, nest_here=True)
.union(select(select_2_cte))
# Generate "select_2_cte" first
.add_cte(select_2_cte, nest_here=True)
.subquery()
)
stmt = select(
select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
)
self.assert_compile(
stmt,
"WITH cte AS ("
"SELECT anon_1.inner_cte AS outer_cte FROM ("
"WITH nesting_2 AS (SELECT :param_1 AS inner_cte)"
", nesting_1 AS (SELECT :param_2 AS inner_cte)"
" SELECT nesting_1.inner_cte AS inner_cte FROM nesting_1"
" UNION"
" SELECT nesting_2.inner_cte AS inner_cte FROM nesting_2"
") AS anon_1"
") SELECT cte.outer_cte FROM cte",
)
def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self):
rec_root = select(literal(1).label("the_value")).cte(
"recursive_cte", recursive=True
@@ -2411,3 +2664,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
" WHERE should_continue.val != true))"
" SELECT recursive_cte.the_value FROM recursive_cte",
)
@testing.combinations(True, False)
def test_correlated_cte_in_lateral_w_add_cte(self, reverse_direction):
"""this is the original use case that led to #7759"""
contracts = table("contracts", column("id"))
invoices = table("invoices", column("id"), column("contract_id"))
contracts_alias = contracts.alias()
cte1 = (
select(contracts_alias)
.where(contracts_alias.c.id == contracts.c.id)
.correlate(contracts)
.cte(name="cte1")
)
cte2 = (
select(invoices)
.join(cte1, invoices.c.contract_id == cte1.c.id)
.cte(name="cte2")
)
if reverse_direction:
subq = select(cte1, cte2).add_cte(cte2, cte1, nest_here=True)
else:
subq = select(cte1, cte2).add_cte(cte1, cte2, nest_here=True)
stmt = select(contracts).outerjoin(subq.lateral(), true())
self.assert_compile(
stmt,
"SELECT contracts.id FROM contracts LEFT OUTER JOIN LATERAL "
"(WITH cte1 AS (SELECT contracts_1.id AS id "
"FROM contracts AS contracts_1 "
"WHERE contracts_1.id = contracts.id), "
"cte2 AS (SELECT invoices.id AS id, "
"invoices.contract_id AS contract_id FROM invoices "
"JOIN cte1 ON invoices.contract_id = cte1.id) "
"SELECT cte1.id AS id, cte2.id AS id_1, "
"cte2.contract_id AS contract_id "
"FROM cte1, cte2) AS anon_1 ON true",
)