mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-15 13:17:24 -04:00
Merge "Add more nesting features to add_cte()" into main
This commit is contained in:
+12
@@ -0,0 +1,12 @@
|
||||
.. change::
|
||||
:tags: usecase, sql
|
||||
:tickets: 7759
|
||||
|
||||
Added new parameter :paramref:`.HasCTE.add_cte.nest_here` to
|
||||
:meth:`.HasCTE.add_cte` which will "nest" a given :class:`.CTE` at the
|
||||
level of the parent statement. This parameter is equivalent to using the
|
||||
:paramref:`.HasCTE.cte.nesting` parameter, but may be more intuitive in
|
||||
some scenarios as it allows the nesting attribute to be set simultaneously
|
||||
along with the explicit level of the CTE.
|
||||
|
||||
The :meth:`.HasCTE.add_cte` method also accepts multiple CTE objects.
|
||||
@@ -31,6 +31,13 @@ import itertools
|
||||
import operator
|
||||
import re
|
||||
from time import perf_counter
|
||||
import typing
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import MutableMapping
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from . import base
|
||||
from . import coercions
|
||||
@@ -47,6 +54,12 @@ from .elements import quoted_name
|
||||
from .. import exc
|
||||
from .. import util
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .selectable import CTE
|
||||
from .selectable import FromClause
|
||||
|
||||
_FromHintsType = Dict["FromClause", str]
|
||||
|
||||
RESERVED_WORDS = set(
|
||||
[
|
||||
"all",
|
||||
@@ -842,7 +855,7 @@ class SQLCompiler(Compiled):
|
||||
return {}
|
||||
|
||||
@util.memoized_instancemethod
|
||||
def _init_cte_state(self):
|
||||
def _init_cte_state(self) -> None:
|
||||
"""Initialize collections related to CTEs only if
|
||||
a CTE is located, to save on the overhead of
|
||||
these collections otherwise.
|
||||
@@ -850,19 +863,21 @@ class SQLCompiler(Compiled):
|
||||
"""
|
||||
# collect CTEs to tack on top of a SELECT
|
||||
# To store the query to print - Dict[cte, text_query]
|
||||
self.ctes = util.OrderedDict()
|
||||
self.ctes: MutableMapping[CTE, str] = util.OrderedDict()
|
||||
|
||||
# Detect same CTE references - Dict[(level, name), cte]
|
||||
# Level is required for supporting nesting
|
||||
self.ctes_by_level_name = {}
|
||||
self.ctes_by_level_name: Dict[Tuple[int, str], CTE] = {}
|
||||
|
||||
# To retrieve key/level in ctes_by_level_name -
|
||||
# Dict[cte_reference, (level, cte_name)]
|
||||
self.level_name_by_cte = {}
|
||||
# Dict[cte_reference, (level, cte_name, cte_opts)]
|
||||
self.level_name_by_cte: Dict[
|
||||
CTE, Tuple[int, str, selectable._CTEOpts]
|
||||
] = {}
|
||||
|
||||
self.ctes_recursive = False
|
||||
self.ctes_recursive: bool = False
|
||||
if self.positional:
|
||||
self.cte_positional = {}
|
||||
self.cte_positional: Dict[CTE, List[str]] = {}
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _nested_result(self):
|
||||
@@ -1604,8 +1619,7 @@ class SQLCompiler(Compiled):
|
||||
self.stack.append(new_entry)
|
||||
|
||||
if taf._independent_ctes:
|
||||
for cte in taf._independent_ctes:
|
||||
cte._compiler_dispatch(self, **kw)
|
||||
self._dispatch_independent_ctes(taf, kw)
|
||||
|
||||
populate_result_map = (
|
||||
toplevel
|
||||
@@ -1879,8 +1893,7 @@ class SQLCompiler(Compiled):
|
||||
)
|
||||
|
||||
if compound_stmt._independent_ctes:
|
||||
for cte in compound_stmt._independent_ctes:
|
||||
cte._compiler_dispatch(self, **kwargs)
|
||||
self._dispatch_independent_ctes(compound_stmt, kwargs)
|
||||
|
||||
keyword = self.compound_keywords.get(cs.keyword)
|
||||
|
||||
@@ -2671,16 +2684,25 @@ class SQLCompiler(Compiled):
|
||||
|
||||
return ret
|
||||
|
||||
def _dispatch_independent_ctes(self, stmt, kw):
|
||||
local_kw = kw.copy()
|
||||
local_kw.pop("cte_opts", None)
|
||||
for cte, opt in zip(
|
||||
stmt._independent_ctes, stmt._independent_ctes_opts
|
||||
):
|
||||
cte._compiler_dispatch(self, cte_opts=opt, **local_kw)
|
||||
|
||||
def visit_cte(
|
||||
self,
|
||||
cte,
|
||||
asfrom=False,
|
||||
ashint=False,
|
||||
fromhints=None,
|
||||
visiting_cte=None,
|
||||
from_linter=None,
|
||||
**kwargs,
|
||||
):
|
||||
cte: CTE,
|
||||
asfrom: bool = False,
|
||||
ashint: bool = False,
|
||||
fromhints: Optional[_FromHintsType] = None,
|
||||
visiting_cte: Optional[CTE] = None,
|
||||
from_linter: Optional[FromLinter] = None,
|
||||
cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
|
||||
**kwargs: Any,
|
||||
) -> Optional[str]:
|
||||
self._init_cte_state()
|
||||
|
||||
kwargs["visiting_cte"] = cte
|
||||
@@ -2695,15 +2717,48 @@ class SQLCompiler(Compiled):
|
||||
|
||||
_reference_cte = cte._get_reference_cte()
|
||||
|
||||
if _reference_cte in self.level_name_by_cte:
|
||||
cte_level, _ = self.level_name_by_cte[_reference_cte]
|
||||
assert _ == cte_name
|
||||
else:
|
||||
cte_level = len(self.stack) if cte.nesting else 1
|
||||
nesting = cte.nesting or cte_opts.nesting
|
||||
|
||||
cte_level_name = (cte_level, cte_name)
|
||||
if cte_level_name in self.ctes_by_level_name:
|
||||
# check for CTE already encountered
|
||||
if _reference_cte in self.level_name_by_cte:
|
||||
cte_level, _, existing_cte_opts = self.level_name_by_cte[
|
||||
_reference_cte
|
||||
]
|
||||
assert _ == cte_name
|
||||
|
||||
cte_level_name = (cte_level, cte_name)
|
||||
existing_cte = self.ctes_by_level_name[cte_level_name]
|
||||
|
||||
# check if we are receiving it here with a specific
|
||||
# "nest_here" location; if so, move it to this location
|
||||
|
||||
if cte_opts.nesting:
|
||||
if existing_cte_opts.nesting:
|
||||
raise exc.CompileError(
|
||||
"CTE is stated as 'nest_here' in "
|
||||
"more than one location"
|
||||
)
|
||||
|
||||
old_level_name = (cte_level, cte_name)
|
||||
cte_level = len(self.stack) if nesting else 1
|
||||
cte_level_name = new_level_name = (cte_level, cte_name)
|
||||
|
||||
del self.ctes_by_level_name[old_level_name]
|
||||
self.ctes_by_level_name[new_level_name] = existing_cte
|
||||
self.level_name_by_cte[_reference_cte] = new_level_name + (
|
||||
cte_opts,
|
||||
)
|
||||
|
||||
else:
|
||||
cte_level = len(self.stack) if nesting else 1
|
||||
cte_level_name = (cte_level, cte_name)
|
||||
|
||||
if cte_level_name in self.ctes_by_level_name:
|
||||
existing_cte = self.ctes_by_level_name[cte_level_name]
|
||||
else:
|
||||
existing_cte = None
|
||||
|
||||
if existing_cte is not None:
|
||||
embedded_in_current_named_cte = visiting_cte is existing_cte
|
||||
|
||||
# we've generated a same-named CTE that we are enclosed in,
|
||||
@@ -2718,10 +2773,8 @@ class SQLCompiler(Compiled):
|
||||
|
||||
existing_cte_reference_cte = existing_cte._get_reference_cte()
|
||||
|
||||
# TODO: determine if these assertions are correct. they
|
||||
# pass for current test cases
|
||||
# assert existing_cte_reference_cte is _reference_cte
|
||||
# assert existing_cte_reference_cte is existing_cte
|
||||
assert existing_cte_reference_cte is _reference_cte
|
||||
assert existing_cte_reference_cte is existing_cte
|
||||
|
||||
del self.level_name_by_cte[existing_cte_reference_cte]
|
||||
else:
|
||||
@@ -2746,19 +2799,9 @@ class SQLCompiler(Compiled):
|
||||
|
||||
if is_new_cte:
|
||||
self.ctes_by_level_name[cte_level_name] = cte
|
||||
self.level_name_by_cte[_reference_cte] = cte_level_name
|
||||
|
||||
if (
|
||||
"autocommit" in cte.element._execution_options
|
||||
and "autocommit" not in self.execution_options
|
||||
):
|
||||
self.execution_options = self.execution_options.union(
|
||||
{
|
||||
"autocommit": cte.element._execution_options[
|
||||
"autocommit"
|
||||
]
|
||||
}
|
||||
)
|
||||
self.level_name_by_cte[_reference_cte] = cte_level_name + (
|
||||
cte_opts,
|
||||
)
|
||||
|
||||
if pre_alias_cte not in self.ctes:
|
||||
self.visit_cte(pre_alias_cte, **kwargs)
|
||||
@@ -3378,8 +3421,7 @@ class SQLCompiler(Compiled):
|
||||
byfrom = None
|
||||
|
||||
if select_stmt._independent_ctes:
|
||||
for cte in select_stmt._independent_ctes:
|
||||
cte._compiler_dispatch(self, **kwargs)
|
||||
self._dispatch_independent_ctes(select_stmt, kwargs)
|
||||
|
||||
if select_stmt._prefixes:
|
||||
text += self._generate_prefixes(
|
||||
@@ -3485,7 +3527,9 @@ class SQLCompiler(Compiled):
|
||||
|
||||
return text
|
||||
|
||||
def _setup_select_hints(self, select):
|
||||
def _setup_select_hints(
|
||||
self, select: Select
|
||||
) -> Tuple[str, _FromHintsType]:
|
||||
byfrom = dict(
|
||||
[
|
||||
(
|
||||
@@ -3663,13 +3707,14 @@ class SQLCompiler(Compiled):
|
||||
if nesting_level and nesting_level > 1:
|
||||
ctes = util.OrderedDict()
|
||||
for cte in list(self.ctes.keys()):
|
||||
cte_level, cte_name = self.level_name_by_cte[
|
||||
cte_level, cte_name, cte_opts = self.level_name_by_cte[
|
||||
cte._get_reference_cte()
|
||||
]
|
||||
nesting = cte.nesting or cte_opts.nesting
|
||||
is_rendered_level = cte_level == nesting_level or (
|
||||
include_following_stack and cte_level == nesting_level + 1
|
||||
)
|
||||
if not (cte.nesting and is_rendered_level):
|
||||
if not (nesting and is_rendered_level):
|
||||
continue
|
||||
|
||||
ctes[cte] = self.ctes[cte]
|
||||
@@ -3693,7 +3738,7 @@ class SQLCompiler(Compiled):
|
||||
|
||||
if nesting_level and nesting_level > 1:
|
||||
for cte in list(ctes.keys()):
|
||||
cte_level, cte_name = self.level_name_by_cte[
|
||||
cte_level, cte_name, cte_opts = self.level_name_by_cte[
|
||||
cte._get_reference_cte()
|
||||
]
|
||||
del self.ctes[cte]
|
||||
@@ -3939,8 +3984,7 @@ class SQLCompiler(Compiled):
|
||||
_, table_text = self._setup_crud_hints(insert_stmt, table_text)
|
||||
|
||||
if insert_stmt._independent_ctes:
|
||||
for cte in insert_stmt._independent_ctes:
|
||||
cte._compiler_dispatch(self, **kw)
|
||||
self._dispatch_independent_ctes(insert_stmt, kw)
|
||||
|
||||
text += table_text
|
||||
|
||||
@@ -4108,8 +4152,7 @@ class SQLCompiler(Compiled):
|
||||
dialect_hints = None
|
||||
|
||||
if update_stmt._independent_ctes:
|
||||
for cte in update_stmt._independent_ctes:
|
||||
cte._compiler_dispatch(self, **kw)
|
||||
self._dispatch_independent_ctes(update_stmt, kw)
|
||||
|
||||
text += table_text
|
||||
|
||||
@@ -4221,8 +4264,7 @@ class SQLCompiler(Compiled):
|
||||
dialect_hints = None
|
||||
|
||||
if delete_stmt._independent_ctes:
|
||||
for cte in delete_stmt._independent_ctes:
|
||||
cte._compiler_dispatch(self, **kw)
|
||||
self._dispatch_independent_ctes(delete_stmt, kw)
|
||||
|
||||
text += table_text
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import itertools
|
||||
from operator import attrgetter
|
||||
import typing
|
||||
from typing import Any as TODO_Any
|
||||
from typing import NamedTuple
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
@@ -1809,6 +1810,10 @@ class CTE(
|
||||
SelfHasCTE = typing.TypeVar("SelfHasCTE", bound="HasCTE")
|
||||
|
||||
|
||||
class _CTEOpts(NamedTuple):
|
||||
nesting: bool
|
||||
|
||||
|
||||
class HasCTE(roles.HasCTERole):
|
||||
"""Mixin that declares a class to include CTE support.
|
||||
|
||||
@@ -1818,20 +1823,36 @@ class HasCTE(roles.HasCTERole):
|
||||
|
||||
_has_ctes_traverse_internals = [
|
||||
("_independent_ctes", InternalTraversal.dp_clauseelement_list),
|
||||
("_independent_ctes_opts", InternalTraversal.dp_plain_obj),
|
||||
]
|
||||
|
||||
_independent_ctes = ()
|
||||
_independent_ctes_opts = ()
|
||||
|
||||
@_generative
|
||||
def add_cte(self: SelfHasCTE, cte) -> SelfHasCTE:
|
||||
"""Add a :class:`_sql.CTE` to this statement object that will be
|
||||
independently rendered even if not referenced in the statement
|
||||
otherwise.
|
||||
def add_cte(self: SelfHasCTE, *ctes, nest_here=False) -> SelfHasCTE:
|
||||
r"""Add one or more :class:`_sql.CTE` constructs to this statement.
|
||||
|
||||
This feature is useful for the use case of embedding a DML statement
|
||||
such as an INSERT or UPDATE as a CTE inline with a primary statement
|
||||
that may draw from its results indirectly; while PostgreSQL is known
|
||||
to support this usage, it may not be supported by other backends.
|
||||
This method will associate the given :class:`_sql.CTE` constructs with
|
||||
the parent statement such that they will each be unconditionally
|
||||
rendered in the WITH clause of the final statement, even if not
|
||||
referenced elsewhere within the statement or any sub-selects.
|
||||
|
||||
The optional :paramref:`.HasCTE.add_cte.nest_here` parameter when set
|
||||
to True will have the effect that each given :class:`_sql.CTE` will
|
||||
render in a WITH clause rendered directly along with this statement,
|
||||
rather than being moved to the top of the ultimate rendered statement,
|
||||
even if this statement is rendered as a subquery within a larger
|
||||
statement.
|
||||
|
||||
This method has two general uses. One is to embed CTE statements that
|
||||
serve some purpose without being referenced explicitly, such as the use
|
||||
case of embedding a DML statement such as an INSERT or UPDATE as a CTE
|
||||
inline with a primary statement that may draw from its results
|
||||
indirectly. The other is to provide control over the exact placement
|
||||
of a particular series of CTE constructs that should remain rendered
|
||||
directly in terms of a particular statement that may be nested in a
|
||||
larger statement.
|
||||
|
||||
E.g.::
|
||||
|
||||
@@ -1885,9 +1906,32 @@ class HasCTE(roles.HasCTERole):
|
||||
|
||||
.. versionadded:: 1.4.21
|
||||
|
||||
:param \*ctes: zero or more :class:`.CTE` constructs.
|
||||
|
||||
.. versionchanged:: 2.0 Multiple CTE instances are accepted
|
||||
|
||||
:param nest_here: if True, the given CTE or CTEs will be rendered
|
||||
as though they specified the :paramref:`.HasCTE.cte.nesting` flag
|
||||
to ``True`` when they were added to this :class:`.HasCTE`.
|
||||
Assuming the given CTEs are not referenced in an outer-enclosing
|
||||
statement as well, the CTEs given should render at the level of
|
||||
this statement when this flag is given.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
.. seealso::
|
||||
|
||||
:paramref:`.HasCTE.cte.nesting`
|
||||
|
||||
|
||||
"""
|
||||
cte = coercions.expect(roles.IsCTERole, cte)
|
||||
self._independent_ctes += (cte,)
|
||||
opt = _CTEOpts(
|
||||
nest_here,
|
||||
)
|
||||
for cte in ctes:
|
||||
cte = coercions.expect(roles.IsCTERole, cte)
|
||||
self._independent_ctes += (cte,)
|
||||
self._independent_ctes_opts += (opt,)
|
||||
return self
|
||||
|
||||
def cte(self, name=None, recursive=False, nesting=False):
|
||||
@@ -1931,10 +1975,18 @@ class HasCTE(roles.HasCTERole):
|
||||
conjunction with UNION ALL in order to derive rows
|
||||
from those already selected.
|
||||
:param nesting: if ``True``, will render the CTE locally to the
|
||||
actual statement.
|
||||
statement in which it is referenced. For more complex scenarios,
|
||||
the :meth:`.HasCTE.add_cte` method using the
|
||||
:paramref:`.HasCTE.add_cte.nest_here`
|
||||
parameter may also be used to more carefully
|
||||
control the exact placement of a particular CTE.
|
||||
|
||||
.. versionadded:: 1.4.24
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`.HasCTE.add_cte`
|
||||
|
||||
The following examples include two from PostgreSQL's documentation at
|
||||
https://www.postgresql.org/docs/current/static/queries-with.html,
|
||||
as well as additional examples.
|
||||
@@ -2084,6 +2136,28 @@ class HasCTE(roles.HasCTERole):
|
||||
SELECT value_a.n AS a, value_b.n AS b
|
||||
FROM value_a, value_b
|
||||
|
||||
The same CTE can be set up using the :meth:`.HasCTE.add_cte` method
|
||||
as follows (SQLAlchemy 2.0 and above)::
|
||||
|
||||
value_a = select(
|
||||
literal("root").label("n")
|
||||
).cte("value_a")
|
||||
|
||||
# A nested CTE with the same name as the root one
|
||||
value_a_nested = select(
|
||||
literal("nesting").label("n")
|
||||
).cte("value_a")
|
||||
|
||||
# Nesting CTEs takes ascendency locally
|
||||
# over the CTEs at a higher level
|
||||
value_b = (
|
||||
select(value_a_nested.c.n).
|
||||
add_cte(value_a_nested, nest_here=True).
|
||||
cte("value_b")
|
||||
)
|
||||
|
||||
value_ab = select(value_a.c.n.label("a"), value_b.c.n.label("b"))
|
||||
|
||||
Example 5, Non-Linear CTE (SQLAlchemy 1.4.28 and above)::
|
||||
|
||||
edge = Table(
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import true
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects import mssql
|
||||
from sqlalchemy.engine import default
|
||||
@@ -25,6 +27,7 @@ from sqlalchemy.sql.visitors import cloned_traverse
|
||||
from sqlalchemy.testing import assert_raises_message
|
||||
from sqlalchemy.testing import AssertsCompiledSQL
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import expect_raises_message
|
||||
from sqlalchemy.testing import fixtures
|
||||
|
||||
|
||||
@@ -1869,6 +1872,21 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
|
||||
"SELECT cte.outer_cte FROM cte",
|
||||
)
|
||||
|
||||
def test_select_with_nesting_cte_in_cte_w_add_cte(self):
|
||||
nesting_cte = select(literal(1).label("inner_cte")).cte("nesting")
|
||||
stmt = select(
|
||||
select(nesting_cte.c.inner_cte.label("outer_cte"))
|
||||
.add_cte(nesting_cte, nest_here=True)
|
||||
.cte("cte")
|
||||
)
|
||||
|
||||
self.assert_compile(
|
||||
stmt,
|
||||
"WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) "
|
||||
"SELECT nesting.inner_cte AS outer_cte FROM nesting) "
|
||||
"SELECT cte.outer_cte FROM cte",
|
||||
)
|
||||
|
||||
def test_select_with_aliased_nesting_cte_in_cte(self):
|
||||
nesting_cte = (
|
||||
select(literal(1).label("inner_cte"))
|
||||
@@ -1887,6 +1905,25 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
|
||||
"SELECT cte.outer_cte FROM cte",
|
||||
)
|
||||
|
||||
def test_select_with_aliased_nesting_cte_in_cte_w_add_cte(self):
|
||||
inner_nesting_cte = select(literal(1).label("inner_cte")).cte(
|
||||
"nesting"
|
||||
)
|
||||
outer_cte = select().add_cte(inner_nesting_cte, nest_here=True)
|
||||
nesting_cte = inner_nesting_cte.alias("aliased_nested")
|
||||
outer_cte = outer_cte.add_columns(
|
||||
nesting_cte.c.inner_cte.label("outer_cte")
|
||||
).cte("cte")
|
||||
stmt = select(outer_cte)
|
||||
|
||||
self.assert_compile(
|
||||
stmt,
|
||||
"WITH cte AS (WITH nesting AS (SELECT :param_1 AS inner_cte) "
|
||||
"SELECT aliased_nested.inner_cte AS outer_cte "
|
||||
"FROM nesting AS aliased_nested) "
|
||||
"SELECT cte.outer_cte FROM cte",
|
||||
)
|
||||
|
||||
def test_nesting_cte_in_cte_with_same_name(self):
|
||||
nesting_cte = select(literal(1).label("inner_cte")).cte(
|
||||
"some_cte", nesting=True
|
||||
@@ -1904,6 +1941,23 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
|
||||
"SELECT some_cte.outer_cte FROM some_cte",
|
||||
)
|
||||
|
||||
def test_nesting_cte_in_cte_with_same_name_w_add_cte(self):
|
||||
nesting_cte = select(literal(1).label("inner_cte")).cte("some_cte")
|
||||
stmt = select(
|
||||
select(nesting_cte.c.inner_cte.label("outer_cte"))
|
||||
.add_cte(nesting_cte, nest_here=True)
|
||||
.cte("some_cte")
|
||||
)
|
||||
|
||||
self.assert_compile(
|
||||
stmt,
|
||||
"WITH some_cte AS (WITH some_cte AS "
|
||||
"(SELECT :param_1 AS inner_cte) "
|
||||
"SELECT some_cte.inner_cte AS outer_cte "
|
||||
"FROM some_cte) "
|
||||
"SELECT some_cte.outer_cte FROM some_cte",
|
||||
)
|
||||
|
||||
def test_nesting_cte_at_top_level(self):
|
||||
nesting_cte = select(literal(1).label("val")).cte(
|
||||
"nesting_cte", nesting=True
|
||||
@@ -1918,6 +1972,20 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
|
||||
" SELECT nesting_cte.val, cte.val AS val_1 FROM nesting_cte, cte",
|
||||
)
|
||||
|
||||
def test_nesting_cte_at_top_level_w_add_cte(self):
|
||||
nesting_cte = select(literal(1).label("val")).cte("nesting_cte")
|
||||
cte = select(literal(2).label("val")).cte("cte")
|
||||
stmt = select(nesting_cte.c.val, cte.c.val).add_cte(
|
||||
nesting_cte, nest_here=True
|
||||
)
|
||||
|
||||
self.assert_compile(
|
||||
stmt,
|
||||
"WITH nesting_cte AS (SELECT :param_1 AS val)"
|
||||
", cte AS (SELECT :param_2 AS val)"
|
||||
" SELECT nesting_cte.val, cte.val AS val_1 FROM nesting_cte, cte",
|
||||
)
|
||||
|
||||
def test_double_nesting_cte_in_cte(self):
|
||||
"""
|
||||
Validate that the SELECT in the 2nd nesting CTE does not render
|
||||
@@ -1950,6 +2018,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
|
||||
") SELECT cte.outer_1, cte.outer_2 FROM cte",
|
||||
)
|
||||
|
||||
def test_double_nesting_cte_in_cte_w_add_cte(self):
|
||||
"""
|
||||
Validate that the SELECT in the 2nd nesting CTE does not render
|
||||
the 1st CTE.
|
||||
|
||||
It implies that nesting CTE level is taken in account.
|
||||
"""
|
||||
select_1_cte = select(literal(1).label("inner_cte")).cte("nesting_1")
|
||||
select_2_cte = select(literal(2).label("inner_cte")).cte("nesting_2")
|
||||
|
||||
stmt = select(
|
||||
select(
|
||||
select_1_cte.c.inner_cte.label("outer_1"),
|
||||
select_2_cte.c.inner_cte.label("outer_2"),
|
||||
)
|
||||
.add_cte(select_1_cte, select_2_cte, nest_here=True)
|
||||
.cte("cte")
|
||||
)
|
||||
|
||||
self.assert_compile(
|
||||
stmt,
|
||||
"WITH cte AS ("
|
||||
"WITH nesting_1 AS (SELECT :param_1 AS inner_cte)"
|
||||
", nesting_2 AS (SELECT :param_2 AS inner_cte)"
|
||||
" SELECT nesting_1.inner_cte AS outer_1"
|
||||
", nesting_2.inner_cte AS outer_2"
|
||||
" FROM nesting_1, nesting_2"
|
||||
") SELECT cte.outer_1, cte.outer_2 FROM cte",
|
||||
)
|
||||
|
||||
def test_double_nesting_cte_with_cross_reference_in_cte(self):
|
||||
select_1_cte = select(literal(1).label("inner_cte_1")).cte(
|
||||
"nesting_1", nesting=True
|
||||
@@ -1993,6 +2091,32 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
|
||||
") SELECT cte.inner_cte_2, cte.inner_cte_1 FROM cte",
|
||||
)
|
||||
|
||||
def test_double_nesting_cte_with_cross_reference_in_cte_w_add_cte(self):
|
||||
select_1_cte = select(literal(1).label("inner_cte_1")).cte("nesting_1")
|
||||
select_2_cte = select(
|
||||
(select_1_cte.c.inner_cte_1 + 1).label("inner_cte_2")
|
||||
).cte("nesting_2")
|
||||
|
||||
# 1 next 2
|
||||
|
||||
nesting_cte_1_2 = (
|
||||
select(select_1_cte, select_2_cte)
|
||||
.add_cte(select_1_cte, select_2_cte, nest_here=True)
|
||||
.cte("cte")
|
||||
)
|
||||
stmt_1_2 = select(nesting_cte_1_2)
|
||||
self.assert_compile(
|
||||
stmt_1_2,
|
||||
"WITH cte AS ("
|
||||
"WITH nesting_1 AS (SELECT :param_1 AS inner_cte_1)"
|
||||
", nesting_2 AS (SELECT nesting_1.inner_cte_1 + :inner_cte_1_1"
|
||||
" AS inner_cte_2 FROM nesting_1)"
|
||||
" SELECT nesting_1.inner_cte_1 AS inner_cte_1"
|
||||
", nesting_2.inner_cte_2 AS inner_cte_2"
|
||||
" FROM nesting_1, nesting_2"
|
||||
") SELECT cte.inner_cte_1, cte.inner_cte_2 FROM cte",
|
||||
)
|
||||
|
||||
def test_nesting_cte_in_nesting_cte_in_cte(self):
|
||||
select_1_cte = select(literal(1).label("inner_cte")).cte(
|
||||
"nesting_1", nesting=True
|
||||
@@ -2069,6 +2193,31 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
|
||||
"SELECT rec_cte.outer_cte FROM rec_cte",
|
||||
)
|
||||
|
||||
def test_nesting_cte_in_recursive_cte_w_add_cte(self):
|
||||
nesting_cte = select(literal(1).label("inner_cte")).cte(
|
||||
"nesting", nesting=True
|
||||
)
|
||||
|
||||
rec_cte = select(nesting_cte.c.inner_cte.label("outer_cte")).cte(
|
||||
"rec_cte", recursive=True
|
||||
)
|
||||
rec_part = select(rec_cte.c.outer_cte).where(
|
||||
rec_cte.c.outer_cte == literal(1)
|
||||
)
|
||||
rec_cte = rec_cte.union(rec_part)
|
||||
|
||||
stmt = select(rec_cte)
|
||||
|
||||
self.assert_compile(
|
||||
stmt,
|
||||
"WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
|
||||
"(SELECT :param_1 AS inner_cte) "
|
||||
"SELECT nesting.inner_cte AS outer_cte FROM nesting UNION "
|
||||
"SELECT rec_cte.outer_cte AS outer_cte FROM rec_cte "
|
||||
"WHERE rec_cte.outer_cte = :param_2) "
|
||||
"SELECT rec_cte.outer_cte FROM rec_cte",
|
||||
)
|
||||
|
||||
def test_recursive_nesting_cte_in_cte(self):
|
||||
rec_root = select(literal(1).label("inner_cte")).cte(
|
||||
"nesting", recursive=True, nesting=True
|
||||
@@ -2209,6 +2358,80 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
|
||||
"FROM nesting_cte",
|
||||
)
|
||||
|
||||
def test_add_cte_dont_nest_in_two_places(self):
|
||||
nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte(
|
||||
"nesting_cte"
|
||||
)
|
||||
select_add_cte = select(
|
||||
(nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value")
|
||||
).cte("nesting_2")
|
||||
|
||||
union_cte = (
|
||||
select(
|
||||
(nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value")
|
||||
)
|
||||
.add_cte(nesting_cte_used_twice, nest_here=True)
|
||||
.union(
|
||||
select(select_add_cte).add_cte(select_add_cte, nest_here=True)
|
||||
)
|
||||
.cte("wrapper")
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(union_cte)
|
||||
.add_cte(nesting_cte_used_twice, nest_here=True)
|
||||
.union(select(nesting_cte_used_twice))
|
||||
)
|
||||
with expect_raises_message(
|
||||
exc.CompileError,
|
||||
"CTE is stated as 'nest_here' in more than one location",
|
||||
):
|
||||
stmt.compile()
|
||||
|
||||
def test_same_nested_cte_is_not_generated_twice_w_add_cte(self):
|
||||
# Same = name and query
|
||||
nesting_cte_used_twice = select(literal(1).label("inner_cte_1")).cte(
|
||||
"nesting_cte"
|
||||
)
|
||||
select_add_cte = select(
|
||||
(nesting_cte_used_twice.c.inner_cte_1 + 1).label("next_value")
|
||||
).cte("nesting_2")
|
||||
|
||||
union_cte = (
|
||||
select(
|
||||
(nesting_cte_used_twice.c.inner_cte_1 - 1).label("next_value")
|
||||
)
|
||||
.add_cte(nesting_cte_used_twice)
|
||||
.union(
|
||||
select(select_add_cte).add_cte(select_add_cte, nest_here=True)
|
||||
)
|
||||
.cte("wrapper")
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(union_cte)
|
||||
.add_cte(nesting_cte_used_twice, nest_here=True)
|
||||
.union(select(nesting_cte_used_twice))
|
||||
)
|
||||
|
||||
self.assert_compile(
|
||||
stmt,
|
||||
"WITH nesting_cte AS "
|
||||
"(SELECT :param_1 AS inner_cte_1)"
|
||||
", wrapper AS "
|
||||
"(WITH nesting_2 AS "
|
||||
"(SELECT nesting_cte.inner_cte_1 + :inner_cte_1_2 "
|
||||
"AS next_value "
|
||||
"FROM nesting_cte)"
|
||||
" SELECT nesting_cte.inner_cte_1 - :inner_cte_1_1 "
|
||||
"AS next_value "
|
||||
"FROM nesting_cte UNION SELECT nesting_2.next_value AS next_value "
|
||||
"FROM nesting_2)"
|
||||
" SELECT wrapper.next_value "
|
||||
"FROM wrapper UNION SELECT nesting_cte.inner_cte_1 "
|
||||
"FROM nesting_cte",
|
||||
)
|
||||
|
||||
def test_recursive_nesting_cte_in_recursive_cte(self):
|
||||
nesting_cte = select(literal(1).label("inner_cte")).cte(
|
||||
"nesting", nesting=True, recursive=True
|
||||
@@ -2363,6 +2586,36 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
|
||||
") SELECT cte.outer_cte FROM cte",
|
||||
)
|
||||
|
||||
def test_compound_select_with_nesting_cte_in_custom_order_w_add_cte(self):
|
||||
select_1_cte = select(literal(1).label("inner_cte")).cte("nesting_1")
|
||||
select_2_cte = select(literal(2).label("inner_cte")).cte("nesting_2")
|
||||
|
||||
nesting_cte = (
|
||||
select(select_1_cte)
|
||||
.add_cte(select_1_cte, nest_here=True)
|
||||
.union(select(select_2_cte))
|
||||
# Generate "select_2_cte" first
|
||||
.add_cte(select_2_cte, nest_here=True)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
stmt = select(
|
||||
select(nesting_cte.c.inner_cte.label("outer_cte")).cte("cte")
|
||||
)
|
||||
|
||||
self.assert_compile(
|
||||
stmt,
|
||||
"WITH cte AS ("
|
||||
"SELECT anon_1.inner_cte AS outer_cte FROM ("
|
||||
"WITH nesting_2 AS (SELECT :param_1 AS inner_cte)"
|
||||
", nesting_1 AS (SELECT :param_2 AS inner_cte)"
|
||||
" SELECT nesting_1.inner_cte AS inner_cte FROM nesting_1"
|
||||
" UNION"
|
||||
" SELECT nesting_2.inner_cte AS inner_cte FROM nesting_2"
|
||||
") AS anon_1"
|
||||
") SELECT cte.outer_cte FROM cte",
|
||||
)
|
||||
|
||||
def test_recursive_cte_referenced_multiple_times_with_nesting_cte(self):
|
||||
rec_root = select(literal(1).label("the_value")).cte(
|
||||
"recursive_cte", recursive=True
|
||||
@@ -2411,3 +2664,43 @@ class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
|
||||
" WHERE should_continue.val != true))"
|
||||
" SELECT recursive_cte.the_value FROM recursive_cte",
|
||||
)
|
||||
|
||||
@testing.combinations(True, False)
|
||||
def test_correlated_cte_in_lateral_w_add_cte(self, reverse_direction):
|
||||
"""this is the original use case that led to #7759"""
|
||||
contracts = table("contracts", column("id"))
|
||||
|
||||
invoices = table("invoices", column("id"), column("contract_id"))
|
||||
|
||||
contracts_alias = contracts.alias()
|
||||
cte1 = (
|
||||
select(contracts_alias)
|
||||
.where(contracts_alias.c.id == contracts.c.id)
|
||||
.correlate(contracts)
|
||||
.cte(name="cte1")
|
||||
)
|
||||
cte2 = (
|
||||
select(invoices)
|
||||
.join(cte1, invoices.c.contract_id == cte1.c.id)
|
||||
.cte(name="cte2")
|
||||
)
|
||||
|
||||
if reverse_direction:
|
||||
subq = select(cte1, cte2).add_cte(cte2, cte1, nest_here=True)
|
||||
else:
|
||||
subq = select(cte1, cte2).add_cte(cte1, cte2, nest_here=True)
|
||||
stmt = select(contracts).outerjoin(subq.lateral(), true())
|
||||
|
||||
self.assert_compile(
|
||||
stmt,
|
||||
"SELECT contracts.id FROM contracts LEFT OUTER JOIN LATERAL "
|
||||
"(WITH cte1 AS (SELECT contracts_1.id AS id "
|
||||
"FROM contracts AS contracts_1 "
|
||||
"WHERE contracts_1.id = contracts.id), "
|
||||
"cte2 AS (SELECT invoices.id AS id, "
|
||||
"invoices.contract_id AS contract_id FROM invoices "
|
||||
"JOIN cte1 ON invoices.contract_id = cte1.id) "
|
||||
"SELECT cte1.id AS id, cte2.id AS id_1, "
|
||||
"cte2.contract_id AS contract_id "
|
||||
"FROM cte1, cte2) AS anon_1 ON true",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user