New tests for model_sql to cover some gaps.

This commit is contained in:
Charles Leifer
2026-03-22 10:25:00 -05:00
parent 6418faf6ce
commit c17dc048fb
+294 -5
View File
@@ -210,11 +210,6 @@ class TestModelSQL(ModelDatabaseTestCase):
'FROM "huey"."with_schema" AS "t1" '
'WHERE ("t1"."data" = ?)'), ['zaizee'])
# ===========================================================================
# ON CONFLICT / upsert SQL with Models
# ===========================================================================
def test_where_coerce(self):
query = Person.select(Person.last).where(Person.id == '1337')
self.assertSQL(query, (
@@ -974,6 +969,300 @@ class TestModelSQL(ModelDatabaseTestCase):
('DELETE FROM "tweet" WHERE ("tweet"."user_id" = ?)', [1]),
])
# ===========================================================================
# Advanced Model SQL: RETURNING, window functions, CTE, LATERAL
# ===========================================================================
class TestModelAdvancedSQL(ModelDatabaseTestCase):
database = get_in_memory_db()
requires = [Category, Note, Person, Sample, User]
# -- RETURNING on UPDATE and DELETE --
def test_update_returning(self):
query = (User
.update({User.username: 'zaizee'})
.where(User.username == 'charlie')
.returning(User))
self.assertSQL(query, (
'UPDATE "users" SET "username" = ? '
'WHERE ("users"."username" = ?) '
'RETURNING "users"."id", "users"."username"'),
['zaizee', 'charlie'])
query = (User
.update({User.username: 'zaizee'})
.returning(User.id))
self.assertSQL(query, (
'UPDATE "users" SET "username" = ? '
'RETURNING "users"."id"'), ['zaizee'])
def test_update_returning_expression(self):
query = (User
.update({User.username: User.username.concat('-x')})
.where(User.id > 2)
.returning(User.id, User.username.alias('new_name')))
self.assertSQL(query, (
'UPDATE "users" SET "username" = ("users"."username" || ?) '
'WHERE ("users"."id" > ?) '
'RETURNING "users"."id", "users"."username" AS "new_name"'),
['-x', 2])
def test_delete_returning(self):
query = (User
.delete()
.where(User.username == 'zaizee')
.returning(User))
self.assertSQL(query, (
'DELETE FROM "users" '
'WHERE ("users"."username" = ?) '
'RETURNING "users"."id", "users"."username"'), ['zaizee'])
query = (User
.delete()
.returning(User.id.alias('removed_id')))
self.assertSQL(query, (
'DELETE FROM "users" '
'RETURNING "users"."id" AS "removed_id"'), [])
def test_delete_returning_no_fields(self):
query = (User
.delete()
.where(User.id > 3)
.returning())
self.assertSQL(query, (
'DELETE FROM "users" WHERE ("users"."id" > ?)'), [3])
# -- Window functions --
def test_window_partition(self):
query = (Sample
.select(
Sample.counter,
Sample.value,
fn.AVG(Sample.value).over(
partition_by=[Sample.counter]))
.order_by(Sample.counter))
self.assertSQL(query, (
'SELECT "t1"."counter", "t1"."value", AVG("t1"."value") '
'OVER (PARTITION BY "t1"."counter") '
'FROM "sample" AS "t1" ORDER BY "t1"."counter"'), [])
def test_window_order(self):
query = (Sample
.select(
Sample.value,
fn.RANK().over(order_by=[Sample.value])))
self.assertSQL(query, (
'SELECT "t1"."value", RANK() '
'OVER (ORDER BY "t1"."value") '
'FROM "sample" AS "t1"'), [])
def test_window_empty_over(self):
query = (Sample
.select(
Sample.value,
fn.LAG(Sample.value, 1).over())
.order_by(Sample.value))
self.assertSQL(query, (
'SELECT "t1"."value", LAG("t1"."value", ?) OVER () '
'FROM "sample" AS "t1" '
'ORDER BY "t1"."value"'), [1])
def test_window_frame(self):
query = (Sample
.select(
Sample.value,
fn.SUM(Sample.value).over(
partition_by=[Sample.counter],
order_by=[Sample.value],
start=Window.preceding(),
end=Window.CURRENT_ROW)))
self.assertSQL(query, (
'SELECT "t1"."value", SUM("t1"."value") '
'OVER (PARTITION BY "t1"."counter" ORDER BY "t1"."value" '
'ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) '
'FROM "sample" AS "t1"'), [])
def test_window_named(self):
window = Window(partition_by=[Sample.counter],
order_by=[Sample.value])
query = (Sample
.select(
Sample.counter,
fn.SUM(Sample.value).over(window),
fn.AVG(Sample.value).over(window))
.window(window))
self.assertSQL(query, (
'SELECT "t1"."counter", '
'SUM("t1"."value") OVER "w", '
'AVG("t1"."value") OVER "w" '
'FROM "sample" AS "t1" '
'WINDOW "w" AS (PARTITION BY "t1"."counter" '
'ORDER BY "t1"."value")'), [])
def test_window_filter(self):
query = (Sample
.select(
fn.COUNT(Sample.id)
.filter(Sample.counter > 1)
.over(partition_by=[Sample.counter])))
self.assertSQL(query, (
'SELECT COUNT("t1"."id") '
'FILTER (WHERE ("t1"."counter" > ?)) '
'OVER (PARTITION BY "t1"."counter") '
'FROM "sample" AS "t1"'), [1])
# -- CTE (Common Table Expressions) --
def test_simple_cte(self):
cte = (Category
.select(Category.name, Category.parent)
.cte('catz', columns=('name', 'parent')))
query = (cte
.select_from(cte.c.name)
.order_by(cte.c.name))
self.assertSQL(query, (
'WITH "catz" ("name", "parent") AS ('
'SELECT "t1"."name", "t1"."parent_id" '
'FROM "category" AS "t1") '
'SELECT "catz"."name" FROM "catz" '
'ORDER BY "catz"."name"'), [])
def test_recursive_cte(self):
base = (Category
.select(Category.name, Category.parent)
.where(Category.name == 'root')
.cte('tree', recursive=True, columns=('name', 'parent_id')))
CA = Category.alias()
recursive = (CA
.select(CA.name, CA.parent)
.join(base, on=(CA.parent == base.c.name)))
cte = base.union_all(recursive)
query = (cte
.select_from(cte.c.name)
.order_by(cte.c.name))
self.assertSQL(query, (
'WITH RECURSIVE "tree" ("name", "parent_id") AS ('
'SELECT "t1"."name", "t1"."parent_id" '
'FROM "category" AS "t1" '
'WHERE ("t1"."name" = ?) '
'UNION ALL '
'SELECT "t2"."name", "t2"."parent_id" '
'FROM "category" AS "t2" '
'INNER JOIN "tree" ON ("t2"."parent_id" = "tree"."name")) '
'SELECT "tree"."name" FROM "tree" '
'ORDER BY "tree"."name"'), ['root'])
def test_cte_in_subquery(self):
cte = (User
.select(User.id)
.where(User.username.startswith('h'))
.cte('filtered'))
query = (User
.select()
.where(User.id.in_(cte.select(cte.c.id)))
.with_cte(cte))
self.assertSQL(query, (
'WITH "filtered" AS ('
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."username" ILIKE ?)) '
'SELECT "t2"."id", "t2"."username" FROM "users" AS "t2" '
'WHERE ("t2"."id" IN ('
'SELECT "filtered"."id" FROM "filtered"))'), ['h%'])
def test_cte_update(self):
cte = (User
.select(User.id)
.where(User.username == 'zaizee')
.cte('to_update'))
query = (User
.update({User.username: 'zaizee-x'})
.where(User.id.in_(cte.select(cte.c.id)))
.with_cte(cte))
self.assertSQL(query, (
'WITH "to_update" AS ('
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."username" = ?)) '
'UPDATE "users" SET "username" = ? '
'WHERE ("users"."id" IN ('
'SELECT "to_update"."id" FROM "to_update"))'),
['zaizee', 'zaizee-x'])
def test_cte_delete(self):
cte = (User
.select(User.id)
.where(User.username.startswith('z'))
.cte('to_delete'))
query = (User
.delete()
.where(User.id.in_(cte.select(cte.c.id)))
.with_cte(cte))
self.assertSQL(query, (
'WITH "to_delete" AS ('
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."username" ILIKE ?)) '
'DELETE FROM "users" '
'WHERE ("users"."id" IN ('
'SELECT "to_delete"."id" FROM "to_delete"))'), ['z%'])
def test_materialized_cte(self):
for materialized, clause in ((True, 'MATERIALIZED '),
(False, 'NOT MATERIALIZED '),
(None, '')):
cte = (User
.select(User.id)
.cte('uids', materialized=materialized))
query = cte.select_from(cte.c.id)
self.assertSQL(query, (
'WITH "uids" AS %s('
'SELECT "t1"."id" FROM "users" AS "t1") '
'SELECT "uids"."id" FROM "uids"') % clause, [])
# -- LATERAL join --
def test_lateral_join(self):
PA = Person.alias()
subq = (Note
.select(Note.content)
.where(Note.author == PA.id)
.limit(1))
query = (PA
.select(PA.first, subq.c.content)
.join(subq, JOIN.LEFT_LATERAL, on=True))
self.assertSQL(query, (
'SELECT "t1"."first", "t2"."content" '
'FROM "person" AS "t1" '
'LEFT JOIN LATERAL ('
'SELECT "t3"."content" FROM "note" AS "t3" '
'WHERE ("t3"."author_id" = "t1"."id") '
'LIMIT ?) AS "t2" ON ?'), [1, True])
def test_lateral_join_inner(self):
subq = (Note
.select(Note.content)
.where(Note.author == Person.id)
.order_by(Note.id.desc())
.limit(2))
query = (Person
.select(Person.first, subq.c.content)
.join(subq, JOIN.LATERAL, on=True))
self.assertSQL(query, (
'SELECT "t1"."first", "t2"."content" '
'FROM "person" AS "t1" '
'LATERAL ('
'SELECT "t3"."content" FROM "note" AS "t3" '
'WHERE ("t3"."author_id" = "t1"."id") '
'ORDER BY "t3"."id" DESC LIMIT ?) AS "t2" ON ?'), [2, True])
# ===========================================================================
# ON CONFLICT / upsert SQL with Models
# ===========================================================================
@requires_pglike
class TestOnConflictSQL(ModelDatabaseTestCase):
requires = [Emp, OCTest, UKVP]