Files
2026-04-17 10:04:44 -05:00

3074 lines
124 KiB
Python

"""
SQL generation tests for Table-level (non-Model) queries.
These tests verify that the query builder produces correct SQL from raw Table
objects, without involving Model metaclass machinery. Table objects are
lightweight and internal to query building.
Test case ordering:
* Core DML: SELECT, INSERT, UPDATE, DELETE
* Advanced SELECT features: window functions, VALUES lists, CASE
* Miscellaneous SELECT features: FOR UPDATE, RETURNING, etc.
* Expression SQL
* ON CONFLICT (per-dialect: SQLite, MySQL, PostgreSQL)
* Index generation
* Utilities and edge cases
"""
import datetime
import re
from peewee import *
from peewee import Alias
from peewee import AliasManager
from peewee import Context
from peewee import Entity
from peewee import Expression
from peewee import ForUpdate
from peewee import Function
from peewee import Join
from peewee import Negated
from peewee import NodeList
from peewee import Ordering
from peewee import QualifiedNames
from peewee import Value
from peewee import ValueLiterals
from peewee import Window
from peewee import query_to_string
from .base import BaseTestCase
from .base import TestModel
from .base import db
from .base import requires_mysql
from .base import requires_sqlite
from .base import __sql__
# ---------------------------------------------------------------------------
# Module-level Table objects shared across test cases in this module.
# These are Table instances (not Model classes) — they test the low-level
# query builder without Model metaclass involvement.
# ---------------------------------------------------------------------------
User = Table('users')
Tweet = Table('tweets')
Person = Table('person', ['id', 'name', 'dob'], primary_key='id')
Note = Table('note', ['id', 'person_id', 'content'])
# ===========================================================================
# Core DML: SELECT, INSERT, UPDATE, DELETE
# ===========================================================================
class TestSelectQuery(BaseTestCase):
def test_select(self):
query = (User
.select(User.c.id, User.c.username)
.where(User.c.username == 'foo'))
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."username" '
'FROM "users" AS "t1" '
'WHERE ("t1"."username" = ?)'), ['foo'])
query = (User
.select(User.c['id'], User.c['username'])
.where(User.c['username'] == 'test'))
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."username" '
'FROM "users" AS "t1" '
'WHERE ("t1"."username" = ?)'), ['test'])
def test_select_extend(self):
query = User.select(User.c.id, User.c.username)
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."username" FROM "users" AS "t1"'), [])
query = query.select(User.c.username, User.c.is_admin)
self.assertSQL(query, (
'SELECT "t1"."username", "t1"."is_admin" FROM "users" AS "t1"'),
[])
query = query.select_extend(User.c.is_active, User.c.id)
self.assertSQL(query, (
'SELECT "t1"."username", "t1"."is_admin", "t1"."is_active", '
'"t1"."id" FROM "users" AS "t1"'), [])
def test_selected_columns(self):
query = (User
.select(User.c.id, User.c.username, fn.COUNT(Tweet.c.id))
.join(Tweet, JOIN.LEFT_OUTER,
on=(User.c.id == Tweet.c.user_id)))
# NOTE: because of operator overloads for equality we have to test by
# asserting the attributes of the selected cols.
c_id, c_username, c_ct = query.selected_columns
self.assertEqual(c_id.name, 'id')
self.assertTrue(c_id.source is User)
self.assertEqual(c_username.name, 'username')
self.assertTrue(c_username.source is User)
self.assertTrue(isinstance(c_ct, Function))
self.assertEqual(c_ct.name, 'COUNT')
c_tid, = c_ct.arguments
self.assertEqual(c_tid.name, 'id')
self.assertTrue(c_tid.source is Tweet)
query.selected_columns = (User.c.username,)
c_username, = query.selected_columns
self.assertEqual(c_username.name, 'username')
self.assertTrue(c_username.source is User)
def test_select_explicit_columns(self):
query = (Person
.select()
.where(Person.dob < datetime.date(1980, 1, 1)))
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."name", "t1"."dob" '
'FROM "person" AS "t1" '
'WHERE ("t1"."dob" < ?)'), [datetime.date(1980, 1, 1)])
def test_star(self):
query = User.select(User.__star__)
self.assertSQL(query, ('SELECT "t1".* FROM "users" AS "t1"'), [])
query = (Tweet
.select(Tweet.__star__, User.__star__)
.join(User, on=(Tweet.c.user_id == User.c.id)))
self.assertSQL(query, (
'SELECT "t1".*, "t2".* '
'FROM "tweets" AS "t1" '
'INNER JOIN "users" AS "t2" ON ("t1"."user_id" = "t2"."id")'), [])
query = (Tweet
.select(Tweet.__star__, User.c.id)
.join(User, on=(Tweet.c.user_id == User.c.id)))
self.assertSQL(query, (
'SELECT "t1".*, "t2"."id" '
'FROM "tweets" AS "t1" '
'INNER JOIN "users" AS "t2" ON ("t1"."user_id" = "t2"."id")'), [])
def test_from_clause(self):
query = (Note
.select(Note.content, Person.name)
.from_(Note, Person)
.where(Note.person_id == Person.id)
.order_by(Note.id))
self.assertSQL(query, (
'SELECT "t1"."content", "t2"."name" '
'FROM "note" AS "t1", "person" AS "t2" '
'WHERE ("t1"."person_id" = "t2"."id") '
'ORDER BY "t1"."id"'), [])
def test_from_query(self):
inner = Person.select(Person.name)
query = (Person
.select(Person.name)
.from_(inner.alias('i1')))
self.assertSQL(query, (
'SELECT "t1"."name" '
'FROM (SELECT "t1"."name" FROM "person" AS "t1") AS "i1"'), [])
PA = Person.alias('pa')
inner = PA.select(PA.name).alias('i1')
query = (Person
.select(inner.c.name)
.from_(inner)
.order_by(inner.c.name))
self.assertSQL(query, (
'SELECT "i1"."name" '
'FROM (SELECT "pa"."name" FROM "person" AS "pa") AS "i1" '
'ORDER BY "i1"."name"'), [])
def test_multiple_where(self):
query = (Person
.select(Person.name)
.where(Person.dob < datetime.date(1980, 1, 1))
.where(Person.dob > datetime.date(1950, 1, 1)))
self.assertSQL(query, (
'SELECT "t1"."name" '
'FROM "person" AS "t1" '
'WHERE (("t1"."dob" < ?) AND ("t1"."dob" > ?))'),
[datetime.date(1980, 1, 1), datetime.date(1950, 1, 1)])
def test_orwhere(self):
query = (Person
.select(Person.name)
.orwhere(Person.dob > datetime.date(1980, 1, 1))
.orwhere(Person.dob < datetime.date(1950, 1, 1)))
self.assertSQL(query, (
'SELECT "t1"."name" '
'FROM "person" AS "t1" '
'WHERE (("t1"."dob" > ?) OR ("t1"."dob" < ?))'),
[datetime.date(1980, 1, 1), datetime.date(1950, 1, 1)])
def test_where_convert_to_is_null(self):
Note = Table('notes', ('id', 'content', 'user_id'))
query = Note.select().where(Note.user_id == None)
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."content", "t1"."user_id" '
'FROM "notes" AS "t1" WHERE ("t1"."user_id" IS NULL)'), [])
def test_select_in_list_of_values(self):
names_vals = [
['charlie', 'huey'],
('charlie', 'huey'),
set(('charlie', 'huey')),
frozenset(('charlie', 'huey')),
(x for x in ('charlie', 'huey'))]
for names in names_vals:
query = (Person
.select()
.where(Person.name.in_(names)))
sql, params = Context().sql(query).query()
self.assertEqual(sql, (
'SELECT "t1"."id", "t1"."name", "t1"."dob" '
'FROM "person" AS "t1" '
'WHERE ("t1"."name" IN (?, ?))'))
self.assertEqual(sorted(params), ['charlie', 'huey'])
query = (Person
.select()
.where(Person.id.in_(range(1, 10, 2))))
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."name", "t1"."dob" '
'FROM "person" AS "t1" '
'WHERE ("t1"."id" IN (?, ?, ?, ?, ?))'), [1, 3, 5, 7, 9])
def test_in_value_representation(self):
query = (User
.select(User.c.id)
.where(User.c.username.in_(['foo', 'bar', 'baz'])))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."username" IN (?, ?, ?))'), ['foo', 'bar', 'baz'])
def test_empty_in(self):
query = User.select(User.c.id).where(User.c.username.in_([]))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE (0 = 1)'), [])
query = User.select(User.c.id).where(User.c.username.not_in([]))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE (1 = 1)'), [])
query = User.select(User.c.id).where(User.c.username.in_(Value([])))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE (0 = 1)'), [])
query = User.select(User.c.id).where(User.c.id << [])
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" WHERE (0 = 1)'))
query = User.select(User.c.id).where(User.c.id.not_in([]))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" WHERE (1 = 1)'))
def test_between_via_slice(self):
query = User.select(User.c.id).where(User.c.age[18:65])
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."age" BETWEEN ? AND ?)'), [18, 65])
self.assertRaises(ValueError, lambda: User.c.age[18:])
self.assertRaises(ValueError, lambda: User.c.age[:65])
def test_eq_none_produces_is_null(self):
query = User.select(User.c.id).where(User.c.name == None)
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."name" IS NULL)'), [])
query = User.select(User.c.id).where(User.c.name != None)
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."name" IS NOT NULL)'), [])
def test_join_explicit_columns(self):
query = (Note
.select(Note.content)
.join(Person, on=(Note.person_id == Person.id))
.where(Person.name == 'charlie')
.order_by(Note.id.desc()))
self.assertSQL(query, (
'SELECT "t1"."content" '
'FROM "note" AS "t1" '
'INNER JOIN "person" AS "t2" ON ("t1"."person_id" = "t2"."id") '
'WHERE ("t2"."name" = ?) '
'ORDER BY "t1"."id" DESC'), ['charlie'])
def test_simple_join(self):
query = (User
.select(
User.c.id,
User.c.username,
fn.COUNT(Tweet.c.id).alias('ct'))
.join(Tweet, on=(Tweet.c.user_id == User.c.id))
.group_by(User.c.id, User.c.username))
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."username", COUNT("t2"."id") AS "ct" '
'FROM "users" AS "t1" '
'INNER JOIN "tweets" AS "t2" ON ("t2"."user_id" = "t1"."id") '
'GROUP BY "t1"."id", "t1"."username"'), [])
def test_multi_join(self):
Like = Table('likes')
LikeUser = User.alias('lu')
query = (Like
.select(Tweet.c.content, User.c.username, LikeUser.c.username)
.join(Tweet, on=(Like.c.tweet_id == Tweet.c.id))
.join(User, on=(Tweet.c.user_id == User.c.id))
.join(LikeUser, on=(Like.c.user_id == LikeUser.c.id))
.where(LikeUser.c.username == 'charlie')
.order_by(Tweet.c.timestamp))
self.assertSQL(query, (
'SELECT "t1"."content", "t2"."username", "lu"."username" '
'FROM "likes" AS "t3" '
'INNER JOIN "tweets" AS "t1" ON ("t3"."tweet_id" = "t1"."id") '
'INNER JOIN "users" AS "t2" ON ("t1"."user_id" = "t2"."id") '
'INNER JOIN "users" AS "lu" ON ("t3"."user_id" = "lu"."id") '
'WHERE ("lu"."username" = ?) '
'ORDER BY "t1"."timestamp"'), ['charlie'])
def test_join_on_query(self):
inner = User.select(User.c.id).alias('j1')
query = (Tweet
.select(Tweet.c.content)
.join(inner, on=(Tweet.c.user_id == inner.c.id)))
self.assertSQL(query, (
'SELECT "t1"."content" FROM "tweets" AS "t1" '
'INNER JOIN (SELECT "t2"."id" FROM "users" AS "t2") AS "j1" '
'ON ("t1"."user_id" = "j1"."id")'), [])
def test_join_on_misc(self):
cond = fn.Magic(Person.id, Note.id).alias('magic')
query = Person.select(Person.id).join(Note, on=cond)
self.assertSQL(query, (
'SELECT "t1"."id" FROM "person" AS "t1" '
'INNER JOIN "note" AS "t2" '
'ON Magic("t1"."id", "t2"."id") AS "magic"'), [])
def test_left_outer_join_shortcut(self):
query = (User
.select(User.c.id, Tweet.c.content)
.left_outer_join(Tweet, on=(User.c.id == Tweet.c.user_id)))
self.assertSQL(query, (
'SELECT "t1"."id", "t2"."content" '
'FROM "users" AS "t1" '
'LEFT OUTER JOIN "tweets" AS "t2" '
'ON ("t1"."id" = "t2"."user_id")'))
def test_operator_joins(self):
t1 = Table('a')
t2 = Table('b')
# & = INNER JOIN
j = t1 & t2
self.assertIsInstance(j, Join)
self.assertEqual(j.join_type, 'INNER JOIN')
# + = LEFT OUTER JOIN
j = t1 + t2
self.assertEqual(j.join_type, 'LEFT OUTER JOIN')
# - = RIGHT OUTER JOIN
j = t1 - t2
self.assertEqual(j.join_type, 'RIGHT OUTER JOIN')
# | = FULL OUTER JOIN
j = t1 | t2
self.assertEqual(j.join_type, 'FULL OUTER JOIN')
# * = CROSS JOIN
j = t1 * t2
self.assertEqual(j.join_type, 'CROSS JOIN')
def test_user_defined_alias(self):
UA = User.alias('alt')
query = (User
.select(User.c.id, User.c.username, UA.c.nuggz)
.join(UA, on=(User.c.id == UA.c.id))
.order_by(UA.c.nuggz))
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."username", "alt"."nuggz" '
'FROM "users" AS "t1" '
'INNER JOIN "users" AS "alt" ON ("t1"."id" = "alt"."id") '
'ORDER BY "alt"."nuggz"'), [])
def test_group_by(self):
q = (User
.select(User.c.username, fn.COUNT(Tweet.c.id).alias('ct'))
.join(Tweet, on=(User.c.id == Tweet.c.user_id))
.group_by(User.c.username))
self.assertSQL(q, (
'SELECT "t1"."username", COUNT("t2"."id") AS "ct" '
'FROM "users" AS "t1" '
'INNER JOIN "tweets" AS "t2" ON ("t1"."id" = "t2"."user_id") '
'GROUP BY "t1"."username"'), [])
q = (Person
.select(Person, fn.COUNT(Note.id).alias('ct'))
.join(Note, on=(Person.id == Note.person_id))
.group_by(Person))
self.assertSQL(q, (
'SELECT "person" AS "t1", COUNT("t2"."id") AS "ct" '
'FROM "person" AS "t1" '
'INNER JOIN "note" AS "t2" ON ("t1"."id" = "t2"."person_id") '
'GROUP BY "t1"."id", "t1"."name", "t1"."dob"'), [])
def test_group_by_table_no_columns_error(self):
query = User.select()
self.assertRaises(ValueError, query.group_by, User)
def test_having(self):
q = (User
.select(User.c.username, fn.COUNT(Tweet.c.id).alias('ct'))
.join(Tweet, on=(User.c.id == Tweet.c.user_id))
.group_by(User.c.username)
.having(fn.COUNT(Tweet.c.id) > 3))
self.assertSQL(q, (
'SELECT "t1"."username", COUNT("t2"."id") AS "ct" '
'FROM "users" AS "t1" '
'INNER JOIN "tweets" AS "t2" ON ("t1"."id" = "t2"."user_id") '
'GROUP BY "t1"."username" '
'HAVING (COUNT("t2"."id") > ?)'), [3])
q = (User
.select(User.c.username, fn.COUNT(Tweet.c.id).alias('ct'))
.join(Tweet, on=(User.c.id == Tweet.c.user_id))
.group_by(User.c.username)
.having(fn.COUNT(Tweet.c.id) > 3, fn.MAX(Tweet.c.id) < 10))
self.assertSQL(q, (
'SELECT "t1"."username", COUNT("t2"."id") AS "ct" '
'FROM "users" AS "t1" '
'INNER JOIN "tweets" AS "t2" ON ("t1"."id" = "t2"."user_id") '
'GROUP BY "t1"."username" '
'HAVING ((COUNT("t2"."id") > ?) AND (MAX("t2"."id") < ?))'),
[3, 10])
q = (Person
.select(Person, fn.COUNT(Note.id).alias('ct'))
.join(Note, on=(Person.id == Note.person_id))
.group_by(Person)
.having((Entity('ct') > 2) & (fn.COUNT(Note.id) < 10)))
self.assertSQL(q, (
'SELECT "person" AS "t1", COUNT("t2"."id") AS "ct" '
'FROM "person" AS "t1" '
'INNER JOIN "note" AS "t2" ON ("t1"."id" = "t2"."person_id") '
'GROUP BY "t1"."id", "t1"."name", "t1"."dob" '
'HAVING (("ct" > ?) AND (COUNT("t2"."id") < ?))'), [2, 10])
def test_limit(self):
base = User.select(User.c.id)
self.assertSQL(base.limit(None), (
'SELECT "t1"."id" FROM "users" AS "t1"'), [])
self.assertSQL(base.limit(10), (
'SELECT "t1"."id" FROM "users" AS "t1" LIMIT ?'), [10])
self.assertSQL(base.limit(10).offset(3), (
'SELECT "t1"."id" FROM "users" AS "t1" '
'LIMIT ? OFFSET ?'), [10, 3])
self.assertSQL(base.limit(0), (
'SELECT "t1"."id" FROM "users" AS "t1" LIMIT ?'), [0])
self.assertSQL(base.offset(3), (
'SELECT "t1"."id" FROM "users" AS "t1" OFFSET ?'), [3],
limit_max=None)
# Some databases do not support offset without corresponding LIMIT:
self.assertSQL(base.offset(3), (
'SELECT "t1"."id" FROM "users" AS "t1" LIMIT ? OFFSET ?'), [-1, 3],
limit_max=-1)
self.assertSQL(base.limit(0).offset(3), (
'SELECT "t1"."id" FROM "users" AS "t1" LIMIT ? OFFSET ?'), [0, 3],
limit_max=-1)
def test_subquery(self):
inner = (Tweet
.select(fn.COUNT(Tweet.c.id).alias('ct'))
.where(Tweet.c.user == User.c.id))
query = (User
.select(User.c.username, inner.alias('iq'))
.order_by(User.c.username))
self.assertSQL(query, (
'SELECT "t1"."username", '
'(SELECT COUNT("t2"."id") AS "ct" '
'FROM "tweets" AS "t2" '
'WHERE ("t2"."user" = "t1"."id")) AS "iq" '
'FROM "users" AS "t1" ORDER BY "t1"."username"'), [])
def test_subquery_in_expr(self):
Team = Table('team')
Challenge = Table('challenge')
subq = Team.select(fn.COUNT(Team.c.id) + 1)
query = (Challenge
.select((Challenge.c.points / subq).alias('score'))
.order_by(SQL('score')))
self.assertSQL(query, (
'SELECT ("t1"."points" / ('
'SELECT (COUNT("t2"."id") + ?) FROM "team" AS "t2")) AS "score" '
'FROM "challenge" AS "t1" ORDER BY score'), [1])
def test_correlated_subquery(self):
Employee = Table('employee', ['id', 'name', 'salary', 'dept'])
EA = Employee.alias('e2')
query = (Employee
.select(Employee.id, Employee.name)
.where(Employee.salary > (EA
.select(fn.AVG(EA.salary))
.where(EA.dept == Employee.dept))))
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."name" '
'FROM "employee" AS "t1" '
'WHERE ("t1"."salary" > ('
'SELECT AVG("e2"."salary") '
'FROM "employee" AS "e2" '
'WHERE ("e2"."dept" = "t1"."dept")))'), [])
def test_select_subselect_function(self):
# For functions whose only argument is a subquery, we do not need to
# include additional parentheses -- in fact, some databases will report
# a syntax error if we do.
exists = fn.EXISTS(Tweet
.select(Tweet.c.id)
.where(Tweet.c.user_id == User.c.id))
query = User.select(User.c.username, exists.alias('has_tweet'))
self.assertSQL(query, (
'SELECT "t1"."username", EXISTS('
'SELECT "t2"."id" FROM "tweets" AS "t2" '
'WHERE ("t2"."user_id" = "t1"."id")) AS "has_tweet" '
'FROM "users" AS "t1"'), [])
# If the function has more than one argument, we need to wrap the
# subquery in parentheses.
Stat = Table('stat', ['id', 'val'])
SA = Stat.alias('sa')
subq = SA.select(fn.SUM(SA.val).alias('val_sum'))
query = Stat.select(fn.COALESCE(subq, 0))
self.assertSQL(query, (
'SELECT COALESCE(('
'SELECT SUM("sa"."val") AS "val_sum" FROM "stat" AS "sa"'
'), ?) FROM "stat" AS "t1"'), [0])
def test_subquery_in_select_sql(self):
subq = User.select(User.c.id).where(User.c.username == 'huey')
query = Tweet.select(Tweet.c.content,
Tweet.c.user_id.in_(subq).alias('is_huey'))
self.assertSQL(query, (
'SELECT "t1"."content", ("t1"."user_id" IN ('
'SELECT "t2"."id" FROM "users" AS "t2" WHERE ("t2"."username" = ?)'
')) AS "is_huey" FROM "tweets" AS "t1"'), ['huey'])
# If we explicitly specify an alias, it will be included.
subq = subq.alias('sq')
query = Tweet.select(Tweet.c.content,
Tweet.c.user_id.in_(subq).alias('is_huey'))
self.assertSQL(query, (
'SELECT "t1"."content", ("t1"."user_id" IN ('
'SELECT "t2"."id" FROM "users" AS "t2" WHERE ("t2"."username" = ?)'
') AS "sq") AS "is_huey" FROM "tweets" AS "t1"'), ['huey'])
def test_subquery_in_select_expression_sql(self):
Point = Table('point', ('x', 'y'))
PA = Point.alias('pa')
subq = PA.select(fn.SUM(PA.y).alias('sa')).where(PA.x == Point.x)
query = (Point
.select(Point.x, Point.y, subq.alias('sy'))
.order_by(Point.x, Point.y))
self.assertSQL(query, (
'SELECT "t1"."x", "t1"."y", ('
'SELECT SUM("pa"."y") AS "sa" FROM "point" AS "pa" '
'WHERE ("pa"."x" = "t1"."x")) AS "sy" '
'FROM "point" AS "t1" '
'ORDER BY "t1"."x", "t1"."y"'), [])
def test_select_from_subquery(self):
subq = (User
.select(User.c.username,
fn.LENGTH(User.c.username).alias('name_len'))
.alias('sub'))
query = (User
.select(subq.c.username, subq.c.name_len)
.from_(subq)
.where(subq.c.name_len > 3)
.order_by(subq.c.name_len))
self.assertSQL(query, (
'SELECT "sub"."username", "sub"."name_len" '
'FROM ('
'SELECT "t1"."username", LENGTH("t1"."username") AS "name_len" '
'FROM "users" AS "t1") AS "sub" '
'WHERE ("sub"."name_len" > ?) '
'ORDER BY "sub"."name_len"'), [3])
def test_select_from_subquery_no_columns_error(self):
query = User.select(User.c.id)
self.assertRaises(ValueError, query.select_from)
def test_simple_cte(self):
cte = User.select(User.c.id).cte('user_ids')
query = (User
.select(User.c.username)
.where(User.c.id.in_(cte))
.with_cte(cte))
self.assertSQL(query, (
'WITH "user_ids" AS (SELECT "t1"."id" FROM "users" AS "t1") '
'SELECT "t2"."username" FROM "users" AS "t2" '
'WHERE ("t2"."id" IN "user_ids")'), [])
def test_two_ctes(self):
c1 = User.select(User.c.id).cte('user_ids')
c2 = User.select(User.c.username).cte('user_names')
query = (User
.select(c1.c.id, c2.c.username)
.where((c1.c.id == User.c.id) &
(c2.c.username == User.c.username))
.with_cte(c1, c2))
self.assertSQL(query, (
'WITH "user_ids" AS (SELECT "t1"."id" FROM "users" AS "t1"), '
'"user_names" AS (SELECT "t2"."username" FROM "users" AS "t2") '
'SELECT "user_ids"."id", "user_names"."username" '
'FROM "users" AS "t3" '
'WHERE (("user_ids"."id" = "t3"."id") AND '
'("user_names"."username" = "t3"."username"))'), [])
def test_select_from_cte(self):
# Use the "select_from()" helper on the CTE object.
cte = User.select(User.c.username).cte('user_cte')
query = cte.select_from(cte.c.username).order_by(cte.c.username)
self.assertSQL(query, (
'WITH "user_cte" AS (SELECT "t1"."username" FROM "users" AS "t1") '
'SELECT "user_cte"."username" FROM "user_cte" '
'ORDER BY "user_cte"."username"'), [])
# Test selecting from multiple CTEs, which is done manually.
c1 = User.select(User.c.username).where(User.c.is_admin == 1).cte('c1')
c2 = User.select(User.c.username).where(User.c.is_staff == 1).cte('c2')
query = (Select((c1, c2), (c1.c.username, c2.c.username))
.with_cte(c1, c2))
self.assertSQL(query, (
'WITH "c1" AS ('
'SELECT "t1"."username" FROM "users" AS "t1" '
'WHERE ("t1"."is_admin" = ?)), '
'"c2" AS ('
'SELECT "t2"."username" FROM "users" AS "t2" '
'WHERE ("t2"."is_staff" = ?)) '
'SELECT "c1"."username", "c2"."username" FROM "c1", "c2"'), [1, 1])
def test_cte_select_from_2(self):
cte = (User
.select(User.c.username)
.where(User.c.username != 'x')
.cte('filtered'))
query = cte.select_from(cte.c.username)
self.assertSQL(query, (
'WITH "filtered" AS ('
'SELECT "t1"."username" FROM "users" AS "t1" '
'WHERE ("t1"."username" != ?)) '
'SELECT "filtered"."username" FROM "filtered"'), ['x'])
def test_cte_select_from_with_aggregate(self):
cte = (User
.select(User.c.username,
fn.COUNT(Tweet.c.id).alias('tweet_ct'))
.join(Tweet, JOIN.LEFT_OUTER, (Tweet.c.user_id == User.c.id))
.group_by(User.c.username)
.cte('user_stats'))
query = (cte
.select_from(cte.c.username, cte.c.tweet_ct)
.where(cte.c.tweet_ct > 0))
self.assertSQL(query, (
'WITH "user_stats" AS ('
'SELECT "t1"."username", COUNT("t2"."id") AS "tweet_ct" '
'FROM "users" AS "t1" '
'LEFT OUTER JOIN "tweets" AS "t2" '
'ON ("t2"."user_id" = "t1"."id") '
'GROUP BY "t1"."username") '
'SELECT "user_stats"."username", "user_stats"."tweet_ct" '
'FROM "user_stats" '
'WHERE ("user_stats"."tweet_ct" > ?)'), [0])
def test_two_ctes_with_join(self):
cte_a = (User
.select(User.c.id, User.c.username)
.cte('active_users'))
cte_b = (Tweet
.select(Tweet.c.user_id, fn.COUNT(Tweet.c.id).alias('ct'))
.group_by(Tweet.c.user_id)
.cte('tweet_counts'))
query = (cte_a
.select_from(cte_a.c.username, cte_b.c.ct)
.join(cte_b, on=(cte_a.c.id == cte_b.c.user_id))
.with_cte(cte_a, cte_b)
.order_by(cte_b.c.ct.desc()))
self.assertSQL(query, (
'WITH "active_users" AS ('
'SELECT "t1"."id", "t1"."username" '
'FROM "users" AS "t1"), '
'"tweet_counts" AS ('
'SELECT "t2"."user_id", COUNT("t2"."id") AS "ct" '
'FROM "tweets" AS "t2" '
'GROUP BY "t2"."user_id") '
'SELECT "active_users"."username", "tweet_counts"."ct" '
'FROM "active_users" '
'INNER JOIN "tweet_counts" '
'ON ("active_users"."id" = "tweet_counts"."user_id") '
'ORDER BY "tweet_counts"."ct" DESC'), [])
def test_materialize_cte(self):
cases = (
(True, 'MATERIALIZED '),
(False, 'NOT MATERIALIZED '),
(None, ''))
for materialized, clause in cases:
cte = (User
.select(User.c.id)
.cte('user_ids', materialized=materialized))
query = cte.select_from(cte.c.id).where(cte.c.id < 10)
self.assertSQL(query, (
'WITH "user_ids" AS %s('
'SELECT "t1"."id" FROM "users" AS "t1") '
'SELECT "user_ids"."id" FROM "user_ids" '
'WHERE ("user_ids"."id" < ?)') % clause, [10])
def test_cte_union_distinct(self):
# CTE.union() produces UNION (distinct) instead of UNION ALL.
base = User.select(User.c.id).where(User.c.id == 1)
cte = base.cte('cte1', recursive=True)
next_q = User.select(User.c.id).where(User.c.id < 5)
rcte = cte.union(next_q)
query = rcte.select_from(rcte.c.id)
self.assertSQL(query, (
'WITH RECURSIVE "cte1" AS ('
'SELECT "t1"."id" FROM "users" AS "t1" WHERE ("t1"."id" = ?) '
'UNION '
'SELECT "t2"."id" FROM "users" AS "t2" WHERE ("t2"."id" < ?)) '
'SELECT "cte1"."id" FROM "cte1"'), [1, 5])
def test_cte_select_from_no_columns_error(self):
# CTE.select_from() raises ValueError if no columns specified.
cte = User.select(User.c.id).cte('cte1')
self.assertRaises(ValueError, cte.select_from)
def test_source_cte_shortcut(self):
query = User.select(User.c.id).where(User.c.id > 0)
cte = query.cte('my_cte', columns=['id'])
result = cte.select_from(cte.c.id)
self.assertSQL(result, (
'WITH "my_cte" ("id") AS ('
'SELECT "t1"."id" FROM "users" AS "t1" WHERE ("t1"."id" > ?)) '
'SELECT "my_cte"."id" FROM "my_cte"'), [0])
def test_fibonacci_cte(self):
q1 = Select(columns=(
Value(1).alias('n'),
Value(0).alias('fib_n'),
Value(1).alias('next_fib_n'))).cte('fibonacci', recursive=True)
n = (q1.c.n + 1).alias('n')
rterm = Select(columns=(
n,
q1.c.next_fib_n,
q1.c.fib_n + q1.c.next_fib_n)).from_(q1).where(n < 10)
cases = (
(q1.union_all, 'UNION ALL'),
(q1.union, 'UNION'))
for method, clause in cases:
cte = method(rterm)
query = cte.select_from(cte.c.n, cte.c.fib_n)
self.assertSQL(query, (
'WITH RECURSIVE "fibonacci" AS ('
'SELECT ? AS "n", ? AS "fib_n", ? AS "next_fib_n" '
'%s '
'SELECT ("fibonacci"."n" + ?) AS "n", "fibonacci"."next_fib_n", '
'("fibonacci"."fib_n" + "fibonacci"."next_fib_n") '
'FROM "fibonacci" '
'WHERE ("n" < ?)) '
'SELECT "fibonacci"."n", "fibonacci"."fib_n" '
'FROM "fibonacci"' % clause), [1, 0, 1, 1, 10])
def test_cte_with_count(self):
cte = User.select(User.c.id).cte('user_ids')
query = (User
.select(User.c.username)
.join(cte, on=(User.c.id == cte.c.id))
.with_cte(cte))
count = Select([query], [fn.COUNT(SQL('1'))])
self.assertSQL(count, (
'SELECT COUNT(1) FROM ('
'WITH "user_ids" AS (SELECT "t1"."id" FROM "users" AS "t1") '
'SELECT "t2"."username" FROM "users" AS "t2" '
'INNER JOIN "user_ids" ON ("t2"."id" = "user_ids"."id")) '
'AS "t3"'), [])
def test_cte_subquery_in_expression(self):
Order = Table('order', ('id', 'description'))
Item = Table('item', ('id', 'order_id', 'description'))
cte = Order.select(fn.MAX(Order.id).alias('max_id')).cte('max_order')
qexpr = (Order
.select(Order.id)
.join(cte, on=(Order.id == cte.c.max_id))
.with_cte(cte))
query = (Item
.select(Item.id, Item.order_id, Item.description)
.where(Item.order_id.in_(qexpr)))
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."order_id", "t1"."description" '
'FROM "item" AS "t1" '
'WHERE ("t1"."order_id" IN ('
'WITH "max_order" AS ('
'SELECT MAX("t2"."id") AS "max_id" FROM "order" AS "t2") '
'SELECT "t3"."id" '
'FROM "order" AS "t3" '
'INNER JOIN "max_order" '
'ON ("t3"."id" = "max_order"."max_id")))'), [])
def test_multi_update_cte(self):
data = [(i, 'u%sx' % i) for i in range(1, 3)]
vl = ValuesList(data)
cte = vl.select().cte('uv', columns=('id', 'username'))
subq = cte.select(cte.c.username).where(cte.c.id == User.c.id)
query = (User
.update(username=subq)
.where(User.c.id.in_(cte.select(cte.c.id)))
.with_cte(cte))
self.assertSQL(query, (
'WITH "uv" ("id", "username") AS ('
'SELECT * FROM (VALUES (?, ?), (?, ?)) AS "t1") '
'UPDATE "users" SET "username" = ('
'SELECT "uv"."username" FROM "uv" '
'WHERE ("uv"."id" = "users"."id")) '
'WHERE ("users"."id" IN (SELECT "uv"."id" FROM "uv"))'),
[1, 'u1x', 2, 'u2x'])
def test_data_modifying_cte_delete(self):
Product = Table('products', ('id', 'name', 'timestamp'))
Archive = Table('archive', ('id', 'name', 'timestamp'))
query = (Product.delete()
.where(Product.timestamp < datetime.date(2022, 1, 1))
.returning(Product.id, Product.name, Product.timestamp))
cte = query.cte('moved_rows')
src = Select((cte,), (cte.c.id, cte.c.name, cte.c.timestamp))
iq = (Archive
.insert(src, (Archive.id, Archive.name, Archive.timestamp))
.with_cte(cte))
self.assertSQL(iq, (
'WITH "moved_rows" AS ('
'DELETE FROM "products" WHERE ("products"."timestamp" < ?) '
'RETURNING "products"."id", "products"."name", '
'"products"."timestamp") '
'INSERT INTO "archive" ("id", "name", "timestamp") '
'SELECT "moved_rows"."id", "moved_rows"."name", '
'"moved_rows"."timestamp" FROM "moved_rows"'),
[datetime.date(2022, 1, 1)])
Part = Table('parts', ('id', 'part', 'sub_part'))
base = (Part
.select(Part.sub_part, Part.part)
.where(Part.part == 'p')
.cte('included_parts', recursive=True,
columns=('sub_part', 'part')))
PA = Part.alias('p')
recursive = (PA
.select(PA.sub_part, PA.part)
.join(base, on=(PA.part == base.c.sub_part)))
cte = base.union_all(recursive)
sq = Select((cte,), (cte.c.part,))
query = (Part.delete()
.where(Part.part.in_(sq))
.with_cte(cte))
self.assertSQL(query, (
'WITH RECURSIVE "included_parts" ("sub_part", "part") AS ('
'SELECT "t1"."sub_part", "t1"."part" FROM "parts" AS "t1" '
'WHERE ("t1"."part" = ?) '
'UNION ALL '
'SELECT "p"."sub_part", "p"."part" '
'FROM "parts" AS "p" '
'INNER JOIN "included_parts" '
'ON ("p"."part" = "included_parts"."sub_part")) '
'DELETE FROM "parts" '
'WHERE ("parts"."part" IN ('
'SELECT "included_parts"."part" FROM "included_parts"))'), ['p'])
def test_data_modifying_cte_update(self):
Product = Table('products', ('id', 'name', 'price'))
Archive = Table('archive', ('id', 'name', 'price'))
query = (Product
.update(price=Product.price * 1.05)
.returning(Product.id, Product.name, Product.price))
cte = query.cte('t')
sq = cte.select_from(cte.c.id, cte.c.name, cte.c.price)
self.assertSQL(sq, (
'WITH "t" AS ('
'UPDATE "products" SET "price" = ("products"."price" * ?) '
'RETURNING "products"."id", "products"."name", "products"."price")'
' SELECT "t"."id", "t"."name", "t"."price" FROM "t"'), [1.05])
sq = Select((cte,), (cte.c.id, cte.c.price))
uq = (Archive
.update(price=sq.c.price)
.from_(sq)
.where(Archive.id == sq.c.id)
.with_cte(cte))
self.assertSQL(uq, (
'WITH "t" AS ('
'UPDATE "products" SET "price" = ("products"."price" * ?) '
'RETURNING "products"."id", "products"."name", "products"."price")'
' UPDATE "archive" SET "price" = "t1"."price"'
' FROM (SELECT "t"."id", "t"."price" FROM "t") AS "t1"'
' WHERE ("archive"."id" = "t1"."id")'), [1.05])
def test_data_modifying_cte_insert(self):
Product = Table('products', ('id', 'name', 'price'))
Archive = Table('archive', ('id', 'name', 'price'))
query = (Product
.insert({'name': 'p1', 'price': 10})
.returning(Product.id, Product.name, Product.price))
cte = query.cte('t')
sq = cte.select_from(cte.c.id, cte.c.name, cte.c.price)
self.assertSQL(sq, (
'WITH "t" AS ('
'INSERT INTO "products" ("name", "price") VALUES (?, ?) '
'RETURNING "products"."id", "products"."name", "products"."price")'
' SELECT "t"."id", "t"."name", "t"."price" FROM "t"'),
['p1', 10])
sq = Select((cte,), (cte.c.id, cte.c.name, cte.c.price))
iq = (Archive
.insert(sq, (sq.c.id, sq.c.name, sq.c.price))
.with_cte(cte))
self.assertSQL(iq, (
'WITH "t" AS ('
'INSERT INTO "products" ("name", "price") VALUES (?, ?) '
'RETURNING "products"."id", "products"."name", "products"."price")'
' INSERT INTO "archive" ("id", "name", "price")'
' SELECT "t"."id", "t"."name", "t"."price" FROM "t"'), ['p1', 10])
def test_compound_select(self):
lhs = User.select(User.c.id).where(User.c.username == 'charlie')
rhs = User.select(User.c.username).where(User.c.admin == True)
q2 = (lhs | rhs)
UA = User.alias('U2')
q3 = q2 | UA.select(UA.c.id).where(UA.c.superuser == False)
self.assertSQL(q3, (
'SELECT "t1"."id" '
'FROM "users" AS "t1" '
'WHERE ("t1"."username" = ?) '
'UNION '
'SELECT "t2"."username" '
'FROM "users" AS "t2" '
'WHERE ("t2"."admin" = ?) '
'UNION '
'SELECT "U2"."id" '
'FROM "users" AS "U2" '
'WHERE ("U2"."superuser" = ?)'), ['charlie', True, False])
def test_compound_operations(self):
admin = (User
.select(User.c.username, Value('admin').alias('role'))
.where(User.c.is_admin == True))
editors = (User
.select(User.c.username, Value('editor').alias('role'))
.where(User.c.is_editor == True))
union = admin.union(editors)
self.assertSQL(union, (
'SELECT "t1"."username", ? AS "role" '
'FROM "users" AS "t1" '
'WHERE ("t1"."is_admin" = ?) '
'UNION '
'SELECT "t2"."username", ? AS "role" '
'FROM "users" AS "t2" '
'WHERE ("t2"."is_editor" = ?)'), ['admin', 1, 'editor', 1])
xcept = editors.except_(admin)
self.assertSQL(xcept, (
'SELECT "t1"."username", ? AS "role" '
'FROM "users" AS "t1" '
'WHERE ("t1"."is_editor" = ?) '
'EXCEPT '
'SELECT "t2"."username", ? AS "role" '
'FROM "users" AS "t2" '
'WHERE ("t2"."is_admin" = ?)'), ['editor', 1, 'admin', 1])
def test_compound_parentheses_handling(self):
admin = (User
.select(User.c.username, Value('admin').alias('role'))
.where(User.c.is_admin == True)
.order_by(User.c.id.desc())
.limit(3))
editors = (User
.select(User.c.username, Value('editor').alias('role'))
.where(User.c.is_editor == True)
.order_by(User.c.id.desc())
.limit(5))
self.assertSQL((admin | editors), (
'(SELECT "t1"."username", ? AS "role" FROM "users" AS "t1" '
'WHERE ("t1"."is_admin" = ?) ORDER BY "t1"."id" DESC LIMIT ?) '
'UNION '
'(SELECT "t2"."username", ? AS "role" FROM "users" AS "t2" '
'WHERE ("t2"."is_editor" = ?) ORDER BY "t2"."id" DESC LIMIT ?)'),
['admin', 1, 3, 'editor', 1, 5], compound_select_parentheses=True)
Reg = Table('register', ('value',))
lhs = Reg.select().where(Reg.value < 2)
rhs = Reg.select().where(Reg.value > 7)
compound = lhs | rhs
for csq_setting in (1, 2):
self.assertSQL(compound, (
'(SELECT "t1"."value" FROM "register" AS "t1" '
'WHERE ("t1"."value" < ?)) '
'UNION '
'(SELECT "t2"."value" FROM "register" AS "t2" '
'WHERE ("t2"."value" > ?))'),
[2, 7], compound_select_parentheses=csq_setting)
rhs2 = Reg.select().where(Reg.value == 5)
c2 = compound | rhs2
# CSQ = always, we get nested parentheses.
self.assertSQL(c2, (
'((SELECT "t1"."value" FROM "register" AS "t1" '
'WHERE ("t1"."value" < ?)) '
'UNION '
'(SELECT "t2"."value" FROM "register" AS "t2" '
'WHERE ("t2"."value" > ?))) '
'UNION '
'(SELECT "t3"."value" FROM "register" AS "t3" '
'WHERE ("t3"."value" = ?))'),
[2, 7, 5], compound_select_parentheses=1) # Always.
# CSQ = unnested, no nesting but all individual queries have parens.
self.assertSQL(c2, (
'(SELECT "t1"."value" FROM "register" AS "t1" '
'WHERE ("t1"."value" < ?)) '
'UNION '
'(SELECT "t2"."value" FROM "register" AS "t2" '
'WHERE ("t2"."value" > ?)) '
'UNION '
'(SELECT "t3"."value" FROM "register" AS "t3" '
'WHERE ("t3"."value" = ?))'),
[2, 7, 5], compound_select_parentheses=2) # Un-nested.
def test_compound_select_order_limit(self):
A = Table('a', ('col_a',))
B = Table('b', ('col_b',))
C = Table('c', ('col_c',))
q1 = A.select(A.col_a.alias('foo'))
q2 = B.select(B.col_b.alias('foo'))
q3 = C.select(C.col_c.alias('foo'))
qc = (q1 | q2 | q3)
qc = qc.order_by(qc.c.foo.desc()).limit(3)
self.assertSQL(qc, (
'SELECT "t1"."col_a" AS "foo" FROM "a" AS "t1" UNION '
'SELECT "t2"."col_b" AS "foo" FROM "b" AS "t2" UNION '
'SELECT "t3"."col_c" AS "foo" FROM "c" AS "t3" '
'ORDER BY "foo" DESC LIMIT ?'), [3])
self.assertSQL(qc, (
'((SELECT "t1"."col_a" AS "foo" FROM "a" AS "t1") UNION '
'(SELECT "t2"."col_b" AS "foo" FROM "b" AS "t2")) UNION '
'(SELECT "t3"."col_c" AS "foo" FROM "c" AS "t3") '
'ORDER BY "foo" DESC LIMIT ?'),
[3], compound_select_parentheses=1)
def test_compound_select_as_subquery(self):
A = Table('a', ('col_a',))
B = Table('b', ('col_b',))
q1 = A.select(A.col_a.alias('foo'))
q2 = B.select(B.col_b.alias('foo'))
union = q1 | q2
# Create an outer query and do grouping.
outer = (union
.select_from(union.c.foo, fn.COUNT(union.c.foo).alias('ct'))
.group_by(union.c.foo))
self.assertSQL(outer, (
'SELECT "t1"."foo", COUNT("t1"."foo") AS "ct" FROM ('
'SELECT "t2"."col_a" AS "foo" FROM "a" AS "t2" UNION '
'SELECT "t3"."col_b" AS "foo" FROM "b" AS "t3") AS "t1" '
'GROUP BY "t1"."foo"'), [])
def test_union_with_order_and_limit(self):
q1 = User.select(User.c.username).where(User.c.id < 5)
q2 = User.select(User.c.username).where(User.c.id > 95)
combined = (q1 | q2).order_by(SQL('1')).limit(10)
self.assertSQL(combined, (
'SELECT "t1"."username" FROM "users" AS "t1" '
'WHERE ("t1"."id" < ?) '
'UNION '
'SELECT "t2"."username" FROM "users" AS "t2" '
'WHERE ("t2"."id" > ?) '
'ORDER BY 1 LIMIT ?'), [5, 95, 10])
def test_intersect(self):
q1 = User.select(User.c.username).where(User.c.id < 10)
q2 = User.select(User.c.username).where(User.c.id > 5)
combined = q1 & q2
self.assertSQL(combined, (
'SELECT "t1"."username" FROM "users" AS "t1" '
'WHERE ("t1"."id" < ?) '
'INTERSECT '
'SELECT "t2"."username" FROM "users" AS "t2" '
'WHERE ("t2"."id" > ?)'), [10, 5])
def test_except(self):
q1 = User.select(User.c.username)
q2 = User.select(User.c.username).where(User.c.id > 5)
combined = q1 - q2
self.assertSQL(combined, (
'SELECT "t1"."username" FROM "users" AS "t1" '
'EXCEPT '
'SELECT "t2"."username" FROM "users" AS "t2" '
'WHERE ("t2"."id" > ?)'), [5])
def test_complex_select(self):
Order = Table('orders', columns=(
'region',
'amount',
'product',
'quantity'))
regional_sales = (Order
.select(
Order.region,
fn.SUM(Order.amount).alias('total_sales'))
.group_by(Order.region)
.cte('regional_sales'))
top_regions = (regional_sales
.select(regional_sales.c.region)
.where(regional_sales.c.total_sales > (
regional_sales.select(
fn.SUM(regional_sales.c.total_sales) / 10)))
.cte('top_regions'))
query = (Order
.select(
Order.region,
Order.product,
fn.SUM(Order.quantity).alias('product_units'),
fn.SUM(Order.amount).alias('product_sales'))
.where(
Order.region << top_regions.select(top_regions.c.region))
.group_by(Order.region, Order.product)
.with_cte(regional_sales, top_regions))
self.assertSQL(query, (
'WITH "regional_sales" AS ('
'SELECT "t1"."region", SUM("t1"."amount") AS "total_sales" '
'FROM "orders" AS "t1" '
'GROUP BY "t1"."region"'
'), '
'"top_regions" AS ('
'SELECT "regional_sales"."region" '
'FROM "regional_sales" '
'WHERE ("regional_sales"."total_sales" > '
'(SELECT (SUM("regional_sales"."total_sales") / ?) '
'FROM "regional_sales"))'
') '
'SELECT "t2"."region", "t2"."product", '
'SUM("t2"."quantity") AS "product_units", '
'SUM("t2"."amount") AS "product_sales" '
'FROM "orders" AS "t2" '
'WHERE ('
'"t2"."region" IN ('
'SELECT "top_regions"."region" '
'FROM "top_regions")'
') GROUP BY "t2"."region", "t2"."product"'), [10])
def test_lateral_subquery_model(self):
inner = (Tweet
.select(Tweet.c.content)
.where(Tweet.c.user_id == User.c.id)
.order_by(Tweet.c.timestamp.desc())
.limit(1))
query = (User
.select(User.c.username, inner.c.content)
.join(inner, JOIN.LEFT_LATERAL, on=True))
self.assertSQL(query, (
'SELECT "t1"."username", "t2"."content" '
'FROM "users" AS "t1" '
'LEFT JOIN LATERAL ('
'SELECT "t3"."content" FROM "tweets" AS "t3" '
'WHERE ("t3"."user_id" = "t1"."id") '
'ORDER BY "t3"."timestamp" DESC LIMIT ?) AS "t2" ON ?'),
[1, True])
def test_all_clauses(self):
count = fn.COUNT(Tweet.c.id).alias('ct')
query = (User
.select(User.c.username, count)
.join(Tweet, JOIN.LEFT_OUTER,
on=(User.c.id == Tweet.c.user_id))
.where(User.c.is_admin == 1)
.group_by(User.c.username)
.having(count > 10)
.order_by(count.desc()))
self.assertSQL(query, (
'SELECT "t1"."username", COUNT("t2"."id") AS "ct" '
'FROM "users" AS "t1" '
'LEFT OUTER JOIN "tweets" AS "t2" '
'ON ("t1"."id" = "t2"."user_id") '
'WHERE ("t1"."is_admin" = ?) '
'GROUP BY "t1"."username" '
'HAVING ("ct" > ?) '
'ORDER BY "ct" DESC'), [1, 10])
def test_order_by_collate(self):
query = (User
.select(User.c.username)
.order_by(User.c.username.asc(collation='binary')))
self.assertSQL(query, (
'SELECT "t1"."username" FROM "users" AS "t1" '
'ORDER BY "t1"."username" ASC COLLATE binary'), [])
def test_order_by_nulls(self):
query = (User
.select(User.c.username)
.order_by(User.c.ts.desc(nulls='LAST')))
self.assertSQL(query, (
'SELECT "t1"."username" FROM "users" AS "t1" '
'ORDER BY "t1"."ts" DESC NULLS LAST'), [], nulls_ordering=True)
self.assertSQL(query, (
'SELECT "t1"."username" FROM "users" AS "t1" '
'ORDER BY CASE WHEN ("t1"."ts" IS NULL) THEN ? ELSE ? END, '
'"t1"."ts" DESC'), [1, 0], nulls_ordering=False)
query = (User
.select(User.c.username)
.order_by(User.c.ts.desc(nulls='first')))
self.assertSQL(query, (
'SELECT "t1"."username" FROM "users" AS "t1" '
'ORDER BY "t1"."ts" DESC NULLS first'), [], nulls_ordering=True)
self.assertSQL(query, (
'SELECT "t1"."username" FROM "users" AS "t1" '
'ORDER BY CASE WHEN ("t1"."ts" IS NULL) THEN ? ELSE ? END, '
'"t1"."ts" DESC'), [0, 1], nulls_ordering=False)
def test_ordering_invalid_nulls_error(self):
self.assertRaises(ValueError, Ordering, User.c.id, 'ASC',
nulls='middle')
def test_coalesce(self):
Sample = Table('sample', ('counter', 'value'))
query = (Sample
.select(fn.COALESCE(Sample.value, 0).alias('val'))
.where(Sample.counter == 1))
self.assertSQL(query, (
'SELECT COALESCE("t1"."value", ?) AS "val" '
'FROM "sample" AS "t1" '
'WHERE ("t1"."counter" = ?)'), [0, 1])
def test_nullif(self):
Sample = Table('sample', ('counter', 'value'))
query = (Sample
.select(fn.NULLIF(Sample.value, 0).alias('val')))
self.assertSQL(query, (
'SELECT NULLIF("t1"."value", ?) AS "val" '
'FROM "sample" AS "t1"'), [0])
def test_like_escape(self):
T = Table('tbl', ('key',))
def assertLike(expr, expected):
query = T.select().where(expr)
sql, params = __sql__(T.select().where(expr))
match_obj = re.search(r'\("t1"."key" (ILIKE[^\)]+)\)', sql)
if match_obj is None:
raise AssertionError('LIKE expression not found in query.')
like, = match_obj.groups()
self.assertEqual((like, params), expected)
cases = (
(T.key.contains('base'), ('ILIKE ?', ['%base%'])),
(T.key.contains('x_y'), ("ILIKE ? ESCAPE ?", ['%x\\_y%', '\\'])),
(T.key.contains('__y'), ("ILIKE ? ESCAPE ?", ['%\\_\\_y%', '\\'])),
(T.key.contains('%'), ("ILIKE ? ESCAPE ?", ['%\\%%', '\\'])),
(T.key.contains('_%'), ("ILIKE ? ESCAPE ?", ['%\\_\\%%', '\\'])),
(T.key.startswith('base'), ("ILIKE ?", ['base%'])),
(T.key.startswith('x_y'), ("ILIKE ? ESCAPE ?", ['x\\_y%', '\\'])),
(T.key.startswith('x%'), ("ILIKE ? ESCAPE ?", ['x\\%%', '\\'])),
(T.key.startswith('_%'), ("ILIKE ? ESCAPE ?", ['\\_\\%%', '\\'])),
(T.key.endswith('base'), ("ILIKE ?", ['%base'])),
(T.key.endswith('x_y'), ("ILIKE ? ESCAPE ?", ['%x\\_y', '\\'])),
(T.key.endswith('x%'), ("ILIKE ? ESCAPE ?", ['%x\\%', '\\'])),
(T.key.endswith('_%'), ("ILIKE ? ESCAPE ?", ['%\\_\\%', '\\'])),
)
for expr, expected in cases:
assertLike(expr, expected)
def test_like_expr(self):
query = User.select(User.c.id).where(User.c.username.like('%foo%'))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."username" LIKE ?)'), ['%foo%'])
query = User.select(User.c.id).where(User.c.username.ilike('%foo%'))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."username" ILIKE ?)'), ['%foo%'])
def test_field_ops(self):
query = User.select(User.c.id).where(User.c.username.regexp('[a-z]+'))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."username" REGEXP ?)'), ['[a-z]+'])
query = User.select(User.c.id).where(User.c.username.contains('abc'))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."username" ILIKE ?)'), ['%abc%'])
def test_bitwise_ops(self):
query = User.select(User.c.id).where(User.c.id.bin_and(4))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."id" & ?)'), [4])
query = User.select(User.c.id).where(User.c.id.bin_or(1))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE ("t1"."id" | ?)'), [1])
def test_add_custom_op(self):
def mod(lhs, rhs):
return Expression(lhs, '%', rhs)
Stat = Table('stats')
query = (Stat
.select(fn.COUNT(Stat.c.id))
.where(mod(Stat.c.index, 10) == 0))
self.assertSQL(query, (
'SELECT COUNT("t1"."id") FROM "stats" AS "t1" '
'WHERE (("t1"."index" % ?) = ?)'), [10, 0])
def test_entity_escaping(self):
Tbl = Table('te"st')
query = Tbl.select(Tbl.c.id).where(Tbl.c.value > 5)
self.assertSQL(query, (
'SELECT "t1"."id" FROM "te""st" AS "t1" '
'WHERE ("t1"."value" > ?)'), [5])
self.assertSQL(query, (
'SELECT `t1`.`id` FROM `te"st` AS `t1` '
'WHERE (`t1`.`value` > ?)'), [5], quote='``')
def test_tuple_comparison(self):
name_dob = Tuple(Person.name, Person.dob)
query = (Person
.select(Person.id)
.where(name_dob == ('foo', '2017-01-01')))
expected = ('SELECT "t1"."id" FROM "person" AS "t1" '
'WHERE (("t1"."name", "t1"."dob") = (?, ?))')
self.assertSQL(query, expected, ['foo', '2017-01-01'])
# Also works specifying rhs values as Tuple().
query = (Person
.select(Person.id)
.where(name_dob == Tuple('foo', '2017-01-01')))
self.assertSQL(query, expected, ['foo', '2017-01-01'])
def test_tuple_comparison_subquery(self):
PA = Person.alias('pa')
subquery = (PA
.select(PA.name, PA.id)
.where(PA.name != 'huey'))
query = (Person
.select(Person.name)
.where(Tuple(Person.name, Person.id).in_(subquery)))
self.assertSQL(query, (
'SELECT "t1"."name" FROM "person" AS "t1" '
'WHERE (("t1"."name", "t1"."id") IN ('
'SELECT "pa"."name", "pa"."id" FROM "person" AS "pa" '
'WHERE ("pa"."name" != ?)))'), ['huey'])
def test_tuple_in_subquery(self):
subq = (Tweet
.select(Tweet.c.user_id, Tweet.c.content)
.where(Tweet.c.content == 'special'))
query = (User
.select(User.c.id, User.c.username)
.where(Tuple(User.c.id, User.c.username).in_(subq)))
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."username" '
'FROM "users" AS "t1" '
'WHERE (("t1"."id", "t1"."username") IN ('
'SELECT "t2"."user_id", "t2"."content" '
'FROM "tweets" AS "t2" '
'WHERE ("t2"."content" = ?)))'), ['special'])
class TestInsertQuery(BaseTestCase):
def test_insert_simple(self):
query = User.insert({
User.c.username: 'charlie',
User.c.superuser: False,
User.c.admin: True})
self.assertSQL(query, (
'INSERT INTO "users" ("admin", "superuser", "username") '
'VALUES (?, ?, ?)'), [True, False, 'charlie'])
def test_table_insert_with_kwargs(self):
query = User.insert(
username='charlie',
superuser=False,
admin=True)
self.assertSQL(query, (
'INSERT INTO "users" ("admin", "superuser", "username") '
'VALUES (?, ?, ?)'), [True, False, 'charlie'])
@requires_sqlite
def test_replace_sqlite(self):
query = User.replace({
User.c.username: 'charlie',
User.c.superuser: False})
self.assertSQL(query, (
'INSERT OR REPLACE INTO "users" ("superuser", "username") '
'VALUES (?, ?)'), [False, 'charlie'])
query = User.replace(
username='charlie',
superuser=False)
self.assertSQL(query, (
'INSERT OR REPLACE INTO "users" ("superuser", "username") '
'VALUES (?, ?)'), [False, 'charlie'])
@requires_mysql
def test_replace_mysql(self):
query = User.replace({
User.c.username: 'charlie',
User.c.superuser: False})
self.assertSQL(query, (
'REPLACE INTO "users" ("superuser", "username") '
'VALUES (?, ?)'), [False, 'charlie'])
def test_insert_list(self):
data = [
{Person.name: 'charlie'},
{Person.name: 'huey'},
{Person.name: 'zaizee'}]
query = Person.insert(data)
self.assertSQL(query, (
'INSERT INTO "person" ("name") VALUES (?), (?), (?)'),
['charlie', 'huey', 'zaizee'])
def test_insert_list_with_columns(self):
data = [(i,) for i in ('charlie', 'huey', 'zaizee')]
query = Person.insert(data, columns=[Person.name])
self.assertSQL(query, (
'INSERT INTO "person" ("name") VALUES (?), (?), (?)'),
['charlie', 'huey', 'zaizee'])
# Use column name instead of column instance.
query = Person.insert(data, columns=['name'])
self.assertSQL(query, (
'INSERT INTO "person" ("name") VALUES (?), (?), (?)'),
['charlie', 'huey', 'zaizee'])
def test_insert_list_infer_columns(self):
data = [('p1', '1980-01-01'), ('p2', '1980-02-02')]
self.assertSQL(Person.insert(data), (
'INSERT INTO "person" ("name", "dob") VALUES (?, ?), (?, ?)'),
['p1', '1980-01-01', 'p2', '1980-02-02'])
# Cannot infer any columns for User.
data = [('u1',), ('u2',)]
self.assertRaises(ValueError, User.insert(data).sql)
# Note declares columns, but no primary key. So we would have to
# include it for this to work.
data = [(1, 'p1-n'), (2, 'p2-n')]
self.assertRaises(ValueError, Note.insert(data).sql)
data = [(1, 1, 'p1-n'), (2, 2, 'p2-n')]
self.assertSQL(Note.insert(data), (
'INSERT INTO "note" ("id", "person_id", "content") '
'VALUES (?, ?, ?), (?, ?, ?)'), [1, 1, 'p1-n', 2, 2, 'p2-n'])
def test_insert_query(self):
source = User.select(User.c.username).where(User.c.admin == False)
query = Person.insert(source, columns=[Person.name])
self.assertSQL(query, (
'INSERT INTO "person" ("name") '
'SELECT "t1"."username" FROM "users" AS "t1" '
'WHERE ("t1"."admin" = ?)'), [False])
def test_insert_query_cte(self):
cte = User.select(User.c.username).cte('foo')
source = cte.select(cte.c.username)
query = Person.insert(source, columns=[Person.name]).with_cte(cte)
self.assertSQL(query, (
'WITH "foo" AS (SELECT "t1"."username" FROM "users" AS "t1") '
'INSERT INTO "person" ("name") '
'SELECT "foo"."username" FROM "foo"'), [])
def test_insert_single_value_query(self):
query = Person.select(Person.id).where(Person.name == 'huey')
insert = Note.insert({
Note.person_id: query,
Note.content: 'hello'})
self.assertSQL(insert, (
'INSERT INTO "note" ("content", "person_id") VALUES (?, '
'(SELECT "t1"."id" FROM "person" AS "t1" '
'WHERE ("t1"."name" = ?)))'), ['hello', 'huey'])
def test_insert_returning(self):
query = (Person
.insert({
Person.name: 'zaizee',
Person.dob: datetime.date(2000, 1, 2)})
.returning(Person.id, Person.name, Person.dob))
self.assertSQL(query, (
'INSERT INTO "person" ("dob", "name") '
'VALUES (?, ?) '
'RETURNING "person"."id", "person"."name", "person"."dob"'),
[datetime.date(2000, 1, 2), 'zaizee'])
query = query.returning(Person.id, Person.name.alias('new_name'))
self.assertSQL(query, (
'INSERT INTO "person" ("dob", "name") '
'VALUES (?, ?) '
'RETURNING "person"."id", "person"."name" AS "new_name"'),
[datetime.date(2000, 1, 2), 'zaizee'])
def test_insert_returning_expression(self):
query = (Person
.insert(name='huey')
.returning(Person.id, Person.name,
fn.LENGTH(Person.name).alias('ulen')))
self.assertSQL(query, (
'INSERT INTO "person" ("name") VALUES (?) '
'RETURNING "person"."id", '
'"person"."name", '
'LENGTH("person"."name") AS "ulen"'), ['huey'])
def test_empty(self):
class Empty(TestModel): pass
if isinstance(db, MySQLDatabase):
sql = 'INSERT INTO "empty" () VALUES ()'
elif isinstance(db, PostgresqlDatabase):
sql = 'INSERT INTO "empty" DEFAULT VALUES RETURNING "empty"."id"'
else:
sql = 'INSERT INTO "empty" DEFAULT VALUES'
for query in (Empty.insert(), Empty.insert({}), Empty.insert([])):
self.assertSQL(query, sql, [])
def test_insert_where_raises(self):
q = User.insert({User.c.username: 'huey'})
self.assertRaises(NotImplementedError, q.where, User.c.val > 0)
class TestUpdateQuery(BaseTestCase):
def test_update_query(self):
query = (User
.update({
User.c.username: 'nuggie',
User.c.admin: False,
User.c.counter: User.c.counter + 1})
.where(User.c.username == 'nugz'))
self.assertSQL(query, (
'UPDATE "users" SET '
'"admin" = ?, '
'"counter" = ("users"."counter" + ?), '
'"username" = ? '
'WHERE ("users"."username" = ?)'), [False, 1, 'nuggie', 'nugz'])
def test_table_update_with_kwargs(self):
query = (User
.update(
username='nuggie',
admin=False,
counter=User.c.counter + 1)
.where(User.c.username == 'nugz'))
self.assertSQL(query, (
'UPDATE "users" SET '
'"admin" = ?, '
'"counter" = ("users"."counter" + ?), '
'"username" = ? '
'WHERE ("users"."username" = ?)'), [False, 1, 'nuggie', 'nugz'])
def test_update_subquery(self):
count = fn.COUNT(Tweet.c.id).alias('ct')
subquery = (User
.select(User.c.id, count)
.join(Tweet, on=(Tweet.c.user_id == User.c.id))
.group_by(User.c.id)
.having(count > 100))
query = (User
.update({
User.c.muted: True,
User.c.counter: 0})
.where(User.c.id << subquery))
self.assertSQL(query, (
'UPDATE "users" SET '
'"counter" = ?, '
'"muted" = ? '
'WHERE ("users"."id" IN ('
'SELECT "users"."id", COUNT("t1"."id") AS "ct" '
'FROM "users" AS "users" '
'INNER JOIN "tweets" AS "t1" '
'ON ("t1"."user_id" = "users"."id") '
'GROUP BY "users"."id" '
'HAVING ("ct" > ?)))'), [0, True, 100])
def test_update_value_subquery(self):
subquery = (Tweet
.select(fn.MAX(Tweet.c.id))
.where(Tweet.c.user_id == User.c.id))
query = (User
.update({User.c.last_tweet_id: subquery})
.where(User.c.last_tweet_id.is_null(True)))
self.assertSQL(query, (
'UPDATE "users" SET '
'"last_tweet_id" = (SELECT MAX("t1"."id") FROM "tweets" AS "t1" '
'WHERE ("t1"."user_id" = "users"."id")) '
'WHERE ("users"."last_tweet_id" IS NULL)'), [])
def test_update_from_cte(self):
cte = (Tweet
.select(Tweet.c.user_id, fn.COUNT(Tweet.c.id).alias('ct'))
.group_by(Tweet.c.user_id)
.cte('t'))
query = (User
.update(tweet_count=cte.c.ct)
.from_(cte)
.where(User.c.id == cte.c.user_id)
.with_cte(cte))
self.assertSQL(query, (
'WITH "t" AS '
'(SELECT "t1"."user_id", COUNT("t1"."id") AS "ct" '
'FROM "tweets" AS "t1" '
'GROUP BY "t1"."user_id") '
'UPDATE "users" SET "tweet_count" = "t"."ct" '
'FROM "t" '
'WHERE ("users"."id" = "t"."user_id")'), [])
def test_update_from(self):
data = [(1, 'u1-x'), (2, 'u2-x')]
vl = ValuesList(data, columns=('id', 'username'), alias='tmp')
query = (User
.update(username=vl.c.username)
.from_(vl)
.where(User.c.id == vl.c.id))
self.assertSQL(query, (
'UPDATE "users" SET "username" = "tmp"."username" '
'FROM (VALUES (?, ?), (?, ?)) AS "tmp"("id", "username") '
'WHERE ("users"."id" = "tmp"."id")'), [1, 'u1-x', 2, 'u2-x'])
subq = vl.select(vl.c.id, vl.c.username)
query = (User
.update({User.c.username: subq.c.username})
.from_(subq)
.where(User.c.id == subq.c.id))
self.assertSQL(query, (
'UPDATE "users" SET "username" = "t1"."username" FROM ('
'SELECT "tmp"."id", "tmp"."username" '
'FROM (VALUES (?, ?), (?, ?)) AS "tmp"("id", "username")) AS "t1" '
'WHERE ("users"."id" = "t1"."id")'), [1, 'u1-x', 2, 'u2-x'])
def test_update_from_subquery(self):
subq = (Tweet
.select(Tweet.c.user_id,
fn.COUNT(Tweet.c.id).alias('ct'))
.group_by(Tweet.c.user_id)
.alias('tweet_ct'))
query = (User
.update({User.c.username: fn.CONCAT(
User.c.username, ' (', subq.c.ct, ')')})
.from_(subq)
.where(User.c.id == subq.c.user_id))
self.assertSQL(query, (
'UPDATE "users" SET "username" = CONCAT('
'"users"."username", ?, "tweet_ct"."ct", ?) '
'FROM ('
'SELECT "t1"."user_id", COUNT("t1"."id") AS "ct" '
'FROM "tweets" AS "t1" '
'GROUP BY "t1"."user_id") AS "tweet_ct" '
'WHERE ("users"."id" = "tweet_ct"."user_id")'),
[' (', ')'])
def test_update_returning(self):
query = (User
.update({User.c.is_admin: True})
.where(User.c.username == 'charlie')
.returning(User.c.id))
self.assertSQL(query, (
'UPDATE "users" SET "is_admin" = ? WHERE ("users"."username" = ?) '
'RETURNING "users"."id"'), [True, 'charlie'])
query = query.returning(User.c.is_admin.alias('new_is_admin'))
self.assertSQL(query, (
'UPDATE "users" SET "is_admin" = ? WHERE ("users"."username" = ?) '
'RETURNING "users"."is_admin" AS "new_is_admin"'),
[True, 'charlie'])
def test_update_with_order_limit(self):
query = (User
.update(username=User.c.username.concat('-x'))
.where(User.c.active == False)
.order_by(User.c.id)
.limit(10))
self.assertSQL(query, (
'UPDATE "users" SET "username" = ("users"."username" || ?) '
'WHERE ("users"."active" = ?) '
'ORDER BY "id" LIMIT ?'), ['-x', False, 10])
class TestDeleteQuery(BaseTestCase):
def test_delete_query(self):
query = (User
.delete()
.where(User.c.username != 'charlie')
.limit(3))
self.assertSQL(query, (
'DELETE FROM "users" WHERE ("users"."username" != ?) LIMIT ?'),
['charlie', 3])
def test_delete_subquery(self):
count = fn.COUNT(Tweet.c.id).alias('ct')
subquery = (User
.select(User.c.id, count)
.join(Tweet, on=(Tweet.c.user_id == User.c.id))
.group_by(User.c.id)
.having(count > 100))
query = (User
.delete()
.where(User.c.id << subquery))
self.assertSQL(query, (
'DELETE FROM "users" '
'WHERE ("users"."id" IN ('
'SELECT "users"."id", COUNT("t1"."id") AS "ct" '
'FROM "users" AS "users" '
'INNER JOIN "tweets" AS "t1" ON ("t1"."user_id" = "users"."id") '
'GROUP BY "users"."id" '
'HAVING ("ct" > ?)))'), [100])
def test_delete_cte(self):
cte = (User
.select(User.c.id)
.where(User.c.admin == True)
.cte('u'))
query = (User
.delete()
.where(User.c.id << cte.select(cte.c.id))
.with_cte(cte))
self.assertSQL(query, (
'WITH "u" AS '
'(SELECT "t1"."id" FROM "users" AS "t1" WHERE ("t1"."admin" = ?)) '
'DELETE FROM "users" '
'WHERE ("users"."id" IN (SELECT "u"."id" FROM "u"))'), [True])
def test_delete_returning(self):
query = (User
.delete()
.where(User.c.id > 2)
.returning(User.c.username))
self.assertSQL(query, (
'DELETE FROM "users" '
'WHERE ("users"."id" > ?) '
'RETURNING "users"."username"'), [2])
query = query.returning(User.c.id, User.c.username, SQL('1'))
self.assertSQL(query, (
'DELETE FROM "users" '
'WHERE ("users"."id" > ?) '
'RETURNING "users"."id", "users"."username", 1'), [2])
query = query.returning(User.c.id.alias('old_id'))
self.assertSQL(query, (
'DELETE FROM "users" '
'WHERE ("users"."id" > ?) '
'RETURNING "users"."id" AS "old_id"'), [2])
def test_delete_with_order_limit(self):
query = (User.delete()
.where(User.c.active == False)
.order_by(User.c.id)
.limit(10))
self.assertSQL(query, (
'DELETE FROM "users" WHERE ("users"."active" = ?) '
'ORDER BY "id" LIMIT ?'), [False, 10])
# ===========================================================================
# Advanced SELECT features: window functions, VALUES lists, CASE expressions
# ===========================================================================
Register = Table('register', ('id', 'value', 'category'))
class TestWindowFunctions(BaseTestCase):
def test_partition_unordered(self):
partition = [Register.category]
query = (Register
.select(
Register.category,
Register.value,
fn.AVG(Register.value).over(partition_by=partition))
.order_by(Register.id))
self.assertSQL(query, (
'SELECT "t1"."category", "t1"."value", AVG("t1"."value") '
'OVER (PARTITION BY "t1"."category") '
'FROM "register" AS "t1" ORDER BY "t1"."id"'), [])
def test_ordered_unpartitioned(self):
query = (Register
.select(
Register.value,
fn.RANK().over(order_by=[Register.value])))
self.assertSQL(query, (
'SELECT "t1"."value", RANK() OVER (ORDER BY "t1"."value") '
'FROM "register" AS "t1"'), [])
def test_ordered_partitioned(self):
query = Register.select(
Register.value,
fn.SUM(Register.value).over(
order_by=Register.id,
partition_by=Register.category).alias('rsum'))
self.assertSQL(query, (
'SELECT "t1"."value", SUM("t1"."value") '
'OVER (PARTITION BY "t1"."category" ORDER BY "t1"."id") AS "rsum" '
'FROM "register" AS "t1"'), [])
def test_empty_over(self):
query = (Register
.select(Register.value, fn.LAG(Register.value, 1).over())
.order_by(Register.value))
self.assertSQL(query, (
'SELECT "t1"."value", LAG("t1"."value", ?) OVER () '
'FROM "register" AS "t1" '
'ORDER BY "t1"."value"'), [1])
def test_frame(self):
query = (Register
.select(
Register.value,
fn.AVG(Register.value).over(
partition_by=[Register.category],
start=Window.preceding(),
end=Window.following(2))))
self.assertSQL(query, (
'SELECT "t1"."value", AVG("t1"."value") '
'OVER (PARTITION BY "t1"."category" '
'ROWS BETWEEN UNBOUNDED PRECEDING AND 2 FOLLOWING) '
'FROM "register" AS "t1"'), [])
query = (Register
.select(Register.value, fn.AVG(Register.value).over(
partition_by=[Register.category],
order_by=[Register.value],
start=Window.CURRENT_ROW,
end=Window.following())))
self.assertSQL(query, (
'SELECT "t1"."value", AVG("t1"."value") '
'OVER (PARTITION BY "t1"."category" '
'ORDER BY "t1"."value" '
'ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) '
'FROM "register" AS "t1"'), [])
def test_frame_types(self):
def assertFrame(over_kwargs, expected):
query = Register.select(
Register.value,
fn.SUM(Register.value).over(**over_kwargs))
sql, params = __sql__(query)
match_obj = re.search(r'OVER \((.*?)\) FROM', sql)
self.assertTrue(match_obj is not None)
self.assertEqual(match_obj.groups()[0], expected)
self.assertEqual(params, [])
# No parameters -- empty OVER().
assertFrame({}, (''))
# Explicitly specify RANGE / ROWS frame-types.
assertFrame({'frame_type': Window.RANGE}, 'RANGE UNBOUNDED PRECEDING')
assertFrame({'frame_type': Window.ROWS}, 'ROWS UNBOUNDED PRECEDING')
# Start and end boundaries.
assertFrame({'start': Window.preceding(), 'end': Window.following()},
'ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING')
assertFrame({
'start': Window.preceding(),
'end': Window.following(),
'frame_type': Window.RANGE,
}, 'RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING')
assertFrame({
'start': Window.preceding(),
'end': Window.following(),
'frame_type': Window.ROWS,
}, 'ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING')
# Start boundary.
assertFrame({'start': Window.preceding()}, 'ROWS UNBOUNDED PRECEDING')
assertFrame({'start': Window.preceding(), 'frame_type': Window.RANGE},
'RANGE UNBOUNDED PRECEDING')
assertFrame({'start': Window.preceding(), 'frame_type': Window.ROWS},
'ROWS UNBOUNDED PRECEDING')
# Ordered or partitioned.
assertFrame({'order_by': Register.value}, 'ORDER BY "t1"."value"')
assertFrame({'frame_type': Window.RANGE, 'order_by': Register.value},
'ORDER BY "t1"."value" RANGE UNBOUNDED PRECEDING')
assertFrame({'frame_type': Window.ROWS, 'order_by': Register.value},
'ORDER BY "t1"."value" ROWS UNBOUNDED PRECEDING')
assertFrame({'partition_by': Register.category},
'PARTITION BY "t1"."category"')
assertFrame({
'frame_type': Window.RANGE,
'partition_by': Register.category,
}, 'PARTITION BY "t1"."category" RANGE UNBOUNDED PRECEDING')
assertFrame({
'frame_type': Window.ROWS,
'partition_by': Register.category,
}, 'PARTITION BY "t1"."category" ROWS UNBOUNDED PRECEDING')
# Ordering and boundaries.
assertFrame({'order_by': Register.value, 'start': Window.CURRENT_ROW,
'end': Window.following()},
('ORDER BY "t1"."value" '
'ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING'))
assertFrame({'order_by': Register.value, 'start': Window.CURRENT_ROW,
'end': Window.following(), 'frame_type': Window.RANGE},
('ORDER BY "t1"."value" '
'RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING'))
assertFrame({'order_by': Register.value, 'start': Window.CURRENT_ROW,
'end': Window.following(), 'frame_type': Window.ROWS},
('ORDER BY "t1"."value" '
'ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING'))
def test_window_end_without_start_error(self):
self.assertRaises(
ValueError, Window,
order_by=[User.c.id], end=Window.CURRENT_ROW)
def test_window_as_groups(self):
w = Window(order_by=[User.c.id],
start=Window.preceding(),
end=Window.CURRENT_ROW).as_groups()
query = User.select(fn.SUM(User.c.val).over(window=w)).window(w)
self.assertSQL(query, (
'SELECT SUM("t1"."val") OVER "w" FROM "users" AS "t1" '
'WINDOW "w" AS (ORDER BY "t1"."id" '
'GROUPS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)'))
def test_window_exclude_with_string(self):
w = Window(order_by=[User.c.id],
start=Window.preceding(),
end=Window.CURRENT_ROW)
w = w.exclude('TIES')
query = User.select(fn.SUM(User.c.val).over(window=w)).window(w)
self.assertSQL(query, (
'SELECT SUM("t1"."val") OVER "w" FROM "users" AS "t1" '
'WINDOW "w" AS (ORDER BY "t1"."id" '
'ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW EXCLUDE TIES)'))
def test_running_total(self):
EventLog = Table('evtlog', ('id', 'timestamp', 'data'))
w = fn.SUM(EventLog.timestamp).over(order_by=[EventLog.timestamp])
query = (EventLog
.select(EventLog.timestamp, EventLog.data, w.alias('elapsed'))
.order_by(EventLog.timestamp))
self.assertSQL(query, (
'SELECT "t1"."timestamp", "t1"."data", '
'SUM("t1"."timestamp") OVER (ORDER BY "t1"."timestamp") '
'AS "elapsed" '
'FROM "evtlog" AS "t1" ORDER BY "t1"."timestamp"'), [])
w = fn.SUM(EventLog.timestamp).over(
order_by=[EventLog.timestamp],
partition_by=[EventLog.data])
query = (EventLog
.select(EventLog.timestamp, EventLog.data, w.alias('elapsed'))
.order_by(EventLog.timestamp))
self.assertSQL(query, (
'SELECT "t1"."timestamp", "t1"."data", '
'SUM("t1"."timestamp") OVER '
'(PARTITION BY "t1"."data" ORDER BY "t1"."timestamp") AS "elapsed"'
' FROM "evtlog" AS "t1" ORDER BY "t1"."timestamp"'), [])
def test_named_window(self):
window = Window(partition_by=[Register.category])
query = (Register
.select(
Register.category,
Register.value,
fn.AVG(Register.value).over(window))
.window(window))
self.assertSQL(query, (
'SELECT "t1"."category", "t1"."value", AVG("t1"."value") '
'OVER "w" '
'FROM "register" AS "t1" '
'WINDOW "w" AS (PARTITION BY "t1"."category")'), [])
window = Window(
partition_by=[Register.category],
order_by=[Register.value.desc()])
query = (Register
.select(
Register.value,
fn.RANK().over(window))
.window(window))
self.assertSQL(query, (
'SELECT "t1"."value", RANK() OVER "w" '
'FROM "register" AS "t1" '
'WINDOW "w" AS ('
'PARTITION BY "t1"."category" '
'ORDER BY "t1"."value" DESC)'), [])
def test_multiple_windows(self):
w1 = Window(partition_by=[Register.category]).alias('w1')
w2 = Window(order_by=[Register.value]).alias('w2')
query = (Register
.select(
Register.value,
fn.AVG(Register.value).over(w1),
fn.RANK().over(w2))
.window(w1, w2))
self.assertSQL(query, (
'SELECT "t1"."value", AVG("t1"."value") OVER "w1", '
'RANK() OVER "w2" '
'FROM "register" AS "t1" '
'WINDOW "w1" AS (PARTITION BY "t1"."category"), '
'"w2" AS (ORDER BY "t1"."value")'), [])
def test_alias_window(self):
w = Window(order_by=Register.value).alias('wx')
query = Register.select(Register.value, fn.RANK().over(w)).window(w)
# We can re-alias the window and it's updated alias is reflected
# correctly in the final query.
w.alias('wz')
self.assertSQL(query, (
'SELECT "t1"."value", RANK() OVER "wz" '
'FROM "register" AS "t1" '
'WINDOW "wz" AS (ORDER BY "t1"."value")'), [])
def test_reuse_window(self):
EventLog = Table('evt', ('id', 'timestamp', 'key'))
window = Window(partition_by=[EventLog.key],
order_by=[EventLog.timestamp])
query = (EventLog
.select(EventLog.timestamp, EventLog.key,
fn.NTILE(4).over(window).alias('quartile'),
fn.NTILE(5).over(window).alias('quintile'),
fn.NTILE(100).over(window).alias('percentile'))
.order_by(EventLog.timestamp)
.window(window))
self.assertSQL(query, (
'SELECT "t1"."timestamp", "t1"."key", '
'NTILE(?) OVER "w" AS "quartile", '
'NTILE(?) OVER "w" AS "quintile", '
'NTILE(?) OVER "w" AS "percentile" '
'FROM "evt" AS "t1" '
'WINDOW "w" AS ('
'PARTITION BY "t1"."key" ORDER BY "t1"."timestamp") '
'ORDER BY "t1"."timestamp"'), [4, 5, 100])
def test_filter_clause(self):
condsum = fn.SUM(Register.value).filter(Register.value > 1).over(
order_by=[Register.id], partition_by=[Register.category],
start=Window.preceding(1))
query = (Register
.select(Register.category, Register.value, condsum)
.order_by(Register.category))
self.assertSQL(query, (
'SELECT "t1"."category", "t1"."value", SUM("t1"."value") FILTER ('
'WHERE ("t1"."value" > ?)) OVER (PARTITION BY "t1"."category" '
'ORDER BY "t1"."id" ROWS 1 PRECEDING) '
'FROM "register" AS "t1" '
'ORDER BY "t1"."category"'), [1])
def test_window_in_orderby(self):
Register = Table('register', ['id', 'value'])
w = Window(partition_by=[Register.value], order_by=[Register.id])
query = (Register
.select()
.window(w)
.order_by(fn.FIRST_VALUE(Register.id).over(w)))
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."value" FROM "register" AS "t1" '
'WINDOW "w" AS (PARTITION BY "t1"."value" ORDER BY "t1"."id") '
'ORDER BY FIRST_VALUE("t1"."id") OVER "w"'), [])
fv = fn.FIRST_VALUE(Register.id).over(
partition_by=[Register.value],
order_by=[Register.id])
query = Register.select().order_by(fv)
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."value" FROM "register" AS "t1" '
'ORDER BY FIRST_VALUE("t1"."id") '
'OVER (PARTITION BY "t1"."value" ORDER BY "t1"."id")'), [])
def test_window_extends(self):
Tbl = Table('tbl', ('b', 'c'))
w1 = Window(partition_by=[Tbl.b], alias='win1')
w2 = Window(extends=w1, order_by=[Tbl.c], alias='win2')
query = Tbl.select(fn.GROUP_CONCAT(Tbl.c).over(w2)).window(w1, w2)
self.assertSQL(query, (
'SELECT GROUP_CONCAT("t1"."c") OVER "win2" FROM "tbl" AS "t1" '
'WINDOW "win1" AS (PARTITION BY "t1"."b"), '
'"win2" AS ("win1" ORDER BY "t1"."c")'), [])
w1 = Window(partition_by=[Tbl.b], alias='w1')
w2 = Window(extends=w1).alias('w2')
w3 = Window(extends=w2).alias('w3')
w4 = Window(extends=w3, order_by=[Tbl.c]).alias('w4')
query = (Tbl
.select(fn.GROUP_CONCAT(Tbl.c).over(w4))
.window(w1, w2, w3, w4))
self.assertSQL(query, (
'SELECT GROUP_CONCAT("t1"."c") OVER "w4" FROM "tbl" AS "t1" '
'WINDOW "w1" AS (PARTITION BY "t1"."b"), "w2" AS ("w1"), '
'"w3" AS ("w2"), '
'"w4" AS ("w3" ORDER BY "t1"."c")'), [])
def test_window_ranged(self):
Tbl = Table('tbl', ('a', 'b'))
query = (Tbl
.select(Tbl.a, fn.SUM(Tbl.b).over(
order_by=[Tbl.a.desc()],
frame_type=Window.RANGE,
start=Window.preceding(1),
end=Window.following(2)))
.order_by(Tbl.a.asc()))
self.assertSQL(query, (
'SELECT "t1"."a", SUM("t1"."b") OVER ('
'ORDER BY "t1"."a" DESC RANGE BETWEEN 1 PRECEDING AND 2 FOLLOWING)'
' FROM "tbl" AS "t1" ORDER BY "t1"."a" ASC'), [])
query = (Tbl
.select(Tbl.a, fn.SUM(Tbl.b).over(
order_by=[Tbl.a],
frame_type=Window.GROUPS,
start=Window.preceding(3),
end=Window.preceding(1))))
self.assertSQL(query, (
'SELECT "t1"."a", SUM("t1"."b") OVER ('
'ORDER BY "t1"."a" GROUPS BETWEEN 3 PRECEDING AND 1 PRECEDING) '
'FROM "tbl" AS "t1"'), [])
query = (Tbl
.select(Tbl.a, fn.SUM(Tbl.b).over(
order_by=[Tbl.a],
frame_type=Window.GROUPS,
start=Window.following(1),
end=Window.following(5))))
self.assertSQL(query, (
'SELECT "t1"."a", SUM("t1"."b") OVER ('
'ORDER BY "t1"."a" GROUPS BETWEEN 1 FOLLOWING AND 5 FOLLOWING) '
'FROM "tbl" AS "t1"'), [])
def test_window_frametypes(self):
Tbl = Table('tbl', ('b', 'c'))
fts = (('as_range', Window.RANGE, 'RANGE'),
('as_rows', Window.ROWS, 'ROWS'),
('as_groups', Window.GROUPS, 'GROUPS'))
for method, arg, sql in fts:
w = getattr(Window(order_by=[Tbl.b + 1]), method)()
self.assertSQL(Tbl.select(fn.SUM(Tbl.c).over(w)).window(w), (
'SELECT SUM("t1"."c") OVER "w" FROM "tbl" AS "t1" '
'WINDOW "w" AS (ORDER BY ("t1"."b" + ?) '
'%s UNBOUNDED PRECEDING)') % sql, [1])
query = Tbl.select(fn.SUM(Tbl.c)
.over(order_by=[Tbl.b + 1], frame_type=arg))
self.assertSQL(query, (
'SELECT SUM("t1"."c") OVER (ORDER BY ("t1"."b" + ?) '
'%s UNBOUNDED PRECEDING) FROM "tbl" AS "t1"') % sql, [1])
def test_window_frame_exclusion(self):
Tbl = Table('tbl', ('b', 'c'))
fts = ((Window.CURRENT_ROW, 'CURRENT ROW'),
(Window.TIES, 'TIES'),
(Window.NO_OTHERS, 'NO OTHERS'),
(Window.GROUP, 'GROUP'))
for arg, sql in fts:
query = Tbl.select(fn.MAX(Tbl.b).over(
order_by=[Tbl.c],
start=Window.preceding(4),
end=Window.following(),
frame_type=Window.ROWS,
exclude=arg))
self.assertSQL(query, (
'SELECT MAX("t1"."b") OVER (ORDER BY "t1"."c" '
'ROWS BETWEEN 4 PRECEDING AND UNBOUNDED FOLLOWING '
'EXCLUDE %s) FROM "tbl" AS "t1"') % sql, [])
def test_filter_window(self):
# Example derived from sqlite window test 5.1.3.2.
Tbl = Table('tbl', ('a', 'c'))
win = Window(partition_by=fn.COALESCE(Tbl.a, ''),
frame_type=Window.RANGE,
start=Window.CURRENT_ROW,
end=Window.following(),
exclude=Window.NO_OTHERS)
query = (Tbl
.select(fn.SUM(Tbl.c).filter(Tbl.c < 5).over(win),
fn.RANK().over(win),
fn.DENSE_RANK().over(win))
.window(win))
self.assertSQL(query, (
'SELECT SUM("t1"."c") FILTER (WHERE ("t1"."c" < ?)) OVER "w", '
'RANK() OVER "w", DENSE_RANK() OVER "w" '
'FROM "tbl" AS "t1" '
'WINDOW "w" AS (PARTITION BY COALESCE("t1"."a", ?) '
'RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING '
'EXCLUDE NO OTHERS)'), [5, ''])
class TestValuesList(BaseTestCase):
_data = [(1, 'one'), (2, 'two'), (3, 'three')]
def test_values_list(self):
vl = ValuesList(self._data)
query = vl.select(SQL('*'))
self.assertSQL(query, (
'SELECT * FROM (VALUES (?, ?), (?, ?), (?, ?)) AS "t1"'),
[1, 'one', 2, 'two', 3, 'three'])
def test_values_list_named_columns(self):
vl = ValuesList(self._data).columns('idx', 'name')
query = (vl
.select(vl.c.idx, vl.c.name)
.order_by(vl.c.idx))
self.assertSQL(query, (
'SELECT "t1"."idx", "t1"."name" '
'FROM (VALUES (?, ?), (?, ?), (?, ?)) AS "t1"("idx", "name") '
'ORDER BY "t1"."idx"'), [1, 'one', 2, 'two', 3, 'three'])
def test_named_values_list(self):
vl = ValuesList(self._data, ['idx', 'name']).alias('vl')
query = (vl
.select(vl.c.idx, vl.c.name)
.order_by(vl.c.idx))
self.assertSQL(query, (
'SELECT "vl"."idx", "vl"."name" '
'FROM (VALUES (?, ?), (?, ?), (?, ?)) AS "vl"("idx", "name") '
'ORDER BY "vl"."idx"'), [1, 'one', 2, 'two', 3, 'three'])
def test_docs_examples(self):
data = [(1, 'first'), (2, 'second')]
vl = ValuesList(data, columns=('idx', 'name'))
query = (vl
.select(vl.c.idx, vl.c.name)
.order_by(vl.c.idx))
self.assertSQL(query, (
'SELECT "t1"."idx", "t1"."name" '
'FROM (VALUES (?, ?), (?, ?)) AS "t1"("idx", "name") '
'ORDER BY "t1"."idx"'), [1, 'first', 2, 'second'])
vl = ValuesList([(1, 'first'), (2, 'second')])
vl = vl.columns('idx', 'name').alias('v')
query = vl.select(vl.c.idx, vl.c.name)
self.assertSQL(query, (
'SELECT "v"."idx", "v"."name" '
'FROM (VALUES (?, ?), (?, ?)) AS "v"("idx", "name")'),
[1, 'first', 2, 'second'])
def test_join_on_valueslist(self):
vl = ValuesList([('huey',), ('zaizee',)], columns=['username'])
query = (User
.select(vl.c.username)
.join(vl, on=(User.c.username == vl.c.username))
.order_by(vl.c.username.desc()))
self.assertSQL(query, (
'SELECT "t1"."username" FROM "users" AS "t2" '
'INNER JOIN (VALUES (?), (?)) AS "t1"("username") '
'ON ("t2"."username" = "t1"."username") '
'ORDER BY "t1"."username" DESC'), ['huey', 'zaizee'])
class TestCaseFunction(BaseTestCase):
def test_case_function(self):
NameNum = Table('nn', ('name', 'number'))
query = (NameNum
.select(NameNum.name, Case(NameNum.number, (
(1, 'one'),
(2, 'two')), '?').alias('num_str')))
self.assertSQL(query, (
'SELECT "t1"."name", CASE "t1"."number" '
'WHEN ? THEN ? '
'WHEN ? THEN ? '
'ELSE ? END AS "num_str" '
'FROM "nn" AS "t1"'), [1, 'one', 2, 'two', '?'])
query = (NameNum
.select(NameNum.name, Case(None, (
(NameNum.number == 1, 'one'),
(NameNum.number == 2, 'two')), '?')))
self.assertSQL(query, (
'SELECT "t1"."name", CASE '
'WHEN ("t1"."number" = ?) THEN ? '
'WHEN ("t1"."number" = ?) THEN ? '
'ELSE ? END '
'FROM "nn" AS "t1"'), [1, 'one', 2, 'two', '?'])
def test_multiple_case_expressions(self):
Sample = Table('sample', ('id', 'counter', 'value'))
case1 = Case(None, [
(Sample.counter < 5, 'low'),
(Sample.counter < 10, 'mid')],
'high').alias('tier')
case2 = Case(None, [
(Sample.value > 100, True)],
False).alias('is_large')
query = Sample.select(Sample.counter, case1, case2)
self.assertSQL(query, (
'SELECT "t1"."counter", '
'CASE WHEN ("t1"."counter" < ?) THEN ? '
'WHEN ("t1"."counter" < ?) THEN ? '
'ELSE ? END AS "tier", '
'CASE WHEN ("t1"."value" > ?) THEN ? '
'ELSE ? END AS "is_large" '
'FROM "sample" AS "t1"'),
[5, 'low', 10, 'mid', 'high', 100, True, False])
def test_case_subquery(self):
Name = Table('n', ('id', 'name',))
case = Case(None, [(Name.id.in_(Name.select(Name.id)), 1)], 0)
q = Name.select(fn.SUM(case))
self.assertSQL(q, (
'SELECT SUM('
'CASE WHEN ("t1"."id" IN (SELECT "t1"."id" FROM "n" AS "t1")) '
'THEN ? ELSE ? END) FROM "n" AS "t1"'), [1, 0])
case = Case(None, [
(Name.id < 5, Name.select(fn.SUM(Name.id))),
(Name.id > 5, Name.select(fn.COUNT(Name.name)).distinct())],
Name.select(fn.MAX(Name.id)))
q = Name.select(Name.name, case.alias('magic'))
self.assertSQL(q, (
'SELECT "t1"."name", CASE '
'WHEN ("t1"."id" < ?) '
'THEN (SELECT SUM("t1"."id") FROM "n" AS "t1") '
'WHEN ("t1"."id" > ?) '
'THEN (SELECT DISTINCT COUNT("t1"."name") FROM "n" AS "t1") '
'ELSE (SELECT MAX("t1"."id") FROM "n" AS "t1") END AS "magic" '
'FROM "n" AS "t1"'), [5, 5])
# ===========================================================================
# Miscellaneous SELECT features and expression SQL
# ===========================================================================
class TestSelectFeatures(BaseTestCase):
def test_reselect(self):
query = Person.select(Person.name)
self.assertSQL(query, 'SELECT "t1"."name" FROM "person" AS "t1"', [])
query = query.columns(Person.id, Person.name, Person.dob)
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."name", "t1"."dob" '
'FROM "person" AS "t1"'), [])
def test_distinct_on(self):
query = (Note
.select(Person.name, Note.content)
.join(Person, on=(Note.person_id == Person.id))
.order_by(Person.name, Note.content)
.distinct(Person.name))
self.assertSQL(query, (
'SELECT DISTINCT ON ("t1"."name") '
'"t1"."name", "t2"."content" '
'FROM "note" AS "t2" '
'INNER JOIN "person" AS "t1" ON ("t2"."person_id" = "t1"."id") '
'ORDER BY "t1"."name", "t2"."content"'), [])
query = (Person
.select(Person.name)
.distinct(Person.name))
self.assertSQL(query, (
'SELECT DISTINCT ON ("t1"."name") "t1"."name" '
'FROM "person" AS "t1"'), [])
def test_distinct(self):
query = Person.select(Person.name).distinct()
self.assertSQL(query,
'SELECT DISTINCT "t1"."name" FROM "person" AS "t1"', [])
def test_distinct_count(self):
query = Person.select(fn.COUNT(Person.name.distinct()))
self.assertSQL(query, (
'SELECT COUNT(DISTINCT "t1"."name") FROM "person" AS "t1"'), [])
def test_filtered_count(self):
filtered_count = (fn.COUNT(Person.name)
.filter(Person.dob < datetime.date(2000, 1, 1)))
query = Person.select(fn.COUNT(Person.name), filtered_count)
self.assertSQL(query, (
'SELECT COUNT("t1"."name"), COUNT("t1"."name") '
'FILTER (WHERE ("t1"."dob" < ?)) '
'FROM "person" AS "t1"'), [datetime.date(2000, 1, 1)])
def test_ordered_aggregate(self):
agg = fn.array_agg(Person.name).order_by(Person.id.desc())
self.assertSQL(Person.select(agg.alias('names')), (
'SELECT array_agg("t1"."name" ORDER BY "t1"."id" DESC) AS "names" '
'FROM "person" AS "t1"'), [])
agg = fn.string_agg(Person.name, ',').order_by(Person.dob, Person.id)
self.assertSQL(Person.select(agg), (
'SELECT string_agg("t1"."name", ? ORDER BY "t1"."dob", "t1"."id")'
' FROM "person" AS "t1"'), [','])
agg = (fn.string_agg(Person.name.concat('-x'), ',')
.order_by(Person.name.desc(), Person.dob.asc()))
self.assertSQL(Person.select(agg), (
'SELECT string_agg(("t1"."name" || ?), ? ORDER BY "t1"."name" DESC'
', "t1"."dob" ASC) '
'FROM "person" AS "t1"'), ['-x', ','])
agg = agg.order_by()
self.assertSQL(Person.select(agg), (
'SELECT string_agg(("t1"."name" || ?), ?) '
'FROM "person" AS "t1"'), ['-x', ','])
def test_for_update(self):
query = (Person
.select()
.where(Person.name == 'charlie')
.for_update())
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."name", "t1"."dob" '
'FROM "person" AS "t1" '
'WHERE ("t1"."name" = ?) '
'FOR UPDATE'), ['charlie'], for_update=True)
query = query.for_update('FOR SHARE NOWAIT')
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."name", "t1"."dob" '
'FROM "person" AS "t1" '
'WHERE ("t1"."name" = ?) '
'FOR SHARE NOWAIT'), ['charlie'], for_update=True)
def test_for_update_nested(self):
PA = Person.alias('pa')
subq = PA.select(PA.id).where(PA.name == 'charlie').for_update()
query = (Person
.delete()
.where(Person.id.in_(subq)))
self.assertSQL(query, (
'DELETE FROM "person" WHERE ("person"."id" IN ('
'SELECT "pa"."id" FROM "person" AS "pa" '
'WHERE ("pa"."name" = ?) FOR UPDATE))'),
['charlie'],
for_update=True)
def test_for_update_options(self):
query = (Person
.select(Person.id)
.where(Person.name == 'huey')
.for_update(of=Person, nowait=True))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "person" AS "t1" WHERE ("t1"."name" = ?) '
'FOR UPDATE OF "t1" NOWAIT'), ['huey'], for_update=True)
# Check default behavior.
query = query.for_update()
self.assertSQL(query, (
'SELECT "t1"."id" FROM "person" AS "t1" WHERE ("t1"."name" = ?) '
'FOR UPDATE'), ['huey'], for_update=True)
# Clear flag.
query = query.for_update(None)
self.assertSQL(query, (
'SELECT "t1"."id" FROM "person" AS "t1" WHERE ("t1"."name" = ?)'),
['huey'])
# Old-style is still supported.
query = query.for_update('FOR UPDATE NOWAIT')
self.assertSQL(query, (
'SELECT "t1"."id" FROM "person" AS "t1" WHERE ("t1"."name" = ?) '
'FOR UPDATE NOWAIT'), ['huey'], for_update=True)
# Mix of old and new is OK.
query = query.for_update('FOR SHARE NOWAIT', of=Person)
self.assertSQL(query, (
'SELECT "t1"."id" FROM "person" AS "t1" WHERE ("t1"."name" = ?) '
'FOR SHARE OF "t1" NOWAIT'), ['huey'], for_update=True)
def test_skip_locked_sql(self):
query = (Person
.select(Person.id)
.where(Person.name == 'huey')
.for_update(skip_locked=True))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "person" AS "t1" '
'WHERE ("t1"."name" = ?) '
'FOR UPDATE SKIP LOCKED'), ['huey'], for_update=True)
def test_skip_locked_string_parsing(self):
query = (Person
.select(Person.id)
.for_update('FOR SHARE SKIP LOCKED'))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "person" AS "t1" '
'FOR SHARE SKIP LOCKED'), for_update=True)
def test_nowait_and_skip_locked_error(self):
self.assertRaises(ValueError, ForUpdate, 'FOR UPDATE',
nowait=True, skip_locked=True)
def test_for_update_false_with_of_enables(self):
query = (Person
.select(Person.id)
.for_update(False, of=Person, nowait=True))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "person" AS "t1" '
'FOR UPDATE OF "t1" NOWAIT'), for_update=True)
def test_parentheses(self):
query = (Person
.select(fn.MAX(
fn.IFNULL(1, 10) * 151,
fn.IFNULL(None, 10))))
self.assertSQL(query, (
'SELECT MAX((IFNULL(?, ?) * ?), IFNULL(?, ?)) '
'FROM "person" AS "t1"'), [1, 10, 151, None, 10])
query = (Person
.select(Person.name)
.where(fn.EXISTS(
User.select(User.c.id).where(
User.c.username == Person.name))))
self.assertSQL(query, (
'SELECT "t1"."name" FROM "person" AS "t1" '
'WHERE EXISTS('
'SELECT "t2"."id" FROM "users" AS "t2" '
'WHERE ("t2"."username" = "t1"."name"))'), [])
class TestExpressionSQL(BaseTestCase):
def test_parentheses_functions(self):
expr = (User.c.income + 100)
expr2 = expr * expr
query = User.select(fn.sum(expr), fn.avg(expr2))
self.assertSQL(query, (
'SELECT sum("t1"."income" + ?), '
'avg(("t1"."income" + ?) * ("t1"."income" + ?)) '
'FROM "users" AS "t1"'), [100, 100, 100])
# ===========================================================================
# ON CONFLICT / upsert SQL (per-dialect: SQLite, MySQL, PostgreSQL)
# ===========================================================================
class TestOnConflictSqlite(BaseTestCase):
database = SqliteDatabase(None)
def test_replace_sqlite(self):
query = Person.insert(name='huey').on_conflict('replace')
self.assertSQL(query, (
'INSERT OR REPLACE INTO "person" ("name") VALUES (?)'), ['huey'])
def test_ignore(self):
query = Person.insert(name='huey').on_conflict('ignore')
self.assertSQL(query, (
'INSERT OR IGNORE INTO "person" ("name") VALUES (?)'), ['huey'])
def test_update_not_supported(self):
query = Person.insert(name='huey').on_conflict(
preserve=(Person.dob,),
update={Person.name: Person.name.concat(' (updated)')})
with self.assertRaisesCtx(ValueError):
self.database.get_sql_context().parse(query)
class TestOnConflictMySQL(BaseTestCase):
database = MySQLDatabase(None)
def setUp(self):
super(TestOnConflictMySQL, self).setUp()
self.database.server_version = None
def test_replace_mysql(self):
query = Person.insert(name='huey').on_conflict('replace')
self.assertSQL(query, (
'REPLACE INTO "person" ("name") VALUES (?)'), ['huey'])
def test_ignore(self):
query = Person.insert(name='huey').on_conflict('ignore')
self.assertSQL(query, (
'INSERT IGNORE INTO "person" ("name") VALUES (?)'), ['huey'])
def test_update(self):
dob = datetime.date(2010, 1, 1)
query = (Person
.insert(name='huey', dob=dob)
.on_conflict(
preserve=(Person.dob,),
update={Person.name: Person.name.concat('-x')}))
self.assertSQL(query, (
'INSERT INTO "person" ("dob", "name") VALUES (?, ?) '
'ON DUPLICATE KEY '
'UPDATE "dob" = VALUES("dob"), "name" = ("name" || ?)'),
[dob, 'huey', '-x'])
query = (Person
.insert(name='huey', dob=dob)
.on_conflict(preserve='dob'))
self.assertSQL(query, (
'INSERT INTO "person" ("dob", "name") VALUES (?, ?) '
'ON DUPLICATE KEY '
'UPDATE "dob" = VALUES("dob")'), [dob, 'huey'])
def test_update_use_value_mariadb(self):
# Verify that we use "VALUE" (not "VALUES") for MariaDB 10.3.3.
dob = datetime.date(2010, 1, 1)
query = (Person
.insert(name='huey', dob=dob)
.on_conflict(preserve=(Person.dob,)))
self.database.server_version = (10, 3, 3)
self.assertSQL(query, (
'INSERT INTO "person" ("dob", "name") VALUES (?, ?) '
'ON DUPLICATE KEY '
'UPDATE "dob" = VALUE("dob")'), [dob, 'huey'])
self.database.server_version = (10, 3, 2)
self.assertSQL(query, (
'INSERT INTO "person" ("dob", "name") VALUES (?, ?) '
'ON DUPLICATE KEY '
'UPDATE "dob" = VALUES("dob")'), [dob, 'huey'])
def test_where_not_supported(self):
query = Person.insert(name='huey').on_conflict(
preserve=(Person.dob,),
where=(Person.name == 'huey'))
with self.assertRaisesCtx(ValueError):
self.database.get_sql_context().parse(query)
class TestOnConflictPostgresql(BaseTestCase):
database = PostgresqlDatabase(None)
def test_ignore(self):
query = Person.insert(name='huey').on_conflict('ignore')
self.assertSQL(query, (
'INSERT INTO "person" ("name") VALUES (?) '
'ON CONFLICT DO NOTHING'), ['huey'])
def test_conflict_target_required(self):
query = Person.insert(name='huey').on_conflict(preserve=(Person.dob,))
with self.assertRaisesCtx(ValueError):
self.database.get_sql_context().parse(query)
def test_conflict_resolution_required(self):
query = Person.insert(name='huey').on_conflict(conflict_target='name')
with self.assertRaisesCtx(ValueError):
self.database.get_sql_context().parse(query)
def test_conflict_update_excluded(self):
KV = Table('kv', ('key', 'value', 'extra'), _database=self.database)
query = (KV.insert(key='k1', value='v1', extra=1)
.on_conflict(conflict_target=(KV.key, KV.value),
update={KV.extra: EXCLUDED.extra + 2},
where=(EXCLUDED.extra < KV.extra)))
self.assertSQL(query, (
'INSERT INTO "kv" ("extra", "key", "value") VALUES (?, ?, ?) '
'ON CONFLICT ("key", "value") DO UPDATE '
'SET "extra" = (EXCLUDED."extra" + ?) '
'WHERE (EXCLUDED."extra" < "kv"."extra")'), [1, 'k1', 'v1', 2])
def test_conflict_target_or_constraint(self):
KV = Table('kv', ('key', 'value', 'extra'), _database=self.database)
query = (KV.insert(key='k1', value='v1', extra='e1')
.on_conflict(conflict_target=[KV.key, KV.value],
preserve=[KV.extra]))
self.assertSQL(query, (
'INSERT INTO "kv" ("extra", "key", "value") VALUES (?, ?, ?) '
'ON CONFLICT ("key", "value") DO UPDATE '
'SET "extra" = EXCLUDED."extra"'), ['e1', 'k1', 'v1'])
query = (KV.insert(key='k1', value='v1', extra='e1')
.on_conflict(conflict_constraint='kv_key_value',
preserve=[KV.extra]))
self.assertSQL(query, (
'INSERT INTO "kv" ("extra", "key", "value") VALUES (?, ?, ?) '
'ON CONFLICT ON CONSTRAINT "kv_key_value" DO UPDATE '
'SET "extra" = EXCLUDED."extra"'), ['e1', 'k1', 'v1'])
query = KV.insert(key='k1', value='v1', extra='e1')
self.assertRaises(ValueError, query.on_conflict,
conflict_target=[KV.key, KV.value],
conflict_constraint='kv_key_value')
def test_update(self):
dob = datetime.date(2010, 1, 1)
query = (Person
.insert(name='huey', dob=dob)
.on_conflict(
conflict_target=(Person.name,),
preserve=(Person.dob,),
update={Person.name: Person.name.concat('-x')}))
self.assertSQL(query, (
'INSERT INTO "person" ("dob", "name") VALUES (?, ?) '
'ON CONFLICT ("name") DO '
'UPDATE SET "dob" = EXCLUDED."dob", '
'"name" = ("person"."name" || ?)'),
[dob, 'huey', '-x'])
query = (Person
.insert(name='huey', dob=dob)
.on_conflict(
conflict_target='name',
preserve='dob'))
self.assertSQL(query, (
'INSERT INTO "person" ("dob", "name") VALUES (?, ?) '
'ON CONFLICT ("name") DO '
'UPDATE SET "dob" = EXCLUDED."dob"'), [dob, 'huey'])
query = (Person
.insert(name='huey')
.on_conflict(
conflict_target=Person.name,
preserve=Person.dob,
update={Person.name: Person.name.concat('-x')},
where=(Person.name != 'zaizee')))
self.assertSQL(query, (
'INSERT INTO "person" ("name") VALUES (?) '
'ON CONFLICT ("name") DO '
'UPDATE SET "dob" = EXCLUDED."dob", '
'"name" = ("person"."name" || ?) '
'WHERE ("person"."name" != ?)'), ['huey', '-x', 'zaizee'])
def test_conflict_target_partial_index(self):
KVE = Table('kve', ('key', 'value', 'extra'))
data = [('k1', 1, 2), ('k2', 2, 3)]
columns = [KVE.key, KVE.value, KVE.extra]
query = (KVE
.insert(data, columns)
.on_conflict(
conflict_target=(KVE.key, KVE.value),
conflict_where=(KVE.extra > 1),
preserve=(KVE.extra,),
where=(KVE.key != 'kx')))
self.assertSQL(query, (
'INSERT INTO "kve" ("key", "value", "extra") '
'VALUES (?, ?, ?), (?, ?, ?) '
'ON CONFLICT ("key", "value") WHERE ("extra" > ?) '
'DO UPDATE SET "extra" = EXCLUDED."extra" '
'WHERE ("kve"."key" != ?)'),
['k1', 1, 2, 'k2', 2, 3, 1, 'kx'])
# ===========================================================================
# Index generation
# ===========================================================================
class TestIndex(BaseTestCase):
def test_simple_index(self):
pidx = Index('person_name', Person, (Person.name,), unique=True)
self.assertSQL(pidx, (
'CREATE UNIQUE INDEX "person_name" ON "person" ("name")'), [])
pidx = pidx.where(Person.dob > datetime.date(1950, 1, 1))
self.assertSQL(pidx, (
'CREATE UNIQUE INDEX "person_name" ON "person" '
'("name") WHERE ("dob" > ?)'), [datetime.date(1950, 1, 1)])
def test_advanced_index(self):
Article = Table('article')
aidx = Index('foo_idx', Article, (
Article.c.status,
Article.c.timestamp.desc(),
fn.SUBSTR(Article.c.title, 1, 1)), safe=True)
self.assertSQL(aidx, (
'CREATE INDEX IF NOT EXISTS "foo_idx" ON "article" '
'("status", "timestamp" DESC, SUBSTR("title", ?, ?))'), [1, 1])
aidx = aidx.where(Article.c.flags.bin_and(4) == 4)
self.assertSQL(aidx, (
'CREATE INDEX IF NOT EXISTS "foo_idx" ON "article" '
'("status", "timestamp" DESC, SUBSTR("title", ?, ?)) '
'WHERE (("flags" & ?) = ?)'), [1, 1, 4, 4])
# Check behavior when value-literals are enabled.
self.assertSQL(aidx, (
'CREATE INDEX IF NOT EXISTS "foo_idx" ON "article" '
'("status", "timestamp" DESC, SUBSTR("title", 1, 1)) '
'WHERE (("flags" & 4) = 4)'), [], value_literals=True)
def test_str_cols(self):
uidx = Index('users_info', User, ('username DESC', 'id'))
self.assertSQL(uidx, (
'CREATE INDEX "users_info" ON "users" (username DESC, id)'), [])
# ===========================================================================
# Utilities and edge cases
# ===========================================================================
class TestSqlToString(BaseTestCase):
def _test_sql_to_string(self, _param):
class FakeDB(SqliteDatabase):
param = _param
db = FakeDB(None)
T = Table('tbl', ('id', 'val')).bind(db)
query = (T.select()
.where((T.val == 'foo') |
(T.val == b'bar') |
(T.val == True) | (T.val == False) |
(T.val == 2) |
(T.val == -3.14) |
(T.val == datetime.datetime(2018, 1, 1)) |
(T.val == datetime.date(2018, 1, 2)) |
T.val.is_null() |
T.val.is_null(False) |
T.val.in_(['aa', 'bb', 'cc'])))
self.assertEqual(query_to_string(query), (
'SELECT "t1"."id", "t1"."val" FROM "tbl" AS "t1" WHERE ((((((((((('
'"t1"."val" = \'foo\') OR '
'("t1"."val" = \'bar\')) OR '
'("t1"."val" = 1)) OR '
'("t1"."val" = 0)) OR '
'("t1"."val" = 2)) OR '
'("t1"."val" = -3.14)) OR '
'("t1"."val" = \'2018-01-01 00:00:00\')) OR '
'("t1"."val" = \'2018-01-02\')) OR '
'("t1"."val" IS NULL)) OR '
'("t1"."val" IS NOT NULL)) OR '
'("t1"."val" IN (\'aa\', \'bb\', \'cc\')))'))
def test_sql_to_string_qmark(self):
self._test_sql_to_string('?')
def test_sql_to_string_default(self):
self._test_sql_to_string('%s')
class TestSubqueryFunctionCall(BaseTestCase):
def test_subquery_function_call(self):
Sample = Table('sample')
SA = Sample.alias('s2')
query = (Sample
.select(Sample.c.data)
.where(~fn.EXISTS(
SA.select(SQL('1')).where(SA.c.key == 'foo'))))
self.assertSQL(query, (
'SELECT "t1"."data" FROM "sample" AS "t1" '
'WHERE NOT EXISTS('
'SELECT 1 FROM "sample" AS "s2" WHERE ("s2"."key" = ?))'), ['foo'])
class TestFunctionInfiniteLoop(BaseTestCase):
def test_function_infinite_loop(self):
self.assertRaises(TypeError, lambda: list(fn.COUNT()))
# ===========================================================================
# Gap coverage: Node fundamentals
# ===========================================================================
class TestNodeClone(BaseTestCase):
def test_clone_produces_independent_copy(self):
t = Table('t')
query = t.select(t.c.id).where(t.c.id > 1)
clone = query.clone()
# Mutating the clone's where should not affect original.
clone = clone.where(t.c.id < 10)
sql1, p1 = __sql__(query)
sql2, p2 = __sql__(clone)
self.assertNotEqual(sql1, sql2)
self.assertEqual(p1, [1])
self.assertEqual(p2, [1, 10])
def test_unwrap_and_is_alias(self):
t = Table('t1')
col = t.c.id
self.assertIs(col.unwrap(), col)
self.assertFalse(col.is_alias())
# ===========================================================================
# Gap coverage: Table and Source operations
# ===========================================================================
class TestTableOperations(BaseTestCase):
def test_table_clone(self):
t = Table('users', columns=['id', 'name'], primary_key='id',
schema='public', alias='u')
clone = t.clone()
self.assertEqual(clone.__name__, 'users')
self.assertEqual(clone._columns, ['id', 'name'])
self.assertEqual(clone._primary_key, 'id')
self.assertEqual(clone._schema, 'public')
self.assertEqual(clone._alias, 'u')
self.assertTrue(t is not clone)
def test_table_explicit_column_error(self):
t = Table('users', columns=['id', 'name'])
# Direct attribute access works.
self.assertSQL(t.select(t.id, t.name),
'SELECT "t1"."id", "t1"."name" FROM "users" AS "t1"')
# Dynamic column access via .c should raise.
self.assertRaises(AttributeError, lambda: t.c.id)
# ===========================================================================
# Gap coverage: Alias, Negated, Cast, and wrapped node types
# ===========================================================================
class TestWrappedNodes(BaseTestCase):
def test_alias_realias_to_none(self):
col = User.c.id
aliased = col.alias('uid')
self.assertIsInstance(aliased, Alias)
self.assertTrue(aliased.is_alias())
unwrapped = aliased.alias(None)
# Should return the original column, not an Alias.
self.assertNotIsInstance(unwrapped, Alias)
self.assertFalse(unwrapped.is_alias())
self.assertTrue(unwrapped is col)
def test_alias_realias_to_new_name(self):
col = User.c.id
a1 = col.alias('uid')
a2 = a1.alias('user_id')
self.assertIsInstance(a2, Alias)
self.assertEqual(a2._alias, 'user_id')
def test_alias_unalias(self):
col = User.c.id
aliased = col.alias('uid')
self.assertIs(aliased.unalias(), col)
def test_alias_name_property(self):
col = User.c.id
aliased = col.alias('uid')
self.assertEqual(aliased.name, 'uid')
aliased.name = 'new_name'
self.assertEqual(aliased.name, 'new_name')
self.assertEqual(aliased._alias, 'new_name')
def test_double_negation(self):
# ~~expr returns the original expression (Negated.__invert__).
col = User.c.active
negated = ~col
self.assertIsInstance(negated, Negated)
double_neg = ~negated
# Should unwrap back to original.
self.assertIs(double_neg, col)
def test_negated_sql(self):
query = User.select(User.c.id).where(~(User.c.active == True))
self.assertSQL(query, (
'SELECT "t1"."id" FROM "users" AS "t1" '
'WHERE NOT ("t1"."active" = ?)'), [True])
def test_cast_sql(self):
expr = User.c.age.cast('TEXT')
query = User.select(expr.alias('age_text'))
self.assertSQL(query, (
'SELECT CAST("t1"."age" AS TEXT) AS "age_text" '
'FROM "users" AS "t1"'))
def test_value_literals(self):
from peewee import ValueLiterals
expr = ValueLiterals(Value(42))
self.assertSQL(expr, '42', [])
class TestQueryBuilderMisc(BaseTestCase):
def test_entity_chained_getattr(self):
e = Entity('schema', 'table')
e2 = e.column
self.assertIsInstance(e2, Entity)
self.assertSQL(e2, '"schema"."table"."column"')
def test_check_with_name(self):
c = Check('val > 0', name='positive_val')
self.assertSQL(c, 'CONSTRAINT "positive_val" CHECK (val > 0)', [])
def test_check_without_name(self):
c = Check('val > 0')
self.assertSQL(c, 'CHECK (val > 0)', [])
def test_nodelist_empty(self):
nl = NodeList([])
self.assertSQL(nl, '')
nl = NodeList([], parens=True)
self.assertSQL(nl, '()')
def test_namespace_attribute_sql(self):
expr = EXCLUDED.name
self.assertSQL(expr, 'EXCLUDED."name"')
# ===========================================================================
# Gap coverage: Context and AliasManager internals
# ===========================================================================
class TestContextAndAliasManager(BaseTestCase):
def test_alias_manager_pop_empty_error(self):
am = AliasManager()
self.assertRaises(ValueError, am.pop)
def test_alias_manager_push_pop(self):
am = AliasManager()
t = Table('users')
alias1 = am.add(t)
self.assertEqual(alias1, 't1')
am.push()
t2 = Table('tweets')
alias2 = am.add(t2)
self.assertEqual(alias2, 't2')
am.pop()
# After pop, the inner mapping is cleared, but counter persists.
t3 = Table('notes')
alias3 = am.add(t3)
self.assertEqual(alias3, 't3')
def test_alias_manager_get_any_depth(self):
am = AliasManager()
t = Table('users')
am.add(t)
am.push()
# From inner scope, source is not in current mapping.
result = am.get(t, any_depth=True)
self.assertEqual(result, 't1')
def test_hashable_source_equality(self):
t1 = Table('users')
t2 = Table('users')
# Same table name, no alias - should be equal.
self.assertEqual(t1, t2)
self.assertEqual(hash(t1), hash(t2))
# Aliased differently - should not be equal.
t3 = t1.alias('u')
self.assertNotEqual(t1, t3)
def test_hashable_source_in_set(self):
t1 = Table('users')
t2 = Table('users')
t3 = Table('tweets')
s = {t1, t3}
self.assertIn(t2, s) # Same as t1.
self.assertEqual(len(s), 2)