mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-08 01:40:48 -04:00
7c80e521f0
be always be correctly propagated when one CTE referred to another aliased CTE in a statement. Fixes #3154
465 lines
17 KiB
Python
465 lines
17 KiB
Python
from sqlalchemy.testing import fixtures
|
|
from sqlalchemy.testing import AssertsCompiledSQL, assert_raises_message
|
|
from sqlalchemy.sql import table, column, select, func, literal
|
|
from sqlalchemy.dialects import mssql
|
|
from sqlalchemy.engine import default
|
|
from sqlalchemy.exc import CompileError
|
|
|
|
|
|
class CTETest(fixtures.TestBase, AssertsCompiledSQL):
|
|
|
|
__dialect__ = 'default'
|
|
|
|
def test_nonrecursive(self):
|
|
orders = table('orders',
|
|
column('region'),
|
|
column('amount'),
|
|
column('product'),
|
|
column('quantity')
|
|
)
|
|
|
|
regional_sales = select([
|
|
orders.c.region,
|
|
func.sum(orders.c.amount).label('total_sales')
|
|
]).group_by(orders.c.region).cte("regional_sales")
|
|
|
|
top_regions = select([regional_sales.c.region]).\
|
|
where(
|
|
regional_sales.c.total_sales > select([
|
|
func.sum(regional_sales.c.total_sales) / 10
|
|
])
|
|
).cte("top_regions")
|
|
|
|
s = select([
|
|
orders.c.region,
|
|
orders.c.product,
|
|
func.sum(orders.c.quantity).label("product_units"),
|
|
func.sum(orders.c.amount).label("product_sales")
|
|
]).where(orders.c.region.in_(
|
|
select([top_regions.c.region])
|
|
)).group_by(orders.c.region, orders.c.product)
|
|
|
|
# needs to render regional_sales first as top_regions
|
|
# refers to it
|
|
self.assert_compile(
|
|
s,
|
|
"WITH regional_sales AS (SELECT orders.region AS region, "
|
|
"sum(orders.amount) AS total_sales FROM orders "
|
|
"GROUP BY orders.region), "
|
|
"top_regions AS (SELECT "
|
|
"regional_sales.region AS region FROM regional_sales "
|
|
"WHERE regional_sales.total_sales > "
|
|
"(SELECT sum(regional_sales.total_sales) / :sum_1 AS "
|
|
"anon_1 FROM regional_sales)) "
|
|
"SELECT orders.region, orders.product, "
|
|
"sum(orders.quantity) AS product_units, "
|
|
"sum(orders.amount) AS product_sales "
|
|
"FROM orders WHERE orders.region "
|
|
"IN (SELECT top_regions.region FROM top_regions) "
|
|
"GROUP BY orders.region, orders.product"
|
|
)
|
|
|
|
def test_recursive(self):
|
|
parts = table('parts',
|
|
column('part'),
|
|
column('sub_part'),
|
|
column('quantity'),
|
|
)
|
|
|
|
included_parts = select([
|
|
parts.c.sub_part,
|
|
parts.c.part,
|
|
parts.c.quantity]).\
|
|
where(parts.c.part == 'our part').\
|
|
cte(recursive=True)
|
|
|
|
incl_alias = included_parts.alias()
|
|
parts_alias = parts.alias()
|
|
included_parts = included_parts.union(
|
|
select([
|
|
parts_alias.c.sub_part,
|
|
parts_alias.c.part,
|
|
parts_alias.c.quantity]).
|
|
where(parts_alias.c.part == incl_alias.c.sub_part)
|
|
)
|
|
|
|
s = select([
|
|
included_parts.c.sub_part,
|
|
func.sum(included_parts.c.quantity).label('total_quantity')]).\
|
|
select_from(included_parts.join(
|
|
parts, included_parts.c.part == parts.c.part)).\
|
|
group_by(included_parts.c.sub_part)
|
|
self.assert_compile(
|
|
s, "WITH RECURSIVE anon_1(sub_part, part, quantity) "
|
|
"AS (SELECT parts.sub_part AS sub_part, parts.part "
|
|
"AS part, parts.quantity AS quantity FROM parts "
|
|
"WHERE parts.part = :part_1 UNION "
|
|
"SELECT parts_1.sub_part AS sub_part, "
|
|
"parts_1.part AS part, parts_1.quantity "
|
|
"AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
|
|
"WHERE parts_1.part = anon_2.sub_part) "
|
|
"SELECT anon_1.sub_part, "
|
|
"sum(anon_1.quantity) AS total_quantity FROM anon_1 "
|
|
"JOIN parts ON anon_1.part = parts.part "
|
|
"GROUP BY anon_1.sub_part")
|
|
|
|
# quick check that the "WITH RECURSIVE" varies per
|
|
# dialect
|
|
self.assert_compile(
|
|
s, "WITH anon_1(sub_part, part, quantity) "
|
|
"AS (SELECT parts.sub_part AS sub_part, parts.part "
|
|
"AS part, parts.quantity AS quantity FROM parts "
|
|
"WHERE parts.part = :part_1 UNION "
|
|
"SELECT parts_1.sub_part AS sub_part, "
|
|
"parts_1.part AS part, parts_1.quantity "
|
|
"AS quantity FROM parts AS parts_1, anon_1 AS anon_2 "
|
|
"WHERE parts_1.part = anon_2.sub_part) "
|
|
"SELECT anon_1.sub_part, "
|
|
"sum(anon_1.quantity) AS total_quantity FROM anon_1 "
|
|
"JOIN parts ON anon_1.part = parts.part "
|
|
"GROUP BY anon_1.sub_part", dialect=mssql.dialect())
|
|
|
|
def test_recursive_union_no_alias_one(self):
|
|
s1 = select([literal(0).label("x")])
|
|
cte = s1.cte(name="cte", recursive=True)
|
|
cte = cte.union_all(
|
|
select([cte.c.x + 1]).where(cte.c.x < 10)
|
|
)
|
|
s2 = select([cte])
|
|
self.assert_compile(s2,
|
|
"WITH RECURSIVE cte(x) AS "
|
|
"(SELECT :param_1 AS x UNION ALL "
|
|
"SELECT cte.x + :x_1 AS anon_1 "
|
|
"FROM cte WHERE cte.x < :x_2) "
|
|
"SELECT cte.x FROM cte"
|
|
)
|
|
|
|
def test_recursive_union_no_alias_two(self):
|
|
"""
|
|
|
|
pg's example:
|
|
|
|
WITH RECURSIVE t(n) AS (
|
|
VALUES (1)
|
|
UNION ALL
|
|
SELECT n+1 FROM t WHERE n < 100
|
|
)
|
|
SELECT sum(n) FROM t;
|
|
|
|
"""
|
|
|
|
# I know, this is the PG VALUES keyword,
|
|
# we're cheating here. also yes we need the SELECT,
|
|
# sorry PG.
|
|
t = select([func.values(1).label("n")]).cte("t", recursive=True)
|
|
t = t.union_all(select([t.c.n + 1]).where(t.c.n < 100))
|
|
s = select([func.sum(t.c.n)])
|
|
self.assert_compile(s,
|
|
"WITH RECURSIVE t(n) AS "
|
|
"(SELECT values(:values_1) AS n "
|
|
"UNION ALL SELECT t.n + :n_1 AS anon_1 "
|
|
"FROM t "
|
|
"WHERE t.n < :n_2) "
|
|
"SELECT sum(t.n) AS sum_1 FROM t"
|
|
)
|
|
|
|
def test_recursive_union_no_alias_three(self):
|
|
# like test one, but let's refer to the CTE
|
|
# in a sibling CTE.
|
|
|
|
s1 = select([literal(0).label("x")])
|
|
cte = s1.cte(name="cte", recursive=True)
|
|
|
|
# can't do it here...
|
|
#bar = select([cte]).cte('bar')
|
|
cte = cte.union_all(
|
|
select([cte.c.x + 1]).where(cte.c.x < 10)
|
|
)
|
|
bar = select([cte]).cte('bar')
|
|
|
|
s2 = select([cte, bar])
|
|
self.assert_compile(s2,
|
|
"WITH RECURSIVE cte(x) AS "
|
|
"(SELECT :param_1 AS x UNION ALL "
|
|
"SELECT cte.x + :x_1 AS anon_1 "
|
|
"FROM cte WHERE cte.x < :x_2), "
|
|
"bar AS (SELECT cte.x AS x FROM cte) "
|
|
"SELECT cte.x, bar.x FROM cte, bar"
|
|
)
|
|
|
|
def test_recursive_union_no_alias_four(self):
|
|
# like test one and three, but let's refer
|
|
# previous version of "cte". here we test
|
|
# how the compiler resolves multiple instances
|
|
# of "cte".
|
|
|
|
s1 = select([literal(0).label("x")])
|
|
cte = s1.cte(name="cte", recursive=True)
|
|
|
|
bar = select([cte]).cte('bar')
|
|
cte = cte.union_all(
|
|
select([cte.c.x + 1]).where(cte.c.x < 10)
|
|
)
|
|
|
|
# outer cte rendered first, then bar, which
|
|
# includes "inner" cte
|
|
s2 = select([cte, bar])
|
|
self.assert_compile(s2,
|
|
"WITH RECURSIVE cte(x) AS "
|
|
"(SELECT :param_1 AS x UNION ALL "
|
|
"SELECT cte.x + :x_1 AS anon_1 "
|
|
"FROM cte WHERE cte.x < :x_2), "
|
|
"bar AS (SELECT cte.x AS x FROM cte) "
|
|
"SELECT cte.x, bar.x FROM cte, bar"
|
|
)
|
|
|
|
# bar rendered, only includes "inner" cte,
|
|
# "outer" cte isn't present
|
|
s2 = select([bar])
|
|
self.assert_compile(s2,
|
|
"WITH RECURSIVE cte(x) AS "
|
|
"(SELECT :param_1 AS x), "
|
|
"bar AS (SELECT cte.x AS x FROM cte) "
|
|
"SELECT bar.x FROM bar"
|
|
)
|
|
|
|
# bar rendered, but then the "outer"
|
|
# cte is rendered.
|
|
s2 = select([bar, cte])
|
|
self.assert_compile(
|
|
s2, "WITH RECURSIVE bar AS (SELECT cte.x AS x FROM cte), "
|
|
"cte(x) AS "
|
|
"(SELECT :param_1 AS x UNION ALL "
|
|
"SELECT cte.x + :x_1 AS anon_1 "
|
|
"FROM cte WHERE cte.x < :x_2) "
|
|
"SELECT bar.x, cte.x FROM bar, cte")
|
|
|
|
def test_conflicting_names(self):
|
|
"""test a flat out name conflict."""
|
|
|
|
s1 = select([1])
|
|
c1 = s1.cte(name='cte1', recursive=True)
|
|
s2 = select([1])
|
|
c2 = s2.cte(name='cte1', recursive=True)
|
|
|
|
s = select([c1, c2])
|
|
assert_raises_message(
|
|
CompileError,
|
|
"Multiple, unrelated CTEs found "
|
|
"with the same name: 'cte1'",
|
|
s.compile
|
|
)
|
|
|
|
def test_union(self):
|
|
orders = table('orders',
|
|
column('region'),
|
|
column('amount'),
|
|
)
|
|
|
|
regional_sales = select([
|
|
orders.c.region,
|
|
orders.c.amount
|
|
]).cte("regional_sales")
|
|
|
|
s = select(
|
|
[regional_sales.c.region]).where(
|
|
regional_sales.c.amount > 500
|
|
)
|
|
|
|
self.assert_compile(s,
|
|
"WITH regional_sales AS "
|
|
"(SELECT orders.region AS region, "
|
|
"orders.amount AS amount FROM orders) "
|
|
"SELECT regional_sales.region "
|
|
"FROM regional_sales WHERE "
|
|
"regional_sales.amount > :amount_1")
|
|
|
|
s = s.union_all(
|
|
select([regional_sales.c.region]).
|
|
where(
|
|
regional_sales.c.amount < 300
|
|
)
|
|
)
|
|
self.assert_compile(s,
|
|
"WITH regional_sales AS "
|
|
"(SELECT orders.region AS region, "
|
|
"orders.amount AS amount FROM orders) "
|
|
"SELECT regional_sales.region FROM regional_sales "
|
|
"WHERE regional_sales.amount > :amount_1 "
|
|
"UNION ALL SELECT regional_sales.region "
|
|
"FROM regional_sales WHERE "
|
|
"regional_sales.amount < :amount_2")
|
|
|
|
def test_reserved_quote(self):
|
|
orders = table('orders',
|
|
column('order'),
|
|
)
|
|
s = select([orders.c.order]).cte("regional_sales", recursive=True)
|
|
s = select([s.c.order])
|
|
self.assert_compile(s,
|
|
'WITH RECURSIVE regional_sales("order") AS '
|
|
'(SELECT orders."order" AS "order" '
|
|
"FROM orders)"
|
|
' SELECT regional_sales."order" '
|
|
"FROM regional_sales"
|
|
)
|
|
|
|
def test_multi_subq_quote(self):
|
|
cte = select([literal(1).label("id")]).cte(name='CTE')
|
|
|
|
s1 = select([cte.c.id]).alias()
|
|
s2 = select([cte.c.id]).alias()
|
|
|
|
s = select([s1, s2])
|
|
self.assert_compile(
|
|
s,
|
|
'WITH "CTE" AS (SELECT :param_1 AS id) '
|
|
'SELECT anon_1.id, anon_2.id FROM '
|
|
'(SELECT "CTE".id AS id FROM "CTE") AS anon_1, '
|
|
'(SELECT "CTE".id AS id FROM "CTE") AS anon_2'
|
|
)
|
|
|
|
def test_positional_binds(self):
|
|
orders = table('orders',
|
|
column('order'),
|
|
)
|
|
s = select([orders.c.order, literal("x")]).cte("regional_sales")
|
|
s = select([s.c.order, literal("y")])
|
|
dialect = default.DefaultDialect()
|
|
dialect.positional = True
|
|
dialect.paramstyle = 'numeric'
|
|
self.assert_compile(
|
|
s,
|
|
'WITH regional_sales AS (SELECT orders."order" '
|
|
'AS "order", :1 AS anon_2 FROM orders) SELECT '
|
|
'regional_sales."order", :2 AS anon_1 FROM regional_sales',
|
|
checkpositional=(
|
|
'x',
|
|
'y'),
|
|
dialect=dialect)
|
|
|
|
self.assert_compile(
|
|
s.union(s), 'WITH regional_sales AS (SELECT orders."order" '
|
|
'AS "order", :1 AS anon_2 FROM orders) SELECT '
|
|
'regional_sales."order", :2 AS anon_1 FROM regional_sales '
|
|
'UNION SELECT regional_sales."order", :3 AS anon_1 '
|
|
'FROM regional_sales', checkpositional=(
|
|
'x', 'y', 'y'), dialect=dialect)
|
|
|
|
s = select([orders.c.order]).\
|
|
where(orders.c.order == 'x').cte("regional_sales")
|
|
s = select([s.c.order]).where(s.c.order == "y")
|
|
self.assert_compile(
|
|
s, 'WITH regional_sales AS (SELECT orders."order" AS '
|
|
'"order" FROM orders WHERE orders."order" = :1) '
|
|
'SELECT regional_sales."order" FROM regional_sales '
|
|
'WHERE regional_sales."order" = :2', checkpositional=(
|
|
'x', 'y'), dialect=dialect)
|
|
|
|
def test_positional_binds_2(self):
|
|
orders = table('orders',
|
|
column('order'),
|
|
)
|
|
s = select([orders.c.order, literal("x")]).cte("regional_sales")
|
|
s = select([s.c.order, literal("y")])
|
|
dialect = default.DefaultDialect()
|
|
dialect.positional = True
|
|
dialect.paramstyle = 'numeric'
|
|
s1 = select([orders.c.order]).where(orders.c.order == 'x').\
|
|
cte("regional_sales_1")
|
|
|
|
s1a = s1.alias()
|
|
|
|
s2 = select([orders.c.order == 'y', s1a.c.order,
|
|
orders.c.order, s1.c.order]).\
|
|
where(orders.c.order == 'z').\
|
|
cte("regional_sales_2")
|
|
|
|
s3 = select([s2])
|
|
|
|
self.assert_compile(
|
|
s3,
|
|
'WITH regional_sales_1 AS (SELECT orders."order" AS "order" '
|
|
'FROM orders WHERE orders."order" = :1), regional_sales_2 AS '
|
|
'(SELECT orders."order" = :2 AS anon_1, '
|
|
'anon_2."order" AS "order", '
|
|
'orders."order" AS "order", '
|
|
'regional_sales_1."order" AS "order" FROM orders, '
|
|
'regional_sales_1 '
|
|
'AS anon_2, regional_sales_1 '
|
|
'WHERE orders."order" = :3) SELECT regional_sales_2.anon_1, '
|
|
'regional_sales_2."order" FROM regional_sales_2',
|
|
checkpositional=('x', 'y', 'z'), dialect=dialect)
|
|
|
|
def test_positional_binds_2_asliteral(self):
|
|
orders = table('orders',
|
|
column('order'),
|
|
)
|
|
s = select([orders.c.order, literal("x")]).cte("regional_sales")
|
|
s = select([s.c.order, literal("y")])
|
|
dialect = default.DefaultDialect()
|
|
dialect.positional = True
|
|
dialect.paramstyle = 'numeric'
|
|
s1 = select([orders.c.order]).where(orders.c.order == 'x').\
|
|
cte("regional_sales_1")
|
|
|
|
s1a = s1.alias()
|
|
|
|
s2 = select([orders.c.order == 'y', s1a.c.order,
|
|
orders.c.order, s1.c.order]).\
|
|
where(orders.c.order == 'z').\
|
|
cte("regional_sales_2")
|
|
|
|
s3 = select([s2])
|
|
|
|
self.assert_compile(
|
|
s3,
|
|
'WITH regional_sales_1 AS '
|
|
'(SELECT orders."order" AS "order" '
|
|
'FROM orders '
|
|
'WHERE orders."order" = \'x\'), '
|
|
'regional_sales_2 AS '
|
|
'(SELECT orders."order" = \'y\' AS anon_1, '
|
|
'anon_2."order" AS "order", orders."order" AS "order", '
|
|
'regional_sales_1."order" AS "order" '
|
|
'FROM orders, regional_sales_1 AS anon_2, regional_sales_1 '
|
|
'WHERE orders."order" = \'z\') '
|
|
'SELECT regional_sales_2.anon_1, regional_sales_2."order" '
|
|
'FROM regional_sales_2',
|
|
checkpositional=(), dialect=dialect,
|
|
literal_binds=True)
|
|
|
|
def test_all_aliases(self):
|
|
orders = table('order', column('order'))
|
|
s = select([orders.c.order]).cte("regional_sales")
|
|
|
|
r1 = s.alias()
|
|
r2 = s.alias()
|
|
|
|
s2 = select([r1, r2]).where(r1.c.order > r2.c.order)
|
|
|
|
self.assert_compile(
|
|
s2,
|
|
'WITH regional_sales AS (SELECT "order"."order" '
|
|
'AS "order" FROM "order") '
|
|
'SELECT anon_1."order", anon_2."order" '
|
|
'FROM regional_sales AS anon_1, '
|
|
'regional_sales AS anon_2 WHERE anon_1."order" > anon_2."order"'
|
|
)
|
|
|
|
s3 = select(
|
|
[orders]).select_from(
|
|
orders.join(
|
|
r1,
|
|
r1.c.order == orders.c.order))
|
|
|
|
self.assert_compile(
|
|
s3,
|
|
'WITH regional_sales AS '
|
|
'(SELECT "order"."order" AS "order" '
|
|
'FROM "order")'
|
|
' SELECT "order"."order" '
|
|
'FROM "order" JOIN regional_sales AS anon_1 '
|
|
'ON anon_1."order" = "order"."order"'
|
|
)
|