mirror of
https://github.com/coleifer/peewee.git
synced 2026-05-06 07:56:41 -04:00
6949 lines
248 KiB
Python
6949 lines
248 KiB
Python
import datetime
|
|
import threading
|
|
import time
|
|
import unittest
|
|
import uuid
|
|
from unittest import mock
|
|
|
|
from peewee import *
|
|
from peewee import Entity
|
|
from peewee import NodeList
|
|
from peewee import SubclassAwareMetadata
|
|
from peewee import __sqlite_version__
|
|
from peewee import sort_models
|
|
|
|
from .base import db
|
|
from .base import get_in_memory_db
|
|
from .base import new_connection
|
|
from .base import requires_models
|
|
from .base import requires_mysql
|
|
from .base import requires_pglike
|
|
from .base import requires_postgresql
|
|
from .base import requires_sqlite
|
|
from .base import skip_if
|
|
from .base import skip_unless
|
|
from .base import BaseTestCase
|
|
from .base import IS_CRDB
|
|
from .base import IS_MYSQL
|
|
from .base import IS_MYSQL_ADVANCED_FEATURES
|
|
from .base import IS_POSTGRESQL
|
|
from .base import IS_SQLITE
|
|
from .base import IS_SQLITE_OLD
|
|
from .base import IS_SQLITE_15 # Row-values.
|
|
from .base import IS_SQLITE_24 # Upsert.
|
|
from .base import IS_SQLITE_25 # Window functions.
|
|
from .base import IS_SQLITE_30 # FILTER clause functions.
|
|
from .base import IS_SQLITE_35 # RETURNING.
|
|
from .base import IS_SQLITE_9
|
|
from .base import ModelTestCase
|
|
from .base import TestModel
|
|
from .base_models import *
|
|
|
|
|
|
|
|
# ===========================================================================
|
|
# Core Model CRUD operations
|
|
# ===========================================================================
|
|
|
|
class Color(TestModel):
|
|
name = CharField(primary_key=True)
|
|
is_neutral = BooleanField(default=False)
|
|
|
|
|
|
class Post(TestModel):
|
|
content = TextField(column_name='Content')
|
|
timestamp = DateTimeField(column_name='TimeStamp',
|
|
default=datetime.datetime.now)
|
|
|
|
|
|
class PostNote(TestModel):
|
|
post = ForeignKeyField(Post, backref='notes', primary_key=True)
|
|
note = TextField()
|
|
|
|
|
|
class Point(TestModel):
|
|
x = IntegerField()
|
|
y = IntegerField()
|
|
class Meta:
|
|
primary_key = False
|
|
|
|
|
|
class CPK(TestModel):
|
|
key = CharField()
|
|
value = IntegerField()
|
|
extra = IntegerField()
|
|
class Meta:
|
|
primary_key = CompositeKey('key', 'value')
|
|
|
|
|
|
class City(TestModel):
|
|
name = CharField()
|
|
|
|
class Venue(TestModel):
|
|
name = CharField()
|
|
city = ForeignKeyField(City, backref='venues')
|
|
city_n = ForeignKeyField(City, backref='venues_n', null=True)
|
|
|
|
class Event(TestModel):
|
|
name = CharField()
|
|
venue = ForeignKeyField(Venue, backref='events', null=True)
|
|
class TestModelAPIs(ModelTestCase):
|
|
def add_user(self, username):
|
|
return User.create(username=username)
|
|
|
|
def add_tweets(self, user, *tweets):
|
|
accum = []
|
|
for tweet in tweets:
|
|
accum.append(Tweet.create(user=user, content=tweet))
|
|
return accum
|
|
|
|
@requires_models(Point)
|
|
def test_no_primary_key(self):
|
|
p11 = Point.create(x=1, y=1)
|
|
p33 = Point.create(x=3, y=3)
|
|
|
|
p_db = Point.get((Point.x == 3) & (Point.y == 3))
|
|
self.assertEqual(p_db.x, 3)
|
|
self.assertEqual(p_db.y, 3)
|
|
|
|
@requires_models(Post, PostNote)
|
|
def test_pk_is_fk(self):
|
|
with self.database.atomic():
|
|
p1 = Post.create(content='p1')
|
|
p2 = Post.create(content='p2')
|
|
p1n = PostNote.create(post=p1, note='p1n')
|
|
p2n = PostNote.create(post=p2, note='p2n')
|
|
|
|
with self.assertQueryCount(2):
|
|
pn = PostNote.get(PostNote.note == 'p1n')
|
|
self.assertEqual(pn.post.content, 'p1')
|
|
|
|
with self.assertQueryCount(1):
|
|
pn = (PostNote
|
|
.select(PostNote, Post)
|
|
.join(Post)
|
|
.where(PostNote.note == 'p2n')
|
|
.get())
|
|
self.assertEqual(pn.post.content, 'p2')
|
|
|
|
if not IS_SQLITE:
|
|
exc_class = (ProgrammingError, IntegrityError)
|
|
with self.database.atomic() as txn:
|
|
self.assertRaises(exc_class, PostNote.create, note='pxn')
|
|
txn.rollback()
|
|
|
|
@requires_models(Post)
|
|
def test_column_field_translation(self):
|
|
ts = datetime.datetime(2017, 2, 1, 13, 37)
|
|
ts2 = datetime.datetime(2017, 2, 2, 13, 37)
|
|
p = Post.create(content='p1', timestamp=ts)
|
|
p2 = Post.create(content='p2', timestamp=ts2)
|
|
|
|
p_db = Post.get(Post.content == 'p1')
|
|
self.assertEqual(p_db.content, 'p1')
|
|
self.assertEqual(p_db.timestamp, ts)
|
|
|
|
pd1, pd2 = Post.select().order_by(Post.id).dicts()
|
|
self.assertEqual(pd1['content'], 'p1')
|
|
self.assertEqual(pd1['timestamp'], ts)
|
|
self.assertEqual(pd2['content'], 'p2')
|
|
self.assertEqual(pd2['timestamp'], ts2)
|
|
|
|
def test_table_schema(self):
|
|
class Schema(TestModel):
|
|
pass
|
|
|
|
self.assertTrue(Schema._meta.schema is None)
|
|
self.assertSQL(Schema.select(), (
|
|
'SELECT "t1"."id" FROM "schema" AS "t1"'), [])
|
|
|
|
Schema._meta.schema = 'test'
|
|
self.assertSQL(Schema.select(), (
|
|
'SELECT "t1"."id" FROM "test"."schema" AS "t1"'), [])
|
|
|
|
Schema._meta.schema = 'another'
|
|
self.assertSQL(Schema.select(), (
|
|
'SELECT "t1"."id" FROM "another"."schema" AS "t1"'), [])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_create(self):
|
|
with self.assertQueryCount(1):
|
|
huey = self.add_user('huey')
|
|
self.assertEqual(huey.username, 'huey')
|
|
self.assertTrue(isinstance(huey.id, int))
|
|
self.assertTrue(huey.id > 0)
|
|
|
|
with self.assertQueryCount(1):
|
|
tweet = Tweet.create(user=huey, content='meow')
|
|
self.assertEqual(tweet.user.id, huey.id)
|
|
self.assertEqual(tweet.user.username, 'huey')
|
|
self.assertEqual(tweet.content, 'meow')
|
|
self.assertTrue(isinstance(tweet.id, int))
|
|
self.assertTrue(tweet.id > 0)
|
|
|
|
@requires_models(User)
|
|
def test_insert_many(self):
|
|
data = [('u%02d' % i,) for i in range(100)]
|
|
with self.database.atomic():
|
|
for chunk in chunked(data, 10):
|
|
User.insert_many(chunk).execute()
|
|
|
|
self.assertEqual(User.select().count(), 100)
|
|
names = [u.username for u in User.select().order_by(User.username)]
|
|
self.assertEqual(names, ['u%02d' % i for i in range(100)])
|
|
|
|
@requires_models(DfltM)
|
|
def test_insert_many_defaults_nullable(self):
|
|
data = [
|
|
{'name': 'd1'},
|
|
{'name': 'd2', 'dflt1': 10},
|
|
{'name': 'd3', 'dflt2': 30},
|
|
{'name': 'd4', 'dfltn': 40}]
|
|
fields = [DfltM.name, DfltM.dflt1, DfltM.dflt2, DfltM.dfltn]
|
|
DfltM.insert_many(data, fields).execute()
|
|
|
|
expected = [
|
|
('d1', 1, 2, None),
|
|
('d2', 10, 2, None),
|
|
('d3', 1, 30, None),
|
|
('d4', 1, 2, 40)]
|
|
query = DfltM.select().order_by(DfltM.name)
|
|
actual = [(d.name, d.dflt1, d.dflt2, d.dfltn) for d in query]
|
|
self.assertEqual(actual, expected)
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_insert_query_value(self):
|
|
huey = self.add_user('huey')
|
|
query = User.select(User.id).where(User.username == 'huey')
|
|
tid = Tweet.insert(content='meow', user=query).execute()
|
|
tweet = Tweet[tid]
|
|
self.assertEqual(tweet.user.id, huey.id)
|
|
self.assertEqual(tweet.user.username, 'huey')
|
|
|
|
@requires_models(User)
|
|
def test_insert_rowcount(self):
|
|
User.create(username='u0') # Ensure that last insert ID != rowcount.
|
|
|
|
iq = User.insert_many([(u,) for u in ('u1', 'u2', 'u3')])
|
|
self.assertEqual(iq.as_rowcount().execute(), 3)
|
|
|
|
# Now explicitly specify empty returning() for all DBs.
|
|
iq = User.insert_many([(u,) for u in ('u4', 'u5')]).returning()
|
|
self.assertEqual(iq.as_rowcount().execute(), 2)
|
|
|
|
query = (User
|
|
.select(User.username.concat('-x'))
|
|
.where(User.username.in_(['u1', 'u2'])))
|
|
iq = User.insert_from(query, ['username'])
|
|
self.assertEqual(iq.as_rowcount().execute(), 2)
|
|
|
|
query = (User
|
|
.select(User.username.concat('-y'))
|
|
.where(User.username.in_(['u3', 'u4'])))
|
|
iq = User.insert_from(query, ['username']).returning()
|
|
self.assertEqual(iq.as_rowcount().execute(), 2)
|
|
|
|
query = User.insert({'username': 'u5'})
|
|
self.assertEqual(query.as_rowcount().execute(), 1)
|
|
|
|
@skip_if(IS_POSTGRESQL or IS_CRDB, 'requires sqlite or mysql')
|
|
@requires_models(Emp)
|
|
def test_replace_rowcount(self):
|
|
Emp.create(first='beanie', last='cat', empno='998')
|
|
|
|
data = [
|
|
('beanie', 'cat', '999'),
|
|
('mickey', 'dog', '123')]
|
|
fields = (Emp.first, Emp.last, Emp.empno)
|
|
|
|
# MySQL returns 3, Sqlite 2. However, older stdlib sqlite3 does not
|
|
# work properly, so we don't assert a result count here.
|
|
Emp.replace_many(data, fields=fields).execute()
|
|
|
|
query = Emp.select(Emp.first, Emp.last, Emp.empno).order_by(Emp.last)
|
|
self.assertEqual(list(query.tuples()), [
|
|
('beanie', 'cat', '999'),
|
|
('mickey', 'dog', '123')])
|
|
|
|
@requires_models(User)
|
|
def test_bulk_create(self):
|
|
users = [User(username='u%s' % i) for i in range(5)]
|
|
self.assertEqual(User.select().count(), 0)
|
|
|
|
with self.assertQueryCount(1):
|
|
User.bulk_create(users)
|
|
|
|
self.assertEqual(User.select().count(), 5)
|
|
self.assertEqual([u.username for u in User.select().order_by(User.id)],
|
|
['u0', 'u1', 'u2', 'u3', 'u4'])
|
|
|
|
if IS_POSTGRESQL:
|
|
self.assertEqual([u.id for u in User.select().order_by(User.id)],
|
|
[user.id for user in users])
|
|
|
|
@requires_models(User)
|
|
def test_bulk_create_empty(self):
|
|
self.assertEqual(User.select().count(), 0)
|
|
User.bulk_create([])
|
|
|
|
@requires_models(User)
|
|
def test_bulk_create_batching(self):
|
|
users = [User(username=str(i)) for i in range(10)]
|
|
with self.assertQueryCount(4):
|
|
User.bulk_create(users, 3)
|
|
|
|
self.assertEqual(User.select().count(), 10)
|
|
self.assertEqual([u.username for u in User.select().order_by(User.id)],
|
|
list('0123456789'))
|
|
|
|
if IS_POSTGRESQL:
|
|
self.assertEqual([u.id for u in User.select().order_by(User.id)],
|
|
[user.id for user in users])
|
|
|
|
@requires_models(Person)
|
|
def test_bulk_create_error(self):
|
|
people = [Person(first='a', last='b'),
|
|
Person(first='b', last='c'),
|
|
Person(first='a', last='b')]
|
|
with self.assertRaises(IntegrityError):
|
|
with self.database.atomic():
|
|
Person.bulk_create(people)
|
|
self.assertEqual(Person.select().count(), 0)
|
|
|
|
@requires_models(CPK)
|
|
def test_bulk_create_composite_key(self):
|
|
self.assertEqual(CPK.select().count(), 0)
|
|
items = [CPK(key='k1', value=1, extra=1),
|
|
CPK(key='k2', value=2, extra=2)]
|
|
CPK.bulk_create(items)
|
|
self.assertEqual([(c.key, c.value, c.extra) for c in items],
|
|
[('k1', 1, 1), ('k2', 2, 2)])
|
|
|
|
query = CPK.select().order_by(CPK.key).tuples()
|
|
self.assertEqual(list(query), [('k1', 1, 1), ('k2', 2, 2)])
|
|
|
|
@requires_models(Person)
|
|
def test_save(self):
|
|
huey = Person(first='huey', last='cat', dob=datetime.date(2010, 7, 1))
|
|
self.assertTrue(huey.save() > 0)
|
|
self.assertTrue(huey.id is not None) # Ensure PK is set.
|
|
orig_id = huey.id
|
|
|
|
# Test initial save (INSERT) worked and data is all present.
|
|
huey_db = Person.get(first='huey', last='cat')
|
|
self.assertEqual(huey_db.id, huey.id)
|
|
self.assertEqual(huey_db.first, 'huey')
|
|
self.assertEqual(huey_db.last, 'cat')
|
|
self.assertEqual(huey_db.dob, datetime.date(2010, 7, 1))
|
|
|
|
# Make a change and do a second save (UPDATE).
|
|
huey.dob = datetime.date(2010, 7, 2)
|
|
self.assertTrue(huey.save() > 0)
|
|
self.assertEqual(huey.id, orig_id)
|
|
|
|
# Test UPDATE worked correctly.
|
|
huey_db = Person.get(first='huey', last='cat')
|
|
self.assertEqual(huey_db.id, huey.id)
|
|
self.assertEqual(huey_db.first, 'huey')
|
|
self.assertEqual(huey_db.last, 'cat')
|
|
self.assertEqual(huey_db.dob, datetime.date(2010, 7, 2))
|
|
|
|
self.assertEqual(Person.select().count(), 1)
|
|
|
|
@requires_models(Person)
|
|
def test_save_only(self):
|
|
huey = Person(first='huey', last='cat', dob=datetime.date(2010, 7, 1))
|
|
huey.save()
|
|
|
|
huey.first = 'huker'
|
|
huey.last = 'kitten'
|
|
self.assertTrue(huey.save(only=('first',)) > 0)
|
|
|
|
huey_db = Person.get_by_id(huey.id)
|
|
self.assertEqual(huey_db.first, 'huker')
|
|
self.assertEqual(huey_db.last, 'cat')
|
|
self.assertEqual(huey_db.dob, datetime.date(2010, 7, 1))
|
|
|
|
huey.first = 'hubie'
|
|
self.assertTrue(huey.save(only=[Person.last]) > 0)
|
|
|
|
huey_db = Person.get_by_id(huey.id)
|
|
self.assertEqual(huey_db.first, 'huker')
|
|
self.assertEqual(huey_db.last, 'kitten')
|
|
self.assertEqual(huey_db.dob, datetime.date(2010, 7, 1))
|
|
|
|
self.assertEqual(Person.select().count(), 1)
|
|
|
|
@requires_models(Color, User)
|
|
def test_save_force(self):
|
|
huey = User(username='huey')
|
|
self.assertTrue(huey.save() > 0)
|
|
huey_id = huey.id
|
|
|
|
huey.username = 'zaizee'
|
|
self.assertTrue(huey.save(force_insert=True, only=('username',)) > 0)
|
|
zaizee_id = huey.id
|
|
self.assertTrue(huey_id != zaizee_id)
|
|
|
|
query = User.select().order_by(User.username)
|
|
self.assertEqual([user.username for user in query], ['huey', 'zaizee'])
|
|
|
|
color = Color(name='red')
|
|
self.assertFalse(bool(color.save()))
|
|
self.assertEqual(Color.select().count(), 0)
|
|
|
|
color = Color(name='blue')
|
|
color.save(force_insert=True)
|
|
self.assertEqual(Color.select().count(), 1)
|
|
|
|
with self.database.atomic():
|
|
self.assertRaises(IntegrityError,
|
|
color.save,
|
|
force_insert=True)
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_populate_unsaved_relations(self):
|
|
user = User(username='charlie')
|
|
tweet = Tweet(user=user, content='foo')
|
|
|
|
self.assertTrue(user.save())
|
|
self.assertTrue(user.id is not None)
|
|
with self.assertQueryCount(1):
|
|
self.assertEqual(tweet.user_id, user.id)
|
|
self.assertTrue(tweet.save())
|
|
self.assertEqual(tweet.user_id, user.id)
|
|
|
|
tweet_db = Tweet.get(Tweet.content == 'foo')
|
|
self.assertEqual(tweet_db.user.username, 'charlie')
|
|
|
|
@requires_models(Person)
|
|
def test_bulk_update(self):
|
|
data = [('f%s' % i, 'l%s' % i, datetime.date(1980, i, i))
|
|
for i in range(1, 5)]
|
|
Person.insert_many(data).execute()
|
|
|
|
p1, p2, p3, p4 = list(Person.select().order_by(Person.id))
|
|
p1.first = 'f1-x'
|
|
p1.last = 'l1-x'
|
|
p2.first = 'f2-y'
|
|
p3.last = 'l3-z'
|
|
|
|
with self.assertQueryCount(1):
|
|
n = Person.bulk_update([p1, p2, p3, p4], ['first', 'last'])
|
|
self.assertEqual(n, 3 if IS_MYSQL else 4)
|
|
|
|
query = Person.select().order_by(Person.id)
|
|
self.assertEqual([(p.first, p.last) for p in query], [
|
|
('f1-x', 'l1-x'),
|
|
('f2-y', 'l2'),
|
|
('f3', 'l3-z'),
|
|
('f4', 'l4')])
|
|
|
|
# Modify multiple fields, but only update "first".
|
|
p1.first = 'f1-x2'
|
|
p1.last = 'l1-x2'
|
|
p2.first = 'f2-y2'
|
|
p3.last = 'f3-z2'
|
|
with self.assertQueryCount(2): # Two batches, so two queries.
|
|
n = Person.bulk_update([p1, p2, p3, p4], [Person.first], 2)
|
|
self.assertEqual(n, 2 if IS_MYSQL else 4)
|
|
|
|
query = Person.select().order_by(Person.id)
|
|
self.assertEqual([(p.first, p.last) for p in query], [
|
|
('f1-x2', 'l1-x'),
|
|
('f2-y2', 'l2'),
|
|
('f3', 'l3-z'),
|
|
('f4', 'l4')])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_bulk_update_foreign_key(self):
|
|
for username in ('charlie', 'huey', 'zaizee'):
|
|
user = User.create(username=username)
|
|
for i in range(2):
|
|
Tweet.create(user=user, content='%s-%s' % (username, i))
|
|
|
|
c, h, z = list(User.select().order_by(User.id))
|
|
c0, c1, h0, h1, z0, z1 = list(Tweet.select().order_by(Tweet.id))
|
|
c0.content = 'charlie-0x'
|
|
c1.user = h
|
|
h0.user = z
|
|
h1.content = 'huey-1x'
|
|
z0.user = c
|
|
z0.content = 'zaizee-0x'
|
|
|
|
with self.assertQueryCount(1):
|
|
Tweet.bulk_update([c0, c1, h0, h1, z0, z1], ['user', 'content'])
|
|
|
|
query = (Tweet
|
|
.select(Tweet.content, User.username)
|
|
.join(User)
|
|
.order_by(Tweet.id)
|
|
.objects())
|
|
self.assertEqual([(t.username, t.content) for t in query], [
|
|
('charlie', 'charlie-0x'),
|
|
('huey', 'charlie-1'),
|
|
('zaizee', 'huey-0'),
|
|
('huey', 'huey-1x'),
|
|
('charlie', 'zaizee-0x'),
|
|
('zaizee', 'zaizee-1')])
|
|
|
|
@requires_models(Person)
|
|
def test_bulk_update_integrityerror(self):
|
|
people = [Person(first='f%s' % i, last='l%s' % i, dob='1980-01-01')
|
|
for i in range(10)]
|
|
Person.bulk_create(people)
|
|
|
|
# Get list of people w/the IDs populated. They will not be set if the
|
|
# underlying DB is Sqlite or MySQL.
|
|
people = list(Person.select().order_by(Person.id))
|
|
|
|
# First we'll just modify all the first and last names.
|
|
for person in people:
|
|
person.first += '-x'
|
|
person.last += '-x'
|
|
|
|
# Now we'll introduce an issue that will cause an integrity error.
|
|
p3, p7 = people[3], people[7]
|
|
p3.first = p7.first = 'fx'
|
|
p3.last = p7.last = 'lx'
|
|
with self.assertRaises(IntegrityError):
|
|
with self.assertQueryCount(1):
|
|
with self.database.atomic():
|
|
Person.bulk_update(people, fields=['first', 'last'])
|
|
|
|
with self.assertRaises(IntegrityError):
|
|
# 10 objects, batch size=4, so 0-3, 4-7, 8&9. But we never get to 8
|
|
# and 9 because of the integrity error processing the 2nd batch.
|
|
with self.assertQueryCount(2):
|
|
with self.database.atomic():
|
|
Person.bulk_update(people, ['first', 'last'], 4)
|
|
|
|
# Ensure no changes were made.
|
|
vals = [(p.first, p.last) for p in Person.select().order_by(Person.id)]
|
|
self.assertEqual(vals, [('f%s' % i, 'l%s' % i) for i in range(10)])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_bulk_update_apply_dbvalue(self):
|
|
u = User.create(username='u')
|
|
t1, t2, t3 = [Tweet.create(user=u, content=str(i)) for i in (1, 2, 3)]
|
|
|
|
# If we don't end up applying the field's db_value() to these timestamp
|
|
# values, then we will end up with bad data or an error when attempting
|
|
# to do the update.
|
|
t1.timestamp = datetime.datetime(2019, 1, 2, 3, 4, 5)
|
|
t2.timestamp = datetime.date(2019, 1, 3)
|
|
t3.timestamp = 1337133700 # 2012-05-15T21:1:40.
|
|
t3_dt = datetime.datetime.fromtimestamp(1337133700)
|
|
Tweet.bulk_update([t1, t2, t3], fields=['timestamp'])
|
|
|
|
# Ensure that the values were handled appropriately.
|
|
t1, t2, t3 = list(Tweet.select().order_by(Tweet.id))
|
|
self.assertEqual(t1.timestamp, datetime.datetime(2019, 1, 2, 3, 4, 5))
|
|
self.assertEqual(t2.timestamp, datetime.datetime(2019, 1, 3, 0, 0, 0))
|
|
self.assertEqual(t3.timestamp, t3_dt)
|
|
|
|
@skip_if(IS_SQLITE_OLD or IS_MYSQL or IS_CRDB)
|
|
@requires_models(CPK)
|
|
def test_bulk_update_cte(self):
|
|
CPK.insert_many([('k1', 1, 1), ('k2', 2, 2), ('k3', 3, 3)]).execute()
|
|
|
|
# We can also do a bulk-update using ValuesList when the primary-key of
|
|
# the model is a composite-pk.
|
|
new_values = [('k1', 1, 10), ('k3', 3, 30)]
|
|
cte = ValuesList(new_values).cte('new_values', columns=('k', 'v', 'x'))
|
|
|
|
# We have to use a subquery to update the individual column, as SQLite
|
|
# does not support UPDATE/FROM syntax.
|
|
subq = (cte
|
|
.select(cte.c.x)
|
|
.where(CPK._meta.primary_key == (cte.c.k, cte.c.v)))
|
|
|
|
# Perform the update, assigning extra the new value from the values
|
|
# list, and restricting the overall update using the composite pk.
|
|
res = (CPK
|
|
.update(extra=subq)
|
|
.where(CPK._meta.primary_key.in_(cte.select(cte.c.k, cte.c.v)))
|
|
.with_cte(cte)
|
|
.execute())
|
|
|
|
self.assertEqual(list(sorted(CPK.select().tuples())), [
|
|
('k1', 1, 10), ('k2', 2, 2), ('k3', 3, 30)])
|
|
|
|
@skip_if(IS_SQLITE_OLD or IS_MYSQL or IS_CRDB)
|
|
@requires_models(User)
|
|
def test_multi_update(self):
|
|
data = [(i, 'u%s' % i) for i in range(1, 4)]
|
|
User.insert_many(data, fields=[User.id, User.username]).execute()
|
|
|
|
data = [(i, 'u%sx' % i) for i in range(1, 3)]
|
|
vl = ValuesList(data)
|
|
cte = vl.select().cte('uv', columns=('id', 'username'))
|
|
subq = cte.select(cte.c.username).where(cte.c.id == User.id)
|
|
res = (User
|
|
.update(username=subq)
|
|
.where(User.id.in_(cte.select(cte.c.id)))
|
|
.with_cte(cte)
|
|
.execute())
|
|
query = User.select().order_by(User.id)
|
|
self.assertEqual([(u.id, u.username) for u in query], [
|
|
(1, 'u1x'),
|
|
(2, 'u2x'),
|
|
(3, 'u3')])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_get_shortcut(self):
|
|
huey = self.add_user('huey')
|
|
self.add_tweets(huey, 'meow', 'purr', 'wheeze')
|
|
mickey = self.add_user('mickey')
|
|
self.add_tweets(mickey, 'woof', 'yip')
|
|
|
|
# Lookup using just the ID.
|
|
huey_db = User.get(huey.id)
|
|
self.assertEqual(huey.id, huey_db.id)
|
|
|
|
# Lookup using an expression.
|
|
huey_db = User.get(User.username == 'huey')
|
|
self.assertEqual(huey.id, huey_db.id)
|
|
mickey_db = User.get(User.username == 'mickey')
|
|
self.assertEqual(mickey.id, mickey_db.id)
|
|
self.assertEqual(User.get(username='mickey').id, mickey.id)
|
|
|
|
# No results is an exception.
|
|
self.assertRaises(User.DoesNotExist, User.get, User.username == 'x')
|
|
# Multiple results is OK.
|
|
tweet = Tweet.get(Tweet.user == huey_db)
|
|
self.assertTrue(tweet.content in ('meow', 'purr', 'wheeze'))
|
|
|
|
# We cannot traverse a join like this.
|
|
@self.database.atomic()
|
|
def has_error():
|
|
Tweet.get(User.username == 'huey')
|
|
self.assertRaises(Exception, has_error)
|
|
|
|
# This is OK, though.
|
|
tweet = Tweet.get(user__username='mickey')
|
|
self.assertTrue(tweet.content in ('woof', 'yip'))
|
|
|
|
tweet = Tweet.get(content__ilike='w%',
|
|
user__username__ilike='%ck%')
|
|
self.assertEqual(tweet.content, 'woof')
|
|
|
|
@requires_models(User)
|
|
def test_get_with_alias(self):
|
|
huey = self.add_user('huey')
|
|
query = (User
|
|
.select(User.username.alias('name'))
|
|
.where(User.username == 'huey'))
|
|
obj = query.dicts().get()
|
|
self.assertEqual(obj, {'name': 'huey'})
|
|
|
|
obj = query.objects().get()
|
|
self.assertEqual(obj.name, 'huey')
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_get_or_none(self):
|
|
huey = self.add_user('huey')
|
|
self.assertEqual(User.get_or_none(User.username == 'huey').username,
|
|
'huey')
|
|
self.assertIsNone(User.get_or_none(User.username == 'foo'))
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_model_select_get_or_none(self):
|
|
huey = self.add_user('huey')
|
|
huey_db = User.select().where(User.username == 'huey').get_or_none()
|
|
self.assertEqual(huey_db.username, 'huey')
|
|
self.assertIsNone(
|
|
User.select().where(User.username == 'foo').get_or_none())
|
|
|
|
@requires_models(User, Color)
|
|
def test_get_by_id(self):
|
|
huey = self.add_user('huey')
|
|
self.assertEqual(User.get_by_id(huey.id).username, 'huey')
|
|
|
|
Color.insert_many([
|
|
{'name': 'red', 'is_neutral': False},
|
|
{'name': 'blue', 'is_neutral': False}]).execute()
|
|
self.assertEqual(Color.get_by_id('red').name, 'red')
|
|
self.assertRaises(Color.DoesNotExist, Color.get_by_id, 'green')
|
|
|
|
self.assertEqual(Color['red'].name, 'red')
|
|
self.assertRaises(Color.DoesNotExist, lambda: Color['green'])
|
|
|
|
@requires_models(User, Color)
|
|
def test_get_set_item(self):
|
|
huey = self.add_user('huey')
|
|
huey_db = User[huey.id]
|
|
self.assertEqual(huey_db.username, 'huey')
|
|
|
|
User[huey.id] = {'username': 'huey-x'}
|
|
huey_db = User[huey.id]
|
|
self.assertEqual(huey_db.username, 'huey-x')
|
|
del User[huey.id]
|
|
self.assertEqual(len(User), 0)
|
|
|
|
# Allow creation by specifying None for key.
|
|
User[None] = {'username': 'zaizee'}
|
|
User.get(User.username == 'zaizee')
|
|
|
|
@requires_models(User)
|
|
def test_get_or_create(self):
|
|
huey, created = User.get_or_create(username='huey')
|
|
self.assertTrue(created)
|
|
huey2, created2 = User.get_or_create(username='huey')
|
|
self.assertFalse(created2)
|
|
self.assertEqual(huey.id, huey2.id)
|
|
|
|
@requires_models(Category)
|
|
def test_get_or_create_self_referential_fk(self):
|
|
parent = Category.create(name='parent')
|
|
child, created = Category.get_or_create(parent=parent, name='child')
|
|
child_db = Category.get(Category.parent == parent)
|
|
self.assertEqual(child_db.parent.name, 'parent')
|
|
self.assertEqual(child_db.name, 'child')
|
|
|
|
@requires_models(Person)
|
|
def test_get_or_create_defaults(self):
|
|
p, created = Person.get_or_create(first='huey', defaults={
|
|
'last': 'cat',
|
|
'dob': datetime.date(2010, 7, 1)})
|
|
self.assertTrue(created)
|
|
|
|
p_db = Person.get(Person.first == 'huey')
|
|
self.assertEqual(p_db.first, 'huey')
|
|
self.assertEqual(p_db.last, 'cat')
|
|
self.assertEqual(p_db.dob, datetime.date(2010, 7, 1))
|
|
|
|
p2, created = Person.get_or_create(first='huey', defaults={
|
|
'last': 'kitten',
|
|
'dob': datetime.date(2020, 1, 1)})
|
|
self.assertFalse(created)
|
|
self.assertEqual(p2.first, 'huey')
|
|
self.assertEqual(p2.last, 'cat')
|
|
self.assertEqual(p2.dob, datetime.date(2010, 7, 1))
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_model_select(self):
|
|
huey = self.add_user('huey')
|
|
mickey = self.add_user('mickey')
|
|
zaizee = self.add_user('zaizee')
|
|
|
|
self.add_tweets(huey, 'meow', 'hiss', 'purr')
|
|
self.add_tweets(mickey, 'woof', 'whine')
|
|
|
|
with self.assertQueryCount(1):
|
|
query = (Tweet
|
|
.select(Tweet.content, User.username)
|
|
.join(User)
|
|
.order_by(User.username, Tweet.content))
|
|
self.assertSQL(query, (
|
|
'SELECT "t1"."content", "t2"."username" '
|
|
'FROM "tweet" AS "t1" '
|
|
'INNER JOIN "users" AS "t2" '
|
|
'ON ("t1"."user_id" = "t2"."id") '
|
|
'ORDER BY "t2"."username", "t1"."content"'), [])
|
|
|
|
tweets = list(query)
|
|
self.assertEqual([(t.content, t.user.username) for t in tweets], [
|
|
('hiss', 'huey'),
|
|
('meow', 'huey'),
|
|
('purr', 'huey'),
|
|
('whine', 'mickey'),
|
|
('woof', 'mickey')])
|
|
|
|
@requires_pglike
|
|
@requires_models(User, Tweet)
|
|
def test_distinct_on(self):
|
|
u1, u2 = self.add_user('u1'), self.add_user('u2')
|
|
self.add_tweets(u1, 'u1-t1', 'u1-t2', 'u1-t3')
|
|
self.add_tweets(u2, 'u2-t1')
|
|
|
|
query = (Tweet
|
|
.select(Tweet.user, Tweet.content)
|
|
.join(User)
|
|
.distinct(Tweet.user)
|
|
.order_by(Tweet.user, Tweet.timestamp))
|
|
self.assertEqual([(t.user_id, t.content) for t in query],
|
|
[(u1.id, 'u1-t1'), (u2.id, 'u2-t1')])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_filtering(self):
|
|
with self.database.atomic():
|
|
huey = self.add_user('huey')
|
|
mickey = self.add_user('mickey')
|
|
self.add_tweets(huey, 'meow', 'hiss', 'purr')
|
|
self.add_tweets(mickey, 'woof', 'wheeze')
|
|
|
|
with self.assertQueryCount(1):
|
|
query = Tweet.filter(user__username='huey').order_by(Tweet.content)
|
|
self.assertEqual([row.content for row in query],
|
|
['hiss', 'meow', 'purr'])
|
|
|
|
with self.assertQueryCount(1):
|
|
query = User.filter(tweets__content__ilike='w%')
|
|
self.assertEqual([user.username for user in query],
|
|
['mickey', 'mickey'])
|
|
|
|
@requires_models(User)
|
|
def test_select_count(self):
|
|
users = [self.add_user(u) for u in ('huey', 'charlie', 'mickey')]
|
|
self.assertEqual(User.select().count(), 3)
|
|
|
|
qr = User.select().execute()
|
|
self.assertEqual(qr.count, 0)
|
|
|
|
list(qr)
|
|
self.assertEqual(qr.count, 3)
|
|
|
|
@requires_models(User)
|
|
def test_peek(self):
|
|
for username in ('huey', 'mickey', 'zaizee'):
|
|
self.add_user(username)
|
|
|
|
with self.assertQueryCount(1):
|
|
query = User.select(User.username).order_by(User.username).dicts()
|
|
self.assertEqual(query.peek(n=1), {'username': 'huey'})
|
|
self.assertEqual(query.peek(n=2), [{'username': 'huey'},
|
|
{'username': 'mickey'}])
|
|
|
|
@requires_models(User)
|
|
def test_first(self):
|
|
for u in 'abc':
|
|
self.add_user(u)
|
|
|
|
# Multiple calls to first() do not result in multiple executions.
|
|
with self.assertQueryCount(1):
|
|
q = User.select().order_by(User.username)
|
|
self.assertEqual(q.first().username, 'a')
|
|
self.assertEqual(q.first().username, 'a')
|
|
|
|
@requires_models(User, Tweet, Favorite)
|
|
def test_join_two_fks(self):
|
|
with self.database.atomic():
|
|
huey = self.add_user('huey')
|
|
mickey = self.add_user('mickey')
|
|
h_m, h_p, h_h = self.add_tweets(huey, 'meow', 'purr', 'hiss')
|
|
m_w, m_b = self.add_tweets(mickey, 'woof', 'bark')
|
|
Favorite.create(user=huey, tweet=m_w)
|
|
Favorite.create(user=mickey, tweet=h_m)
|
|
Favorite.create(user=mickey, tweet=h_p)
|
|
|
|
with self.assertQueryCount(1):
|
|
UA = User.alias()
|
|
query = (Favorite
|
|
.select(Favorite, Tweet, User, UA)
|
|
.join(Tweet)
|
|
.join(User)
|
|
.switch(Favorite)
|
|
.join(UA, on=Favorite.user)
|
|
.order_by(Favorite.id))
|
|
|
|
accum = [(f.tweet.user.username, f.tweet.content, f.user.username)
|
|
for f in query]
|
|
|
|
self.assertEqual(accum, [
|
|
('mickey', 'woof', 'huey'),
|
|
('huey', 'meow', 'mickey'),
|
|
('huey', 'purr', 'mickey')])
|
|
|
|
with self.assertQueryCount(5):
|
|
# Test intermediate models not selected.
|
|
query = (Favorite
|
|
.select()
|
|
.join(Tweet)
|
|
.switch(Favorite)
|
|
.join(User)
|
|
.where(User.username == 'mickey')
|
|
.order_by(Favorite.id))
|
|
|
|
accum = [(f.user.username, f.tweet.content) for f in query]
|
|
|
|
self.assertEqual(accum, [('mickey', 'meow'), ('mickey', 'purr')])
|
|
|
|
@requires_models(A, B, C)
|
|
def test_join_issue_1482(self):
|
|
a1 = A.create(a='a1')
|
|
b1 = B.create(a=a1, b='b1')
|
|
c1 = C.create(b=b1, c='c1')
|
|
|
|
with self.assertQueryCount(3):
|
|
query = C.select().join(B).join(A).where(A.a == 'a1')
|
|
accum = [(c.c, c.b.b, c.b.a.a) for c in query]
|
|
|
|
self.assertEqual(accum, [('c1', 'b1', 'a1')])
|
|
|
|
@requires_models(A, B, C)
|
|
def test_join_empty_intermediate_model(self):
|
|
a1 = A.create(a='a1')
|
|
a2 = A.create(a='a2')
|
|
b11 = B.create(a=a1, b='b11')
|
|
b12 = B.create(a=a1, b='b12')
|
|
b21 = B.create(a=a2, b='b21')
|
|
c111 = C.create(b=b11, c='c111')
|
|
c112 = C.create(b=b11, c='c112')
|
|
c211 = C.create(b=b21, c='c211')
|
|
|
|
with self.assertQueryCount(1):
|
|
query = C.select(C, A.a).join(B).join(A).order_by(C.c)
|
|
accum = [(c.c, c.b.a.a) for c in query]
|
|
|
|
self.assertEqual(accum, [
|
|
('c111', 'a1'),
|
|
('c112', 'a1'),
|
|
('c211', 'a2')])
|
|
|
|
with self.assertQueryCount(1):
|
|
query = C.select(C, B, A).join(B).join(A).order_by(C.c)
|
|
accum = [(c.c, c.b.b, c.b.a.a) for c in query]
|
|
|
|
self.assertEqual(accum, [
|
|
('c111', 'b11', 'a1'),
|
|
('c112', 'b11', 'a1'),
|
|
('c211', 'b21', 'a2')])
|
|
|
|
@requires_models(City, Venue, Event)
|
|
def test_join_empty_relations(self):
|
|
with self.database.atomic():
|
|
city = City.create(name='Topeka')
|
|
venue1 = Venue.create(name='House', city=city, city_n=city)
|
|
venue2 = Venue.create(name='Nowhere', city=city, city_n=None)
|
|
|
|
event1 = Event.create(name='House Party', venue=venue1)
|
|
event2 = Event.create(name='Holiday')
|
|
event3 = Event.create(name='Nowhere Party', venue=venue2)
|
|
|
|
with self.assertQueryCount(1):
|
|
query = (Event
|
|
.select(Event, Venue, City)
|
|
.join(Venue, JOIN.LEFT_OUTER)
|
|
.join(City, JOIN.LEFT_OUTER, on=Venue.city)
|
|
.order_by(Event.id))
|
|
|
|
# Here we have two left-outer joins, and the second Event
|
|
# ("Holiday"), does not have an associated Venue (hence, no City).
|
|
# Peewee would attach an empty Venue() model to the event, however.
|
|
# It did this since we are selecting from Venue/City and Venue is
|
|
# an intermediary model. It is more correct for Event.venue to be
|
|
# None in this case. This is now patched / fixed.
|
|
r = [(e.name, e.venue is not None and e.venue.city.name or None)
|
|
for e in query]
|
|
self.assertEqual(r, [
|
|
('House Party', 'Topeka'),
|
|
('Holiday', None),
|
|
('Nowhere Party', 'Topeka')])
|
|
|
|
with self.assertQueryCount(1):
|
|
query = (Event
|
|
.select(Event, Venue, City)
|
|
.join(Venue, JOIN.INNER)
|
|
.join(City, JOIN.LEFT_OUTER, on=Venue.city_n)
|
|
.order_by(Event.id))
|
|
|
|
# Here we have an inner join and a left-outer join. The furthest
|
|
# object (City) will be NULL for the "Nowhere Party". Make sure
|
|
# that the object is left as None and not populated with an empty
|
|
# City instance.
|
|
accum = []
|
|
for event in query:
|
|
city_name = event.venue.city_n and event.venue.city_n.name
|
|
accum.append((event.name, event.venue.name, city_name))
|
|
|
|
self.assertEqual(accum, [
|
|
('House Party', 'House', 'Topeka'),
|
|
('Nowhere Party', 'Nowhere', None)])
|
|
|
|
@requires_models(Relationship, Person)
|
|
def test_join_same_model_twice(self):
|
|
d = datetime.date(2010, 1, 1)
|
|
huey = Person.create(first='huey', last='cat', dob=d)
|
|
zaizee = Person.create(first='zaizee', last='cat', dob=d)
|
|
mickey = Person.create(first='mickey', last='dog', dob=d)
|
|
relationships = (
|
|
(huey, zaizee),
|
|
(zaizee, huey),
|
|
(mickey, huey),
|
|
)
|
|
for src, dest in relationships:
|
|
Relationship.create(from_person=src, to_person=dest)
|
|
|
|
PA = Person.alias()
|
|
with self.assertQueryCount(1):
|
|
query = (Relationship
|
|
.select(Relationship, Person, PA)
|
|
.join(Person, on=Relationship.from_person)
|
|
.switch(Relationship)
|
|
.join(PA, on=Relationship.to_person)
|
|
.order_by(Relationship.id))
|
|
results = [(r.from_person.first, r.to_person.first) for r in query]
|
|
|
|
self.assertEqual(results, [
|
|
('huey', 'zaizee'),
|
|
('zaizee', 'huey'),
|
|
('mickey', 'huey')])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_join_to_dict(self):
|
|
huey = self.add_user('huey')
|
|
mickey = self.add_user('mickey')
|
|
self.add_tweets(huey, 'meow', 'hiss', 'purr')
|
|
self.add_tweets(mickey, 'woof')
|
|
|
|
with self.assertQueryCount(1):
|
|
q = Select((User,), (User.id, User.username,))
|
|
query = (Tweet
|
|
.select(Tweet.content, q.c.username)
|
|
.join(q, on=(Tweet.user == q.c.id), attr='u')
|
|
.order_by(q.c.username, Tweet.content))
|
|
self.assertSQL(query, (
|
|
'SELECT "t1"."content", "t2"."username" FROM "tweet" AS "t1" '
|
|
'INNER JOIN (SELECT "t3"."id", "t3"."username" FROM "users" '
|
|
'AS "t3") AS "t2" ON ("t1"."user_id" = "t2"."id") '
|
|
'ORDER BY "t2"."username", "t1"."content"'), [])
|
|
|
|
tweets = list(query)
|
|
self.assertEqual([(t.content, t.u) for t in tweets], [
|
|
('hiss', {'username': 'huey'}),
|
|
('meow', {'username': 'huey'}),
|
|
('purr', {'username': 'huey'}),
|
|
('woof', {'username': 'mickey'})])
|
|
|
|
@requires_models(User, Tweet, Favorite)
|
|
def test_multi_join(self):
|
|
u1 = User.create(username='u1')
|
|
u2 = User.create(username='u2')
|
|
u3 = User.create(username='u3')
|
|
t1_1 = Tweet.create(user=u1, content='t1-1')
|
|
t1_2 = Tweet.create(user=u1, content='t1-2')
|
|
t2_1 = Tweet.create(user=u2, content='t2-1')
|
|
t2_2 = Tweet.create(user=u2, content='t2-2')
|
|
favorites = ((u1, t2_1),
|
|
(u1, t2_2),
|
|
(u2, t1_1),
|
|
(u3, t1_2),
|
|
(u3, t2_2))
|
|
for user, tweet in favorites:
|
|
Favorite.create(user=user, tweet=tweet)
|
|
|
|
TweetUser = User.alias('u2')
|
|
|
|
with self.assertQueryCount(1):
|
|
query = (Favorite
|
|
.select(Favorite.id,
|
|
Tweet.content,
|
|
User.username,
|
|
TweetUser.username)
|
|
.join(Tweet)
|
|
.join(TweetUser, on=(Tweet.user == TweetUser.id))
|
|
.switch(Favorite)
|
|
.join(User)
|
|
.order_by(Tweet.content, Favorite.id))
|
|
self.assertSQL(query, (
|
|
'SELECT '
|
|
'"t1"."id", "t2"."content", "t3"."username", "u2"."username" '
|
|
'FROM "favorite" AS "t1" '
|
|
'INNER JOIN "tweet" AS "t2" ON ("t1"."tweet_id" = "t2"."id") '
|
|
'INNER JOIN "users" AS "u2" ON ("t2"."user_id" = "u2"."id") '
|
|
'INNER JOIN "users" AS "t3" ON ("t1"."user_id" = "t3"."id") '
|
|
'ORDER BY "t2"."content", "t1"."id"'), [])
|
|
|
|
accum = [(f.tweet.user.username, f.tweet.content, f.user.username)
|
|
for f in query]
|
|
self.assertEqual(accum, [
|
|
('u1', 't1-1', 'u2'),
|
|
('u1', 't1-2', 'u3'),
|
|
('u2', 't2-1', 'u1'),
|
|
('u2', 't2-2', 'u1'),
|
|
('u2', 't2-2', 'u3')])
|
|
|
|
res = query.count()
|
|
self.assertEqual(res, 5)
|
|
|
|
def _create_user_tweets(self):
|
|
data = (('huey', ('meow', 'purr', 'hiss')),
|
|
('zaizee', ()),
|
|
('mickey', ('woof', 'grr')))
|
|
|
|
with self.database.atomic():
|
|
ts = int(time.time())
|
|
for username, tweets in data:
|
|
user = User.create(username=username)
|
|
for tweet in tweets:
|
|
Tweet.create(user=user, content=tweet, timestamp=ts)
|
|
ts += 1
|
|
|
|
@requires_pglike
|
|
@requires_models(User)
|
|
def test_join_on_valueslist(self):
|
|
for username in ('huey', 'mickey', 'zaizee'):
|
|
User.create(username=username)
|
|
|
|
vl = ValuesList([('huey',), ('zaizee',)], columns=['username'])
|
|
with self.assertQueryCount(1):
|
|
query = (User
|
|
.select(vl.c.username)
|
|
.join(vl, on=(User.username == vl.c.username))
|
|
.order_by(vl.c.username.desc()))
|
|
self.assertEqual([u.username for u in query], ['zaizee', 'huey'])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_join_subquery(self):
|
|
self._create_user_tweets()
|
|
|
|
# Select note user and timestamp of most recent tweet.
|
|
with self.assertQueryCount(1):
|
|
TA = Tweet.alias()
|
|
max_q = (TA
|
|
.select(TA.user, fn.MAX(TA.timestamp).alias('max_ts'))
|
|
.group_by(TA.user)
|
|
.alias('max_q'))
|
|
|
|
predicate = ((Tweet.user == max_q.c.user_id) &
|
|
(Tweet.timestamp == max_q.c.max_ts))
|
|
latest = (Tweet
|
|
.select(Tweet.user, Tweet.content, Tweet.timestamp)
|
|
.join(max_q, on=predicate)
|
|
.alias('latest'))
|
|
|
|
query = (User
|
|
.select(User, latest.c.content, latest.c.timestamp)
|
|
.join(latest, on=(User.id == latest.c.user_id)))
|
|
|
|
data = [(user.username, user.tweet.content) for user in query]
|
|
|
|
# Failing on travis-ci...old SQLite?
|
|
if not IS_SQLITE_OLD:
|
|
self.assertEqual(data, [
|
|
('huey', 'hiss'),
|
|
('mickey', 'grr')])
|
|
|
|
with self.assertQueryCount(1):
|
|
query = (Tweet
|
|
.select(Tweet, User)
|
|
.join(max_q, on=predicate)
|
|
.switch(Tweet)
|
|
.join(User))
|
|
data = [(note.user.username, note.content) for note in query]
|
|
|
|
self.assertEqual(data, [
|
|
('huey', 'hiss'),
|
|
('mickey', 'grr')])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_join_subquery_2(self):
|
|
self._create_user_tweets()
|
|
|
|
with self.assertQueryCount(1):
|
|
users = (User
|
|
.select(User.id, User.username)
|
|
.where(User.username.in_(['huey', 'zaizee'])))
|
|
query = (Tweet
|
|
.select(Tweet.content.alias('content'),
|
|
users.c.username.alias('username'))
|
|
.join(users, on=(Tweet.user == users.c.id))
|
|
.order_by(Tweet.id))
|
|
|
|
self.assertSQL(query, (
|
|
'SELECT "t1"."content" AS "content", '
|
|
'"t2"."username" AS "username"'
|
|
' FROM "tweet" AS "t1" '
|
|
'INNER JOIN (SELECT "t3"."id", "t3"."username" '
|
|
'FROM "users" AS "t3" '
|
|
'WHERE ("t3"."username" IN (?, ?))) AS "t2" '
|
|
'ON ("t1"."user_id" = "t2"."id") '
|
|
'ORDER BY "t1"."id"'), ['huey', 'zaizee'])
|
|
|
|
results = [(t.content, t.user.username) for t in query]
|
|
self.assertEqual(results, [
|
|
('meow', 'huey'),
|
|
('purr', 'huey'),
|
|
('hiss', 'huey')])
|
|
|
|
@skip_if(IS_SQLITE_OLD or (IS_MYSQL and not IS_MYSQL_ADVANCED_FEATURES))
|
|
@requires_models(User, Tweet)
|
|
def test_join_subquery_cte(self):
|
|
self._create_user_tweets()
|
|
|
|
cte = (User
|
|
.select(User.id, User.username)
|
|
.where(User.username.in_(['huey', 'zaizee']))\
|
|
.cte('cats'))
|
|
|
|
with self.assertQueryCount(1):
|
|
# Attempt join with subquery as common-table expression.
|
|
query = (Tweet
|
|
.select(Tweet.content, cte.c.username)
|
|
.join(cte, on=(Tweet.user == cte.c.id))
|
|
.order_by(Tweet.id)
|
|
.with_cte(cte))
|
|
self.assertSQL(query, (
|
|
'WITH "cats" AS ('
|
|
'SELECT "t1"."id", "t1"."username" FROM "users" AS "t1" '
|
|
'WHERE ("t1"."username" IN (?, ?))) '
|
|
'SELECT "t2"."content", "cats"."username" FROM "tweet" AS "t2" '
|
|
'INNER JOIN "cats" ON ("t2"."user_id" = "cats"."id") '
|
|
'ORDER BY "t2"."id"'), ['huey', 'zaizee'])
|
|
|
|
self.assertEqual([t.content for t in query],
|
|
['meow', 'purr', 'hiss'])
|
|
|
|
@skip_if(IS_MYSQL) # MariaDB does not support LIMIT in subqueries!
|
|
@requires_models(User)
|
|
def test_subquery_emulate_window(self):
|
|
# We have duplicated users. Select a maximum of 2 instances of the
|
|
# username.
|
|
name2count = {
|
|
'beanie': 6,
|
|
'huey': 5,
|
|
'mickey': 3,
|
|
'pipey': 1,
|
|
'zaizee': 4}
|
|
names = []
|
|
for name, count in sorted(name2count.items()):
|
|
names += [name] * count
|
|
User.insert_many([(i, n) for i, n in enumerate(names, 1)],
|
|
[User.id, User.username]).execute()
|
|
|
|
# The results we are trying to obtain.
|
|
expected = [
|
|
('beanie', 1), ('beanie', 2),
|
|
('huey', 7), ('huey', 8),
|
|
('mickey', 12), ('mickey', 13),
|
|
('pipey', 15),
|
|
('zaizee', 16), ('zaizee', 17)]
|
|
|
|
with self.assertQueryCount(1):
|
|
# Using a self-join.
|
|
UA = User.alias()
|
|
query = (User
|
|
.select(User.username, UA.id)
|
|
.join(UA, on=((UA.username == User.username) &
|
|
(UA.id >= User.id)))
|
|
.group_by(User.username, UA.id)
|
|
.having(fn.COUNT(UA.id) < 3)
|
|
.order_by(User.username, UA.id))
|
|
self.assertEqual(query.tuples()[:], expected)
|
|
|
|
with self.assertQueryCount(1):
|
|
# Using a correlated subquery.
|
|
subq = (UA
|
|
.select(UA.id)
|
|
.where(User.username == UA.username)
|
|
.order_by(UA.id)
|
|
.limit(2))
|
|
query = (User
|
|
.select(User.username, User.id)
|
|
.where(User.id.in_(subq.alias('subq')))
|
|
.order_by(User.username, User.id))
|
|
self.assertEqual(query.tuples()[:], expected)
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_subquery_alias_selection(self):
|
|
data = (
|
|
('huey', ('meow', 'hiss', 'purr')),
|
|
('mickey', ('woof', 'bark')),
|
|
('zaizee', ()))
|
|
with self.database.atomic():
|
|
for username, tweets in data:
|
|
user = User.create(username=username)
|
|
for tweet in tweets:
|
|
Tweet.create(user=user, content=tweet)
|
|
|
|
with self.assertQueryCount(1):
|
|
subq = (Tweet
|
|
.select(fn.COUNT(Tweet.id))
|
|
.where(Tweet.user == User.id))
|
|
query = (User
|
|
.select(User.username, subq.alias('tweet_count'))
|
|
.order_by(User.id))
|
|
self.assertEqual([(u.username, u.tweet_count) for u in query], [
|
|
('huey', 3),
|
|
('mickey', 2),
|
|
('zaizee', 0)])
|
|
|
|
@requires_models(Point)
|
|
def test_subquery_in_select_expression(self):
|
|
for x, y in ((1, 1), (1, 2), (10, 10), (10, 20)):
|
|
Point.create(x=x, y=y)
|
|
|
|
with self.assertQueryCount(1):
|
|
PA = Point.alias('pa')
|
|
subq = PA.select(fn.SUM(PA.y)).where(PA.x == Point.x)
|
|
query = (Point
|
|
.select(Point.x, Point.y, subq.alias('sy'))
|
|
.order_by(Point.x, Point.y))
|
|
self.assertEqual(list(query.tuples()), [
|
|
(1, 1, 3),
|
|
(1, 2, 3),
|
|
(10, 10, 30),
|
|
(10, 20, 30)])
|
|
|
|
with self.assertQueryCount(1):
|
|
query = (Point
|
|
.select(Point.x, (Point.y + subq).alias('sy'))
|
|
.order_by(Point.x, Point.y))
|
|
self.assertEqual(list(query.tuples()), [
|
|
(1, 4), (1, 5),
|
|
(10, 40), (10, 50)])
|
|
|
|
@skip_if(IS_SQLITE and not IS_SQLITE_9, 'requires sqlite >= 3.9')
|
|
@requires_models(Register)
|
|
def test_compound_select(self):
|
|
for i in range(10):
|
|
Register.create(value=i)
|
|
|
|
q1 = Register.select().where(Register.value < 2)
|
|
q2 = Register.select().where(Register.value > 7)
|
|
c1 = (q1 | q2).order_by(SQL('2'))
|
|
|
|
self.assertSQL(c1, (
|
|
'SELECT "t1"."id", "t1"."value" FROM "register" AS "t1" '
|
|
'WHERE ("t1"."value" < ?) UNION '
|
|
'SELECT "t2"."id", "t2"."value" FROM "register" AS "t2" '
|
|
'WHERE ("t2"."value" > ?) ORDER BY 2'), [2, 7])
|
|
|
|
self.assertEqual([row.value for row in c1], [0, 1, 8, 9],
|
|
[row.__data__ for row in c1])
|
|
self.assertEqual(c1.count(), 4)
|
|
|
|
q3 = Register.select().where(Register.value == 5)
|
|
c2 = (c1.order_by() | q3).order_by(SQL('2'))
|
|
|
|
self.assertSQL(c2, (
|
|
'SELECT "t1"."id", "t1"."value" FROM "register" AS "t1" '
|
|
'WHERE ("t1"."value" < ?) UNION '
|
|
'SELECT "t2"."id", "t2"."value" FROM "register" AS "t2" '
|
|
'WHERE ("t2"."value" > ?) UNION '
|
|
'SELECT "t3"."id", "t3"."value" FROM "register" AS "t3" '
|
|
'WHERE ("t3"."value" = ?) ORDER BY 2'), [2, 7, 5])
|
|
|
|
self.assertEqual([row.value for row in c2], [0, 1, 5, 8, 9])
|
|
self.assertEqual(c2.count(), 5)
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_union_column_resolution(self):
|
|
u1 = User.create(id=1, username='u1')
|
|
u2 = User.create(id=2, username='u2')
|
|
q1 = User.select().where(User.id == 1)
|
|
q2 = User.select()
|
|
union = q1 | q2
|
|
self.assertSQL(union, (
|
|
'SELECT "t1"."id", "t1"."username" FROM "users" AS "t1" '
|
|
'WHERE ("t1"."id" = ?) '
|
|
'UNION '
|
|
'SELECT "t2"."id", "t2"."username" FROM "users" AS "t2"'), [1])
|
|
|
|
results = [(user.id, user.username) for user in union]
|
|
self.assertEqual(sorted(results), [
|
|
(1, 'u1'),
|
|
(2, 'u2')])
|
|
|
|
t1_1 = Tweet.create(id=1, user=u1, content='u1-t1')
|
|
t1_2 = Tweet.create(id=2, user=u1, content='u1-t2')
|
|
t2_1 = Tweet.create(id=3, user=u2, content='u2-t1')
|
|
|
|
with self.assertQueryCount(1):
|
|
q1 = Tweet.select(Tweet, User).join(User).where(User.id == 1)
|
|
q2 = Tweet.select(Tweet, User).join(User)
|
|
union = q1 | q2
|
|
|
|
self.assertSQL(union, (
|
|
'SELECT "t1"."id", "t1"."user_id", "t1"."content", '
|
|
'"t1"."timestamp", "t2"."id", "t2"."username" '
|
|
'FROM "tweet" AS "t1" '
|
|
'INNER JOIN "users" AS "t2" ON ("t1"."user_id" = "t2"."id") '
|
|
'WHERE ("t2"."id" = ?) '
|
|
'UNION '
|
|
'SELECT "t3"."id", "t3"."user_id", "t3"."content", '
|
|
'"t3"."timestamp", "t4"."id", "t4"."username" '
|
|
'FROM "tweet" AS "t3" '
|
|
'INNER JOIN "users" AS "t4" ON ("t3"."user_id" = "t4"."id")'),
|
|
[1])
|
|
|
|
results = [(t.id, t.content, t.user.username) for t in union]
|
|
self.assertEqual(sorted(results), [
|
|
(1, 'u1-t1', 'u1'),
|
|
(2, 'u1-t2', 'u1'),
|
|
(3, 'u2-t1', 'u2')])
|
|
|
|
with self.assertQueryCount(1):
|
|
union_flat = (q1 | q2).objects()
|
|
results = list(results)
|
|
results = [(t.id, t.content, t.username, t.id_2)
|
|
for t in union_flat]
|
|
self.assertEqual(sorted(results), [
|
|
(1, 'u1-t1', 'u1', 1),
|
|
(2, 'u1-t2', 'u1', 1),
|
|
(3, 'u2-t1', 'u2', 2)])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_compound_select_as_subquery(self):
|
|
with self.database.atomic():
|
|
for i in range(5):
|
|
user = User.create(username='u%s' % i)
|
|
for j in range(i * 2):
|
|
Tweet.create(user=user, content='t%s-%s' % (i, j))
|
|
|
|
q1 = (Tweet
|
|
.select(Tweet.id, Tweet.content, User.username)
|
|
.join(User)
|
|
.where(User.username == 'u3'))
|
|
q2 = (Tweet
|
|
.select(Tweet.id, Tweet.content, User.username)
|
|
.join(User)
|
|
.where(User.username.in_(['u2', 'u4'])))
|
|
union = (q1 | q2)
|
|
|
|
q = (union
|
|
.select_from(union.c.username, fn.COUNT(union.c.id).alias('ct'))
|
|
.group_by(union.c.username)
|
|
.order_by(fn.COUNT(union.c.id).desc())
|
|
.dicts())
|
|
self.assertEqual(list(q), [
|
|
{'username': 'u4', 'ct': 8},
|
|
{'username': 'u3', 'ct': 6},
|
|
{'username': 'u2', 'ct': 4}])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_union_with_join(self):
|
|
u1, u2 = [User.create(username='u%s' % i) for i in (1, 2)]
|
|
for u, ts in ((u1, ('t1', 't2')), (u2, ('t1',))):
|
|
for t in ts:
|
|
Tweet.create(user=u, content='%s-%s' % (u.username, t))
|
|
|
|
with self.assertQueryCount(1):
|
|
q1 = (User
|
|
.select(User, Tweet)
|
|
.join(Tweet, on=(Tweet.user == User.id).alias('foo')))
|
|
q2 = (User
|
|
.select(User, Tweet)
|
|
.join(Tweet, on=(Tweet.user == User.id).alias('foo')))
|
|
|
|
self.assertEqual(
|
|
sorted([(user.username, user.foo.content) for user in q1]),
|
|
[('u1', 'u1-t1'), ('u1', 'u1-t2'), ('u2', 'u2-t1')])
|
|
|
|
with self.assertQueryCount(1):
|
|
uq = q1.union_all(q2)
|
|
result = [(user.username, user.foo.content) for user in uq]
|
|
self.assertEqual(sorted(result), [
|
|
('u1', 'u1-t1'),
|
|
('u1', 'u1-t1'),
|
|
('u1', 'u1-t2'),
|
|
('u1', 'u1-t2'),
|
|
('u2', 'u2-t1'),
|
|
('u2', 'u2-t1'),
|
|
])
|
|
|
|
@skip_if(IS_SQLITE_OLD or (IS_MYSQL and not IS_MYSQL_ADVANCED_FEATURES))
|
|
@requires_models(User)
|
|
def test_union_cte(self):
|
|
with self.database.atomic():
|
|
(User
|
|
.insert_many({'username': 'u%s' % i} for i in range(10))
|
|
.execute())
|
|
|
|
lhs = User.select().where(User.username.in_(['u1', 'u3']))
|
|
rhs = User.select().where(User.username.in_(['u5', 'u7']))
|
|
u_cte = (lhs | rhs).cte('users_union')
|
|
|
|
query = (User
|
|
.select(User.username)
|
|
.join(u_cte, on=(User.id == u_cte.c.id))
|
|
.where(User.username.in_(['u1', 'u7']))
|
|
.with_cte(u_cte))
|
|
self.assertEqual(sorted([u.username for u in query]), ['u1', 'u7'])
|
|
|
|
@requires_models(Category)
|
|
def test_self_referential_fk(self):
|
|
self.assertTrue(Category.parent.rel_model is Category)
|
|
|
|
root = Category.create(name='root')
|
|
c1 = Category.create(parent=root, name='child-1')
|
|
c2 = Category.create(parent=root, name='child-2')
|
|
|
|
with self.assertQueryCount(1):
|
|
Parent = Category.alias('p')
|
|
query = (Category
|
|
.select(
|
|
Parent.name,
|
|
Category.name)
|
|
.where(Category.parent == root)
|
|
.order_by(Category.name))
|
|
query = query.join(Parent, on=(Category.parent == Parent.name))
|
|
c1_db, c2_db = list(query)
|
|
|
|
self.assertEqual(c1_db.name, 'child-1')
|
|
self.assertEqual(c1_db.parent.name, 'root')
|
|
self.assertEqual(c2_db.name, 'child-2')
|
|
self.assertEqual(c2_db.parent.name, 'root')
|
|
|
|
def test_deferred_fk(self):
|
|
class Note(TestModel):
|
|
foo = DeferredForeignKey('Foo', backref='notes')
|
|
|
|
class Foo(TestModel):
|
|
note = ForeignKeyField(Note)
|
|
|
|
self.assertTrue(Note.foo.rel_model is Foo)
|
|
self.assertTrue(Foo.note.rel_model is Note)
|
|
f = Foo(id=1337)
|
|
self.assertSQL(f.notes, (
|
|
'SELECT "t1"."id", "t1"."foo_id" FROM "note" AS "t1" '
|
|
'WHERE ("t1"."foo_id" = ?)'), [1337])
|
|
|
|
def test_deferred_fk_dependency_graph(self):
|
|
class AUser(TestModel):
|
|
foo = DeferredForeignKey('Tweet')
|
|
class ZTweet(TestModel):
|
|
user = ForeignKeyField(AUser, backref='ztweets')
|
|
|
|
self.assertEqual(sort_models([AUser, ZTweet]), [AUser, ZTweet])
|
|
|
|
@requires_models(Category)
|
|
def test_empty_joined_instance(self):
|
|
root = Category.create(name='a')
|
|
c1 = Category.create(name='c1', parent=root)
|
|
c2 = Category.create(name='c2', parent=root)
|
|
|
|
with self.assertQueryCount(1):
|
|
Parent = Category.alias('p')
|
|
query = (Category
|
|
.select(Category, Parent)
|
|
.join(Parent, JOIN.LEFT_OUTER,
|
|
on=(Category.parent == Parent.name))
|
|
.order_by(Category.name))
|
|
result = [(category.name, category.parent is None)
|
|
for category in query]
|
|
|
|
self.assertEqual(result, [('a', True), ('c1', False), ('c2', False)])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_from_multi_table(self):
|
|
self.add_tweets(self.add_user('huey'), 'meow', 'hiss', 'purr')
|
|
self.add_tweets(self.add_user('mickey'), 'woof', 'wheeze')
|
|
|
|
with self.assertQueryCount(1):
|
|
query = (Tweet
|
|
.select(Tweet, User)
|
|
.from_(Tweet, User)
|
|
.where(
|
|
(Tweet.user == User.id) &
|
|
(User.username == 'huey'))
|
|
.order_by(Tweet.id)
|
|
.dicts())
|
|
|
|
self.assertEqual([t['content'] for t in query],
|
|
['meow', 'hiss', 'purr'])
|
|
self.assertEqual([t['username'] for t in query],
|
|
['huey', 'huey', 'huey'])
|
|
|
|
@requires_models(User)
|
|
def test_noop(self):
|
|
query = User.noop()
|
|
self.assertEqual(list(query), [])
|
|
|
|
@requires_models(User)
|
|
def test_iteration(self):
|
|
self.assertEqual(list(User), [])
|
|
self.assertEqual(len(User), 0)
|
|
self.assertTrue(User)
|
|
|
|
User.insert_many((['charlie'], ['huey']), [User.username]).execute()
|
|
self.assertEqual(sorted(u.username for u in User), ['charlie', 'huey'])
|
|
self.assertEqual(len(User), 2)
|
|
self.assertTrue(User)
|
|
|
|
@requires_models(User)
|
|
def test_iterator(self):
|
|
users = ['charlie', 'huey', 'zaizee']
|
|
with self.database.atomic():
|
|
for username in users:
|
|
User.create(username=username)
|
|
|
|
with self.assertQueryCount(1):
|
|
query = User.select().order_by(User.username).iterator()
|
|
self.assertEqual([u.username for u in query], users)
|
|
|
|
self.assertEqual(list(query), [])
|
|
|
|
@requires_models(User)
|
|
def test_batch_commit(self):
|
|
commit_method = self.database.commit
|
|
|
|
def assertBatch(n_rows, batch_size, n_commits):
|
|
User.delete().execute()
|
|
user_data = [{'username': 'u%s' % i} for i in range(n_rows)]
|
|
with mock.patch.object(self.database, 'commit') as mock_commit:
|
|
mock_commit.side_effect = commit_method
|
|
for row in self.database.batch_commit(user_data, batch_size):
|
|
User.create(**row)
|
|
|
|
self.assertEqual(mock_commit.call_count, n_commits)
|
|
self.assertEqual(User.select().count(), n_rows)
|
|
|
|
assertBatch(6, 1, 6)
|
|
assertBatch(6, 2, 3)
|
|
assertBatch(6, 3, 2)
|
|
assertBatch(6, 4, 2)
|
|
assertBatch(6, 6, 1)
|
|
assertBatch(6, 7, 1)
|
|
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_assertQueryCount(self):
|
|
self.add_tweets(self.add_user('charlie'), 'foo', 'bar', 'baz')
|
|
def do_test(n):
|
|
with self.assertQueryCount(n):
|
|
authors = [tweet.user.username for tweet in Tweet.select()]
|
|
|
|
self.assertRaises(AssertionError, do_test, 1)
|
|
self.assertRaises(AssertionError, do_test, 3)
|
|
do_test(4)
|
|
self.assertRaises(AssertionError, do_test, 5)
|
|
|
|
|
|
|
|
class TestRaw(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [User]
|
|
|
|
def test_raw(self):
|
|
with self.database.atomic():
|
|
for username in ('charlie', 'chuck', 'huey', 'zaizee'):
|
|
User.create(username=username)
|
|
|
|
query = (User
|
|
.raw('SELECT username, SUBSTR(username, 1, 1) AS first '
|
|
'FROM users '
|
|
'WHERE SUBSTR(username, 1, 1) = ? '
|
|
'ORDER BY username DESC', 'c'))
|
|
self.assertEqual([(row.username, row.first) for row in query],
|
|
[('chuck', 'c'), ('charlie', 'c')])
|
|
|
|
def test_raw_iterator(self):
|
|
(User
|
|
.insert_many([('charlie',), ('huey',)], fields=[User.username])
|
|
.execute())
|
|
|
|
with self.assertQueryCount(1):
|
|
query = User.raw('SELECT * FROM users ORDER BY id')
|
|
results = [user.username for user in query.iterator()]
|
|
self.assertEqual(results, ['charlie', 'huey'])
|
|
|
|
# Since we used iterator(), the results were not cached.
|
|
self.assertEqual([u.username for u in query], [])
|
|
|
|
|
|
class TestDefaultValues(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [Sample, SampleMeta]
|
|
|
|
def test_default_present_on_insert(self):
|
|
# Although value is not specified, it has a default, which is included
|
|
# in the INSERT.
|
|
query = Sample.insert(counter=0)
|
|
self.assertSQL(query, (
|
|
'INSERT INTO "sample" ("counter", "value") '
|
|
'VALUES (?, ?)'), [0, 1.0])
|
|
|
|
# Default values are also included when doing bulk inserts.
|
|
query = Sample.insert_many([
|
|
{'counter': '0'},
|
|
{'counter': 1, 'value': 2},
|
|
{'counter': '2'}])
|
|
self.assertSQL(query, (
|
|
'INSERT INTO "sample" ("counter", "value") '
|
|
'VALUES (?, ?), (?, ?), (?, ?)'), [0, 1.0, 1, 2.0, 2, 1.0])
|
|
|
|
query = Sample.insert_many([(0,), (1, 2.)],
|
|
fields=[Sample.counter])
|
|
self.assertSQL(query, (
|
|
'INSERT INTO "sample" ("counter", "value") '
|
|
'VALUES (?, ?), (?, ?)'), [0, 1.0, 1, 2.0])
|
|
|
|
def test_default_present_on_create(self):
|
|
s = Sample.create(counter=3)
|
|
s_db = Sample.get(Sample.counter == 3)
|
|
self.assertEqual(s_db.value, 1.)
|
|
|
|
def test_defaults_from_cursor(self):
|
|
s = Sample.create(counter=1)
|
|
sm1 = SampleMeta.create(sample=s, value=1.)
|
|
sm2 = SampleMeta.create(sample=s, value=2.)
|
|
|
|
# Defaults are not present when doing a read query.
|
|
with self.assertQueryCount(1):
|
|
# Simple query.
|
|
query = (SampleMeta.select(SampleMeta.sample)
|
|
.order_by(SampleMeta.value))
|
|
sm1_db, sm2_db = list(query)
|
|
self.assertIsNone(sm1_db.value)
|
|
self.assertIsNone(sm2_db.value)
|
|
|
|
with self.assertQueryCount(1):
|
|
# Join-graph query.
|
|
query = (SampleMeta
|
|
.select(SampleMeta.sample,
|
|
Sample.counter)
|
|
.join(Sample)
|
|
.order_by(SampleMeta.value))
|
|
|
|
sm1_db, sm2_db = list(query)
|
|
self.assertIsNone(sm1_db.value)
|
|
self.assertIsNone(sm2_db.value)
|
|
self.assertIsNone(sm1_db.sample.value)
|
|
self.assertIsNone(sm2_db.sample.value)
|
|
self.assertEqual(sm1_db.sample.counter, 1)
|
|
self.assertEqual(sm2_db.sample.counter, 1)
|
|
def incrementer():
|
|
d = {'value': 0}
|
|
def increment():
|
|
d['value'] += 1
|
|
return d['value']
|
|
return increment
|
|
|
|
class AutoCounter(TestModel):
|
|
counter = IntegerField(default=incrementer())
|
|
control = IntegerField(default=1)
|
|
|
|
|
|
class TestDefaultDirtyBehavior(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [AutoCounter]
|
|
|
|
def tearDown(self):
|
|
super(TestDefaultDirtyBehavior, self).tearDown()
|
|
AutoCounter._meta.only_save_dirty = False
|
|
|
|
def test_default_dirty(self):
|
|
AutoCounter._meta.only_save_dirty = True
|
|
|
|
ac = AutoCounter()
|
|
ac.save()
|
|
self.assertEqual(ac.counter, 1)
|
|
self.assertEqual(ac.control, 1)
|
|
|
|
ac_db = AutoCounter.get((AutoCounter.counter == 1) &
|
|
(AutoCounter.control == 1))
|
|
self.assertEqual(ac_db.counter, 1)
|
|
self.assertEqual(ac_db.control, 1)
|
|
|
|
# No changes.
|
|
self.assertFalse(ac_db.save())
|
|
|
|
ac = AutoCounter.create()
|
|
self.assertEqual(ac.counter, 2)
|
|
self.assertEqual(ac.control, 1)
|
|
|
|
AutoCounter._meta.only_save_dirty = False
|
|
|
|
ac = AutoCounter()
|
|
self.assertEqual(ac.counter, 3)
|
|
self.assertEqual(ac.control, 1)
|
|
ac.save()
|
|
|
|
ac_db = AutoCounter.get(AutoCounter.id == ac.id)
|
|
self.assertEqual(ac_db.counter, 3)
|
|
|
|
@requires_models(Person)
|
|
def test_save_only_dirty(self):
|
|
today = datetime.date.today()
|
|
try:
|
|
for only_save_dirty in (False, True):
|
|
Person._meta.only_save_dirty = only_save_dirty
|
|
|
|
p = Person.create(first='f', last='l', dob=today)
|
|
p.first = 'f2'
|
|
p.last = 'l2'
|
|
p.save(only=[Person.first])
|
|
self.assertEqual(p.dirty_fields, [Person.last])
|
|
self.assertFalse('first' in p.dirty_field_names)
|
|
self.assertTrue('last' in p.dirty_field_names)
|
|
|
|
p_db = Person.get(Person.id == p.id)
|
|
self.assertEqual((p_db.first, p_db.last), ('f2', 'l'))
|
|
|
|
p.save()
|
|
self.assertEqual(p.dirty_fields, [])
|
|
|
|
p_db = Person.get(Person.id == p.id)
|
|
self.assertEqual((p_db.first, p_db.last), ('f2', 'l2'))
|
|
|
|
p.delete_instance()
|
|
finally:
|
|
# Reset only_save_dirty property for other tests.
|
|
Person._meta.only_save_dirty = False
|
|
|
|
|
|
class TestFunctionCoerce(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [Sample]
|
|
|
|
def test_coerce(self):
|
|
for i in range(3):
|
|
Sample.create(counter=i, value=i)
|
|
|
|
counter_group = fn.GROUP_CONCAT(Sample.counter).coerce(False)
|
|
query = Sample.select(counter_group.alias('counter'))
|
|
self.assertEqual(query.get().counter, '0,1,2')
|
|
|
|
query = Sample.select(counter_group.alias('counter_group'))
|
|
self.assertEqual(query.get().counter_group, '0,1,2')
|
|
|
|
query = Sample.select(counter_group)
|
|
self.assertEqual(query.scalar(), '0,1,2')
|
|
|
|
def test_scalar(self):
|
|
for i in range(4):
|
|
Sample.create(counter=i, value=i)
|
|
query = Sample.select(fn.SUM(Sample.counter).alias('total'))
|
|
self.assertEqual(query.scalar(), 6)
|
|
self.assertEqual(query.scalar(as_tuple=True), (6,))
|
|
self.assertEqual(query.scalar(as_dict=True), {'total': 6})
|
|
|
|
Sample.delete().execute()
|
|
self.assertTrue(query.scalar() is None)
|
|
self.assertEqual(query.scalar(as_tuple=True), (None,))
|
|
self.assertEqual(query.scalar(as_dict=True), {'total': None})
|
|
|
|
def test_safe_python_value(self):
|
|
for i in range(3):
|
|
Sample.create(counter=i, value=i)
|
|
|
|
counter_group = fn.GROUP_CONCAT(Sample.counter)
|
|
query = Sample.select(counter_group.alias('counter'))
|
|
self.assertEqual(query.get().counter, '0,1,2')
|
|
self.assertEqual(query.scalar(), '0,1,2')
|
|
|
|
query = Sample.select(counter_group.alias('counter_group'))
|
|
self.assertEqual(query.get().counter_group, '0,1,2')
|
|
self.assertEqual(query.scalar(), '0,1,2')
|
|
|
|
def test_conv_using_python_value(self):
|
|
for i in range(3):
|
|
Sample.create(counter=i, value=i)
|
|
|
|
counter = (fn
|
|
.GROUP_CONCAT(Sample.counter)
|
|
.python_value(lambda x: [int(i) for i in x.split(',')]))
|
|
query = Sample.select(counter.alias('counter'))
|
|
self.assertEqual(query.get().counter, [0, 1, 2])
|
|
|
|
query = Sample.select(counter.alias('counter_group'))
|
|
self.assertEqual(query.get().counter_group, [0, 1, 2])
|
|
|
|
query = Sample.select(counter)
|
|
self.assertEqual(query.scalar(), [0, 1, 2])
|
|
|
|
@requires_models(Category, Sample)
|
|
def test_no_coerce_count_avg(self):
|
|
for i in range(10):
|
|
Category.create(name=str(i))
|
|
|
|
# COUNT() does not result in the value being coerced.
|
|
query = Category.select(fn.COUNT(Category.name))
|
|
self.assertEqual(query.scalar(), 10)
|
|
|
|
# Force the value to be coerced using the field's db_value().
|
|
query = Category.select(fn.COUNT(Category.name).coerce(True))
|
|
self.assertEqual(query.scalar(), '10')
|
|
|
|
# Ensure avg over an integer field is returned as a float.
|
|
Sample.insert_many([(1, 0), (2, 0)]).execute()
|
|
query = Sample.select(fn.AVG(Sample.counter).alias('a'))
|
|
self.assertEqual(query.get().a, 1.5)
|
|
|
|
|
|
class T1(TestModel):
|
|
pk = AutoField()
|
|
value = IntegerField()
|
|
|
|
class T2(TestModel):
|
|
pk = IntegerField(constraints=[SQL('DEFAULT 3')], primary_key=True)
|
|
value = IntegerField()
|
|
|
|
class T3(TestModel):
|
|
pk = IntegerField(primary_key=True)
|
|
value = IntegerField()
|
|
|
|
class T4(TestModel):
|
|
pk1 = IntegerField()
|
|
pk2 = IntegerField()
|
|
value = IntegerField()
|
|
class Meta:
|
|
primary_key = CompositeKey('pk1', 'pk2')
|
|
|
|
|
|
class TestPrimaryKeySaveHandling(ModelTestCase):
|
|
requires = [T1, T2, T3, T4]
|
|
|
|
def test_auto_field(self):
|
|
# AutoField will be inserted if the PK is not set, after which the new
|
|
# ID will be populated.
|
|
t11 = T1(value=1)
|
|
self.assertEqual(t11.save(), 1)
|
|
self.assertTrue(t11.pk is not None)
|
|
|
|
# Calling save() a second time will issue an update.
|
|
t11.value = 100
|
|
self.assertEqual(t11.save(), 1)
|
|
|
|
# Verify the record was updated.
|
|
t11_db = T1[t11.pk]
|
|
self.assertEqual(t11_db.value, 100)
|
|
|
|
# We can explicitly specify the value of an auto-incrementing
|
|
# primary-key, but we must be sure to call save(force_insert=True),
|
|
# otherwise peewee will attempt to do an update.
|
|
t12 = T1(pk=1337, value=2)
|
|
self.assertEqual(t12.save(), 0)
|
|
self.assertEqual(T1.select().count(), 1)
|
|
self.assertEqual(t12.save(force_insert=True), 1)
|
|
|
|
# Attempting to force-insert an already-existing PK will fail with an
|
|
# integrity error.
|
|
with self.database.atomic():
|
|
with self.assertRaises(IntegrityError):
|
|
t12.value = 3
|
|
t12.save(force_insert=True)
|
|
|
|
query = T1.select().order_by(T1.value).tuples()
|
|
self.assertEqual(list(query), [(1337, 2), (t11.pk, 100)])
|
|
|
|
@requires_pglike
|
|
def test_server_default_pk(self):
|
|
# The new value of the primary-key will be returned to us, since
|
|
# postgres supports RETURNING.
|
|
t2 = T2(value=1)
|
|
self.assertEqual(t2.save(), 1)
|
|
self.assertEqual(t2.pk, 3)
|
|
|
|
# Saving after the PK is set will issue an update.
|
|
t2.value = 100
|
|
self.assertEqual(t2.save(), 1)
|
|
|
|
t2_db = T2[3]
|
|
self.assertEqual(t2_db.value, 100)
|
|
|
|
# If we just set the pk and try to save, peewee issues an update which
|
|
# doesn't have any effect.
|
|
t22 = T2(pk=2, value=20)
|
|
self.assertEqual(t22.save(), 0)
|
|
self.assertEqual(T2.select().count(), 1)
|
|
|
|
# We can force-insert the value we specify explicitly.
|
|
self.assertEqual(t22.save(force_insert=True), 1)
|
|
self.assertEqual(T2[2].value, 20)
|
|
|
|
def test_integer_field_pk(self):
|
|
# For a non-auto-incrementing primary key, we have to use force_insert.
|
|
t3 = T3(pk=2, value=1)
|
|
self.assertEqual(t3.save(), 0) # Oops, attempts to do an update.
|
|
self.assertEqual(T3.select().count(), 0)
|
|
|
|
# Force to be an insert.
|
|
self.assertEqual(t3.save(force_insert=True), 1)
|
|
|
|
# Now we can update the value and call save() to issue an update.
|
|
t3.value = 100
|
|
self.assertEqual(t3.save(), 1)
|
|
|
|
# Verify data is correct.
|
|
t3_db = T3[2]
|
|
self.assertEqual(t3_db.value, 100)
|
|
|
|
def test_composite_pk(self):
|
|
t4 = T4(pk1=1, pk2=2, value=10)
|
|
|
|
# Will attempt to do an update on non-existant rows.
|
|
self.assertEqual(t4.save(), 0)
|
|
self.assertEqual(t4.save(force_insert=True), 1)
|
|
|
|
# Modifying part of the composite PK and attempt an update will fail.
|
|
t4.pk2 = 3
|
|
t4.value = 30
|
|
self.assertEqual(t4.save(), 0)
|
|
|
|
t4.pk2 = 2
|
|
self.assertEqual(t4.save(), 1)
|
|
|
|
t4_db = T4[1, 2]
|
|
self.assertEqual(t4_db.value, 30)
|
|
|
|
@requires_pglike
|
|
def test_returning_object(self):
|
|
query = T2.insert(value=10).returning(T2).objects()
|
|
t2_db, = list(query)
|
|
self.assertEqual(t2_db.pk, 3)
|
|
self.assertEqual(t2_db.value, 10)
|
|
|
|
class T5(TestModel):
|
|
val = IntegerField(null=True)
|
|
|
|
|
|
class TestSaveNoData(ModelTestCase):
|
|
requires = [T5]
|
|
|
|
def test_save_no_data(self):
|
|
t5 = T5.create()
|
|
self.assertTrue(t5.id >= 1)
|
|
|
|
t5.val = 3
|
|
t5.save()
|
|
|
|
t5_db = T5.get(T5.id == t5.id)
|
|
self.assertEqual(t5_db.val, 3)
|
|
|
|
t5.val = None
|
|
t5.save()
|
|
|
|
t5_db = T5.get(T5.id == t5.id)
|
|
self.assertTrue(t5_db.val is None)
|
|
|
|
def test_save_no_data2(self):
|
|
t5 = T5.create()
|
|
|
|
t5_db = T5.get(T5.id == t5.id)
|
|
t5_db.save()
|
|
|
|
t5_db = T5.get(T5.id == t5.id)
|
|
self.assertTrue(t5_db.val is None)
|
|
|
|
def test_save_no_data3(self):
|
|
t5 = T5.create()
|
|
self.assertRaises(ValueError, t5.save)
|
|
|
|
def test_save_only_no_data(self):
|
|
t5 = T5.create(val=1)
|
|
t5.val = 2
|
|
self.assertRaises(ValueError, t5.save, only=[])
|
|
t5_db = T5.get(T5.id == t5.id)
|
|
self.assertEqual(t5_db.val, 1)
|
|
|
|
|
|
class TestDeleteInstance(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [User, Account, Tweet, Favorite, Relationship]
|
|
|
|
def setUp(self):
|
|
super(TestDeleteInstance, self).setUp()
|
|
with self.database.atomic():
|
|
huey = User.create(username='huey')
|
|
acct = Account.create(user=huey, email='huey@meow.com')
|
|
for content in ('meow', 'purr'):
|
|
Tweet.create(user=huey, content=content)
|
|
mickey = User.create(username='mickey')
|
|
woof = Tweet.create(user=mickey, content='woof')
|
|
Favorite.create(user=huey, tweet=woof)
|
|
Favorite.create(user=mickey, tweet=Tweet.create(user=huey,
|
|
content='hiss'))
|
|
|
|
def test_delete_instance_recursive(self):
|
|
huey = User.get(User.username == 'huey')
|
|
a = []
|
|
for d in huey.dependencies():
|
|
a.append(d)
|
|
with self.assertQueryCount(5):
|
|
huey.delete_instance(recursive=True)
|
|
|
|
self.assertHistory(5, [
|
|
('DELETE FROM "favorite" WHERE ("favorite"."user_id" = ?)',
|
|
[huey.id]),
|
|
('DELETE FROM "favorite" WHERE ('
|
|
'"favorite"."tweet_id" IN ('
|
|
'SELECT "t1"."id" FROM "tweet" AS "t1" WHERE ('
|
|
'"t1"."user_id" = ?)))', [huey.id]),
|
|
('DELETE FROM "tweet" WHERE ("tweet"."user_id" = ?)', [huey.id]),
|
|
('UPDATE "account" SET "user_id" = ? '
|
|
'WHERE ("account"."user_id" = ?)',
|
|
[None, huey.id]),
|
|
('DELETE FROM "users" WHERE ("users"."id" = ?)', [huey.id]),
|
|
])
|
|
|
|
# Only one user left.
|
|
self.assertEqual(User.select().count(), 1)
|
|
|
|
# Huey's account has had the FK cleared out.
|
|
acct = Account.get(Account.email == 'huey@meow.com')
|
|
self.assertTrue(acct.user is None)
|
|
|
|
# Huey owned a favorite and one of huey's tweets was the other fav.
|
|
self.assertEqual(Favorite.select().count(), 0)
|
|
|
|
# The only tweet left is mickey's.
|
|
self.assertEqual(Tweet.select().count(), 1)
|
|
tweet = Tweet.get()
|
|
self.assertEqual(tweet.content, 'woof')
|
|
|
|
def test_delete_nullable(self):
|
|
huey = User.get(User.username == 'huey')
|
|
# Favorite -> Tweet -> User (other users' favorites of huey's tweets)
|
|
# Favorite -> User (huey's favorite tweets)
|
|
# Account -> User (huey's account)
|
|
# User ... for a total of 5. Favorite x2, Tweet, Account, User.
|
|
with self.assertQueryCount(5):
|
|
huey.delete_instance(recursive=True, delete_nullable=True)
|
|
|
|
# Get the last 5 delete queries.
|
|
self.assertHistory(5, [
|
|
('DELETE FROM "favorite" WHERE ("favorite"."user_id" = ?)',
|
|
[huey.id]),
|
|
('DELETE FROM "favorite" WHERE ('
|
|
'"favorite"."tweet_id" IN ('
|
|
'SELECT "t1"."id" FROM "tweet" AS "t1" WHERE ('
|
|
'"t1"."user_id" = ?)))', [huey.id]),
|
|
('DELETE FROM "tweet" WHERE ("tweet"."user_id" = ?)', [huey.id]),
|
|
('DELETE FROM "account" WHERE ("account"."user_id" = ?)',
|
|
[huey.id]),
|
|
('DELETE FROM "users" WHERE ("users"."id" = ?)', [huey.id]),
|
|
])
|
|
|
|
self.assertEqual(User.select().count(), 1)
|
|
self.assertEqual(Account.select().count(), 0)
|
|
self.assertEqual(Favorite.select().count(), 0)
|
|
|
|
self.assertEqual(Tweet.select().count(), 1)
|
|
tweet = Tweet.get()
|
|
self.assertEqual(tweet.content, 'woof')
|
|
|
|
|
|
|
|
|
|
class CascadeParent(TestModel):
|
|
name = TextField()
|
|
|
|
class CascadeChild(TestModel):
|
|
parent = ForeignKeyField(CascadeParent, backref='children',
|
|
on_delete='CASCADE')
|
|
data = TextField()
|
|
|
|
|
|
class TestCascadeDeleteIntegration(ModelTestCase):
|
|
requires = [CascadeParent, CascadeChild]
|
|
|
|
def setUp(self):
|
|
super(TestCascadeDeleteIntegration, self).setUp()
|
|
if IS_SQLITE:
|
|
self.database.pragma('foreign_keys', 1)
|
|
|
|
def test_cascade_delete(self):
|
|
p1 = CascadeParent.create(name='p1')
|
|
p2 = CascadeParent.create(name='p2')
|
|
CascadeChild.create(parent=p1, data='c1')
|
|
CascadeChild.create(parent=p1, data='c2')
|
|
CascadeChild.create(parent=p2, data='c3')
|
|
|
|
self.assertEqual(CascadeChild.select().count(), 3)
|
|
p1.delete_instance()
|
|
self.assertEqual(CascadeChild.select().count(), 1)
|
|
self.assertEqual(CascadeChild.get().data, 'c3')
|
|
|
|
# ===========================================================================
|
|
# Joins and aliases
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
class TestJoinModelAlias(ModelTestCase):
|
|
data = (
|
|
('huey', 'meow'),
|
|
('huey', 'purr'),
|
|
('zaizee', 'hiss'),
|
|
('mickey', 'woof'))
|
|
requires = [User, Tweet]
|
|
|
|
def setUp(self):
|
|
super(TestJoinModelAlias, self).setUp()
|
|
users = {}
|
|
for pk, (username, tweet) in enumerate(self.data, 1):
|
|
if username not in users:
|
|
user = User.create(id=len(users) + 1, username=username)
|
|
users[username] = user
|
|
else:
|
|
user = users[username]
|
|
Tweet.create(id=pk, user=user, content=tweet)
|
|
|
|
def _test_query(self, alias_expr):
|
|
UA = alias_expr()
|
|
return (Tweet
|
|
.select(Tweet, UA)
|
|
.order_by(UA.username, Tweet.content))
|
|
|
|
def assertTweets(self, query, user_attr='user'):
|
|
with self.assertQueryCount(1):
|
|
data = [(getattr(tweet, user_attr).username, tweet.content)
|
|
for tweet in query]
|
|
self.assertEqual(sorted(self.data), data)
|
|
|
|
def test_control(self):
|
|
self.assertTweets(self._test_query(lambda: User).join(User))
|
|
|
|
def test_join_aliased_columns(self):
|
|
query = (Tweet
|
|
.select(Tweet.id.alias('tweet_id'), Tweet.content)
|
|
.order_by(Tweet.id))
|
|
self.assertEqual([(t.tweet_id, t.content) for t in query], [
|
|
(1, 'meow'),
|
|
(2, 'purr'),
|
|
(3, 'hiss'),
|
|
(4, 'woof')])
|
|
|
|
query = (Tweet
|
|
.select(Tweet.id.alias('tweet_id'), Tweet.content)
|
|
.join(User)
|
|
.where(User.username == 'huey')
|
|
.order_by(Tweet.id))
|
|
self.assertEqual([(t.tweet_id, t.content) for t in query], [
|
|
(1, 'meow'),
|
|
(2, 'purr')])
|
|
|
|
def test_join(self):
|
|
UA = User.alias('ua')
|
|
query = self._test_query(lambda: UA).join(UA)
|
|
self.assertTweets(query)
|
|
|
|
def test_join_on(self):
|
|
UA = User.alias('ua')
|
|
query = self._test_query(lambda: UA).join(UA, on=(Tweet.user == UA.id))
|
|
self.assertTweets(query)
|
|
|
|
def test_join_on_field(self):
|
|
UA = User.alias('ua')
|
|
query = self._test_query(lambda: UA)
|
|
query = query.join(UA, on=Tweet.user)
|
|
self.assertTweets(query)
|
|
|
|
def test_join_on_alias(self):
|
|
UA = User.alias('ua')
|
|
query = self._test_query(lambda: UA)
|
|
query = query.join(UA, on=(Tweet.user == UA.id).alias('foo'))
|
|
self.assertTweets(query, 'foo')
|
|
|
|
def test_join_attr(self):
|
|
UA = User.alias('ua')
|
|
query = self._test_query(lambda: UA).join(UA, attr='baz')
|
|
self.assertTweets(query, 'baz')
|
|
|
|
def test_join_on_alias_attr(self):
|
|
UA = User.alias('ua')
|
|
q = self._test_query(lambda: UA)
|
|
q = q.join(UA, on=(Tweet.user == UA.id).alias('foo'), attr='bar')
|
|
self.assertTweets(q, 'bar')
|
|
|
|
def _test_query_backref(self, alias_expr):
|
|
TA = alias_expr()
|
|
return (User
|
|
.select(User, TA)
|
|
.order_by(User.username, TA.content))
|
|
|
|
def assertUsers(self, query, tweet_attr='tweet'):
|
|
with self.assertQueryCount(1):
|
|
data = [(user.username, getattr(user, tweet_attr).content)
|
|
for user in query]
|
|
self.assertEqual(sorted(self.data), data)
|
|
|
|
def test_control_backref(self):
|
|
self.assertUsers(self._test_query_backref(lambda: Tweet).join(Tweet))
|
|
|
|
def test_join_backref(self):
|
|
TA = Tweet.alias('ta')
|
|
query = self._test_query_backref(lambda: TA).join(TA)
|
|
self.assertUsers(query)
|
|
|
|
def test_join_on_backref(self):
|
|
TA = Tweet.alias('ta')
|
|
query = self._test_query_backref(lambda: TA)
|
|
query = query.join(TA, on=(User.id == TA.user_id))
|
|
self.assertUsers(query)
|
|
|
|
def test_join_on_field_backref(self):
|
|
TA = Tweet.alias('ta')
|
|
query = self._test_query_backref(lambda: TA)
|
|
query = query.join(TA, on=TA.user)
|
|
self.assertUsers(query)
|
|
|
|
def test_join_on_alias_backref(self):
|
|
TA = Tweet.alias('ta')
|
|
query = self._test_query_backref(lambda: TA)
|
|
query = query.join(TA, on=(User.id == TA.user_id).alias('foo'))
|
|
self.assertUsers(query, 'foo')
|
|
|
|
def test_join_attr_backref(self):
|
|
TA = Tweet.alias('ta')
|
|
query = self._test_query_backref(lambda: TA).join(TA, attr='baz')
|
|
self.assertUsers(query, 'baz')
|
|
|
|
def test_join_alias_twice(self):
|
|
# Test that a model-alias can be both the source and the dest by
|
|
# joining from User -> Tweet -> User (as "foo").
|
|
TA = Tweet.alias('ta')
|
|
UA = User.alias('ua')
|
|
with self.assertQueryCount(1):
|
|
query = (User
|
|
.select(User, TA, UA)
|
|
.join(TA)
|
|
.join(UA, on=(TA.user_id == UA.id).alias('foo'))
|
|
.order_by(User.username, TA.content))
|
|
|
|
data = [(row.username, row.tweet.content, row.tweet.foo.username)
|
|
for row in query]
|
|
|
|
self.assertEqual(data, [
|
|
('huey', 'meow', 'huey'),
|
|
('huey', 'purr', 'huey'),
|
|
('mickey', 'woof', 'mickey'),
|
|
('zaizee', 'hiss', 'zaizee')])
|
|
|
|
def test_alias_filter(self):
|
|
UA = User.alias('ua')
|
|
lookups = ({'ua__username': 'huey'}, {'user__username': 'huey'})
|
|
for lookup in lookups:
|
|
with self.assertQueryCount(1):
|
|
query = (Tweet
|
|
.select(Tweet.content, UA.username)
|
|
.join(UA)
|
|
.filter(**lookup)
|
|
.order_by(Tweet.content))
|
|
|
|
self.assertSQL(query, (
|
|
'SELECT "t1"."content", "ua"."username" '
|
|
'FROM "tweet" AS "t1" '
|
|
'INNER JOIN "users" AS "ua" '
|
|
'ON ("t1"."user_id" = "ua"."id") '
|
|
'WHERE ("ua"."username" = ?) '
|
|
'ORDER BY "t1"."content"'), ['huey'])
|
|
|
|
data = [(t.content, t.user.username) for t in query]
|
|
self.assertEqual(data, [('meow', 'huey'), ('purr', 'huey')])
|
|
|
|
|
|
class TestModelAliasFieldProperties(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
|
|
def test_field_properties(self):
|
|
class Person(TestModel):
|
|
name = TextField()
|
|
dob = DateField()
|
|
class Meta:
|
|
database = self.database
|
|
|
|
class Job(TestModel):
|
|
worker = ForeignKeyField(Person, backref='jobs')
|
|
client = ForeignKeyField(Person, backref='jobs_hired')
|
|
class Meta:
|
|
database = self.database
|
|
|
|
Worker = Person.alias()
|
|
Client = Person.alias()
|
|
|
|
expected_sql = (
|
|
'SELECT "t1"."id", "t1"."worker_id", "t1"."client_id" '
|
|
'FROM "job" AS "t1" '
|
|
'INNER JOIN "person" AS "t2" ON ("t1"."client_id" = "t2"."id") '
|
|
'INNER JOIN "person" AS "t3" ON ("t1"."worker_id" = "t3"."id") '
|
|
'WHERE (date_part(?, "t2"."dob") = ?)')
|
|
expected_params = ['year', 1983]
|
|
|
|
query = (Job
|
|
.select()
|
|
.join(Client, on=(Job.client == Client.id))
|
|
.switch(Job)
|
|
.join(Worker, on=(Job.worker == Worker.id))
|
|
.where(Client.dob.year == 1983))
|
|
self.assertSQL(query, expected_sql, expected_params)
|
|
|
|
query = (Job
|
|
.select()
|
|
.join(Client, on=(Job.client == Client.id))
|
|
.switch(Job)
|
|
.join(Person, on=(Job.worker == Person.id))
|
|
.where(Client.dob.year == 1983))
|
|
self.assertSQL(query, expected_sql, expected_params)
|
|
|
|
query = (Job
|
|
.select()
|
|
.join(Person, on=(Job.client == Person.id))
|
|
.switch(Job)
|
|
.join(Worker, on=(Job.worker == Worker.id))
|
|
.where(Person.dob.year == 1983))
|
|
self.assertSQL(query, expected_sql, expected_params)
|
|
|
|
|
|
class TestJoinSubquery(ModelTestCase):
|
|
requires = [Person, Relationship]
|
|
|
|
def test_join_subquery(self):
|
|
# Set up some relationships such that there exists a relationship from
|
|
# the left-hand to the right-hand name.
|
|
data = (
|
|
('charlie', None),
|
|
('huey', 'charlie'),
|
|
('mickey', 'charlie'),
|
|
('zaizee', 'charlie'),
|
|
('zaizee', 'huey'))
|
|
people = {}
|
|
def get_person(name):
|
|
if name not in people:
|
|
people[name] = Person.create(first=name, last=name,
|
|
dob=datetime.date(2017, 1, 1))
|
|
return people[name]
|
|
|
|
for person, related_to in data:
|
|
p1 = get_person(person)
|
|
if related_to is not None:
|
|
p2 = get_person(related_to)
|
|
Relationship.create(from_person=p1, to_person=p2)
|
|
|
|
# Create the subquery.
|
|
Friend = Person.alias('friend')
|
|
subq = (Relationship
|
|
.select(Friend.first.alias('friend_name'),
|
|
Relationship.from_person)
|
|
.join(Friend, on=(Relationship.to_person == Friend.id))
|
|
.alias('subq'))
|
|
|
|
# Outer query does a LEFT OUTER JOIN. We join on the subquery because
|
|
# it uses an INNER JOIN, saving us doing two LEFT OUTER joins in the
|
|
# single query.
|
|
query = (Person
|
|
.select(Person.first, subq.c.friend_name)
|
|
.join(subq, JOIN.LEFT_OUTER,
|
|
on=(Person.id == subq.c.from_person_id))
|
|
.order_by(Person.first, subq.c.friend_name))
|
|
self.assertSQL(query, (
|
|
'SELECT "t1"."first", "subq"."friend_name" '
|
|
'FROM "person" AS "t1" '
|
|
'LEFT OUTER JOIN ('
|
|
'SELECT "friend"."first" AS "friend_name", "t2"."from_person_id" '
|
|
'FROM "relationship" AS "t2" '
|
|
'INNER JOIN "person" AS "friend" '
|
|
'ON ("t2"."to_person_id" = "friend"."id")) AS "subq" '
|
|
'ON ("t1"."id" = "subq"."from_person_id") '
|
|
'ORDER BY "t1"."first", "subq"."friend_name"'), [])
|
|
|
|
db_data = [row for row in query.tuples()]
|
|
self.assertEqual(db_data, list(data))
|
|
|
|
|
|
class Task(TestModel):
|
|
heading = ForeignKeyField('self', backref='tasks', null=True)
|
|
project = ForeignKeyField('self', backref='projects', null=True)
|
|
title = TextField()
|
|
type = IntegerField()
|
|
|
|
PROJECT = 1
|
|
HEADING = 2
|
|
|
|
|
|
class TestMultiSelfJoin(ModelTestCase):
|
|
requires = [Task]
|
|
|
|
def setUp(self):
|
|
super(TestMultiSelfJoin, self).setUp()
|
|
|
|
with self.database.atomic():
|
|
p_dev = Task.create(title='dev', type=Task.PROJECT)
|
|
p_p = Task.create(title='peewee', project=p_dev, type=Task.PROJECT)
|
|
p_h = Task.create(title='huey', project=p_dev, type=Task.PROJECT)
|
|
|
|
heading_data = (
|
|
('peewee-1', p_p, 2),
|
|
('peewee-2', p_p, 0),
|
|
('huey-1', p_h, 1),
|
|
('huey-2', p_h, 1))
|
|
for title, proj, n_subtasks in heading_data:
|
|
t = Task.create(title=title, project=proj, type=Task.HEADING)
|
|
for i in range(n_subtasks):
|
|
Task.create(title='%s-%s' % (title, i + 1), project=proj,
|
|
heading=t, type=Task.HEADING)
|
|
|
|
def test_multi_self_join(self):
|
|
Project = Task.alias()
|
|
Heading = Task.alias()
|
|
query = (Task
|
|
.select(Task, Project, Heading)
|
|
.join(Heading, JOIN.LEFT_OUTER,
|
|
on=(Task.heading == Heading.id).alias('heading'))
|
|
.switch(Task)
|
|
.join(Project, JOIN.LEFT_OUTER,
|
|
on=(Task.project == Project.id).alias('project'))
|
|
.order_by(Task.id))
|
|
|
|
with self.assertQueryCount(1):
|
|
accum = []
|
|
for task in query:
|
|
h_title = task.heading.title if task.heading else None
|
|
p_title = task.project.title if task.project else None
|
|
accum.append((task.title, h_title, p_title))
|
|
|
|
self.assertEqual(accum, [
|
|
# title - heading - project
|
|
('dev', None, None),
|
|
('peewee', None, 'dev'),
|
|
('huey', None, 'dev'),
|
|
('peewee-1', None, 'peewee'),
|
|
('peewee-1-1', 'peewee-1', 'peewee'),
|
|
('peewee-1-2', 'peewee-1', 'peewee'),
|
|
('peewee-2', None, 'peewee'),
|
|
('huey-1', None, 'huey'),
|
|
('huey-1-1', 'huey-1', 'huey'),
|
|
('huey-2', None, 'huey'),
|
|
('huey-2-1', 'huey-2', 'huey'),
|
|
])
|
|
|
|
class CJ_A(TestModel):
|
|
id = IntegerField(primary_key=True)
|
|
class CJ_B(TestModel):
|
|
id = IntegerField(primary_key=True)
|
|
class CJ_C(TestModel):
|
|
id = IntegerField(primary_key=True)
|
|
a = ForeignKeyField(CJ_A)
|
|
b = ForeignKeyField(CJ_B)
|
|
|
|
class TestCrossJoin(ModelTestCase):
|
|
requires = [CJ_A, CJ_B, CJ_C]
|
|
|
|
def setUp(self):
|
|
super(TestCrossJoin, self).setUp()
|
|
CJ_A.insert_many([(1,), (2,), (3,)], fields=[CJ_A.id]).execute()
|
|
CJ_B.insert_many([(1,), (2,)], fields=[CJ_B.id]).execute()
|
|
CJ_C.insert_many([
|
|
(1, 1, 1),
|
|
(2, 1, 2),
|
|
(3, 2, 1)], fields=[CJ_C.id, CJ_C.a, CJ_C.b]).execute()
|
|
|
|
def test_cross_join(self):
|
|
query = (CJ_A
|
|
.select(CJ_A.id.alias('aid'), CJ_B.id.alias('bid'))
|
|
.join(CJ_B, JOIN.CROSS)
|
|
.join(CJ_C, JOIN.LEFT_OUTER, on=(
|
|
(CJ_C.a == CJ_A.id) &
|
|
(CJ_C.b == CJ_B.id)))
|
|
.where(CJ_C.id.is_null())
|
|
.order_by(CJ_A.id, CJ_B.id))
|
|
self.assertEqual(list(query.tuples()), [(2, 2), (3, 1), (3, 2)])
|
|
|
|
|
|
|
|
class Student(TestModel):
|
|
name = TextField()
|
|
|
|
class Course(TestModel):
|
|
name = TextField()
|
|
|
|
class Attendance(TestModel):
|
|
student = ForeignKeyField(Student)
|
|
course = ForeignKeyField(Course)
|
|
|
|
|
|
class TestManyToManyJoining(ModelTestCase):
|
|
requires = [Student, Course, Attendance]
|
|
|
|
def setUp(self):
|
|
super(TestManyToManyJoining, self).setUp()
|
|
|
|
data = (
|
|
('charlie', ('eng101', 'cs101', 'cs111')),
|
|
('huey', ('cats1', 'cats2', 'cats3')),
|
|
('zaizee', ('cats2', 'cats3')))
|
|
c = {}
|
|
with self.database.atomic():
|
|
for name, courses in data:
|
|
student = Student.create(name=name)
|
|
for course in courses:
|
|
if course not in c:
|
|
c[course] = Course.create(name=course)
|
|
Attendance.create(student=student, course=c[course])
|
|
|
|
def assertQuery(self, query):
|
|
with self.assertQueryCount(1):
|
|
query = query.order_by(Attendance.id)
|
|
results = [(a.student.name, a.course.name) for a in query]
|
|
self.assertEqual(results, [
|
|
('charlie', 'eng101'),
|
|
('charlie', 'cs101'),
|
|
('charlie', 'cs111'),
|
|
('huey', 'cats1'),
|
|
('huey', 'cats2'),
|
|
('zaizee', 'cats2')])
|
|
|
|
def test_join_subquery(self):
|
|
courses = (Course
|
|
.select(Course.id, Course.name)
|
|
.order_by(Course.id)
|
|
.limit(5))
|
|
query = (Attendance
|
|
.select(Attendance, Student, courses.c.name)
|
|
.join_from(Attendance, Student)
|
|
.join_from(Attendance, courses,
|
|
on=(Attendance.course == courses.c.id)))
|
|
self.assertQuery(query)
|
|
|
|
@skip_if(IS_MYSQL)
|
|
def test_join_where_subquery(self):
|
|
courses = Course.select().order_by(Course.id).limit(5)
|
|
query = (Attendance
|
|
.select(Attendance, Student, Course)
|
|
.join_from(Attendance, Student)
|
|
.join_from(Attendance, Course)
|
|
.where(Attendance.course.in_(courses)))
|
|
self.assertQuery(query)
|
|
|
|
class Player(TestModel):
|
|
name = TextField()
|
|
|
|
class Game(TestModel):
|
|
name = TextField()
|
|
player = ForeignKeyField(Player)
|
|
|
|
class Score(TestModel):
|
|
game = ForeignKeyField(Game)
|
|
points = IntegerField()
|
|
|
|
|
|
class TestJoinSubqueryAggregateViaLeftOuter(ModelTestCase):
|
|
requires = [Player, Game, Score]
|
|
|
|
def test_join_subquery_aggregate_left_outer(self):
|
|
with self.database.atomic():
|
|
p1, p2 = [Player.create(name=name) for name in ('p1', 'p2')]
|
|
games = []
|
|
for p in (p1, p2):
|
|
for gnum in (1, 2):
|
|
g = Game.create(name='%s-g%s' % (p.name, gnum), player=p)
|
|
games.append(g)
|
|
|
|
score_list = (
|
|
(10, 20, 30),
|
|
(),
|
|
(100, 110, 100),
|
|
(50, 50))
|
|
for g, plist in zip(games, score_list):
|
|
for p in plist:
|
|
Score.create(game=g, points=p)
|
|
|
|
subq = (Game
|
|
.select(Game.player, fn.SUM(Score.points).alias('ptotal'),
|
|
fn.AVG(Score.points).alias('pavg'))
|
|
.join(Score, JOIN.LEFT_OUTER)
|
|
.group_by(Game.player))
|
|
query = (Player
|
|
.select(Player, subq.c.ptotal, subq.c.pavg)
|
|
.join(subq, on=(Player.id == subq.c.player_id))
|
|
.order_by(Player.name))
|
|
|
|
with self.assertQueryCount(1):
|
|
results = [(p.name, p.game.ptotal, p.game.pavg) for p in query]
|
|
|
|
self.assertEqual(results, [('p1', 60, 20), ('p2', 410, 82)])
|
|
|
|
with self.assertQueryCount(1):
|
|
obj_query = query.objects()
|
|
results = [(p.name, p.ptotal, p.pavg) for p in obj_query]
|
|
|
|
self.assertEqual(results, [('p1', 60, 20), ('p2', 410, 82)])
|
|
|
|
# ===========================================================================
|
|
# Advanced query features (window functions, tuples, compound selects, etc.)
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
@skip_unless(
|
|
IS_POSTGRESQL or IS_MYSQL_ADVANCED_FEATURES or IS_SQLITE_25 or IS_CRDB,
|
|
'window function')
|
|
class TestWindowFunctionIntegration(ModelTestCase):
|
|
requires = [Sample]
|
|
|
|
def setUp(self):
|
|
super(TestWindowFunctionIntegration, self).setUp()
|
|
values = ((1, 10), (1, 20), (2, 1), (2, 3), (3, 100))
|
|
with self.database.atomic():
|
|
for counter, value in values:
|
|
Sample.create(counter=counter, value=value)
|
|
|
|
def test_simple_partition(self):
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value,
|
|
fn.AVG(Sample.value).over(
|
|
partition_by=[Sample.counter]))
|
|
.order_by(Sample.counter, Sample.value)
|
|
.tuples())
|
|
expected = [
|
|
(1, 10., 15.),
|
|
(1, 20., 15.),
|
|
(2, 1., 2.),
|
|
(2, 3., 2.),
|
|
(3, 100., 100.)]
|
|
self.assertEqual(list(query), expected)
|
|
|
|
window = Window(partition_by=[Sample.counter])
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value,
|
|
fn.AVG(Sample.value).over(window))
|
|
.window(window)
|
|
.order_by(Sample.counter, Sample.value)
|
|
.tuples())
|
|
self.assertEqual(list(query), expected)
|
|
|
|
def test_mixed_ordering(self):
|
|
s = fn.SUM(Sample.value).over(order_by=[Sample.value])
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value, s.alias('rtotal'))
|
|
.order_by(Sample.id))
|
|
# We end up with window going 1., 3., 10., 20., 100..
|
|
# So:
|
|
# 1 | 10 | (1 + 3 + 10)
|
|
# 1 | 20 | (1 + 3 + 10 + 20)
|
|
# 2 | 1 | (1)
|
|
# 2 | 3 | (1 + 3)
|
|
# 3 | 100 | (1 + 3 + 10 + 20 + 100)
|
|
self.assertEqual([(r.counter, r.value, r.rtotal) for r in query], [
|
|
(1, 10., 14.),
|
|
(1, 20., 34.),
|
|
(2, 1., 1.),
|
|
(2, 3., 4.),
|
|
(3, 100., 134.)])
|
|
|
|
def test_reuse_window(self):
|
|
w = Window(order_by=[Sample.value])
|
|
with self.database.atomic():
|
|
Sample.delete().execute()
|
|
for i in range(10):
|
|
Sample.create(counter=i, value=10 * i)
|
|
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value,
|
|
fn.NTILE(4).over(w).alias('quartile'),
|
|
fn.NTILE(5).over(w).alias('quintile'),
|
|
fn.NTILE(100).over(w).alias('percentile'))
|
|
.window(w)
|
|
.order_by(Sample.id))
|
|
results = [(r.counter, r.value, r.quartile, r.quintile, r.percentile)
|
|
for r in query]
|
|
self.assertEqual(results, [
|
|
# ct, v, 4tile, 5tile, 100tile
|
|
(0, 0., 1, 1, 1),
|
|
(1, 10., 1, 1, 2),
|
|
(2, 20., 1, 2, 3),
|
|
(3, 30., 2, 2, 4),
|
|
(4, 40., 2, 3, 5),
|
|
(5, 50., 2, 3, 6),
|
|
(6, 60., 3, 4, 7),
|
|
(7, 70., 3, 4, 8),
|
|
(8, 80., 4, 5, 9),
|
|
(9, 90., 4, 5, 10),
|
|
])
|
|
|
|
def test_ordered_window(self):
|
|
window = Window(partition_by=[Sample.counter],
|
|
order_by=[Sample.value.desc()])
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value,
|
|
fn.RANK().over(window=window).alias('rank'))
|
|
.window(window)
|
|
.order_by(Sample.counter, fn.RANK().over(window=window))
|
|
.tuples())
|
|
self.assertEqual(list(query), [
|
|
(1, 20., 1),
|
|
(1, 10., 2),
|
|
(2, 3., 1),
|
|
(2, 1., 2),
|
|
(3, 100., 1)])
|
|
|
|
def test_two_windows(self):
|
|
w1 = Window(partition_by=[Sample.counter]).alias('w1')
|
|
w2 = Window(order_by=[Sample.counter]).alias('w2')
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value,
|
|
fn.AVG(Sample.value).over(window=w1),
|
|
fn.RANK().over(window=w2))
|
|
.window(w1, w2)
|
|
.order_by(Sample.id)
|
|
.tuples())
|
|
self.assertEqual(list(query), [
|
|
(1, 10., 15., 1),
|
|
(1, 20., 15., 1),
|
|
(2, 1., 2., 3),
|
|
(2, 3., 2., 3),
|
|
(3, 100., 100., 5)])
|
|
|
|
def test_empty_over(self):
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value,
|
|
fn.LAG(Sample.counter, 1).over(order_by=[Sample.id]))
|
|
.order_by(Sample.id)
|
|
.tuples())
|
|
self.assertEqual(list(query), [
|
|
(1, 10., None),
|
|
(1, 20., 1),
|
|
(2, 1., 1),
|
|
(2, 3., 2),
|
|
(3, 100., 2)])
|
|
|
|
def test_bounds(self):
|
|
query = (Sample
|
|
.select(Sample.value,
|
|
fn.SUM(Sample.value).over(
|
|
partition_by=[Sample.counter],
|
|
start=Window.preceding(),
|
|
end=Window.following(1)))
|
|
.order_by(Sample.id)
|
|
.tuples())
|
|
self.assertEqual(list(query), [
|
|
(10., 30.),
|
|
(20., 30.),
|
|
(1., 4.),
|
|
(3., 4.),
|
|
(100., 100.)])
|
|
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value,
|
|
fn.SUM(Sample.value).over(
|
|
order_by=[Sample.id],
|
|
start=Window.preceding(2)))
|
|
.order_by(Sample.id)
|
|
.tuples())
|
|
self.assertEqual(list(query), [
|
|
(1, 10., 10.),
|
|
(1, 20., 30.),
|
|
(2, 1., 31.),
|
|
(2, 3., 24.),
|
|
(3, 100., 104.)])
|
|
|
|
def test_frame_types(self):
|
|
Sample.create(counter=1, value=20.)
|
|
Sample.create(counter=2, value=1.) # Observe logical peer handling.
|
|
|
|
# Defaults to RANGE.
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value,
|
|
fn.SUM(Sample.value).over(
|
|
order_by=[Sample.counter, Sample.value]))
|
|
.order_by(Sample.id))
|
|
self.assertEqual(list(query.tuples()), [
|
|
(1, 10., 10.),
|
|
(1, 20., 50.),
|
|
(2, 1., 52.),
|
|
(2, 3., 55.),
|
|
(3, 100., 155.),
|
|
(1, 20., 50.),
|
|
(2, 1., 52.)])
|
|
|
|
# Explicitly specify ROWS.
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value,
|
|
fn.SUM(Sample.value).over(
|
|
order_by=[Sample.counter, Sample.value],
|
|
frame_type=Window.ROWS))
|
|
.order_by(Sample.counter, Sample.value))
|
|
self.assertEqual(list(query.tuples()), [
|
|
(1, 10., 10.),
|
|
(1, 20., 30.),
|
|
(1, 20., 50.),
|
|
(2, 1., 51.),
|
|
(2, 1., 52.),
|
|
(2, 3., 55.),
|
|
(3, 100., 155.)])
|
|
|
|
# Including a boundary results in ROWS.
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value,
|
|
fn.SUM(Sample.value).over(
|
|
order_by=[Sample.counter, Sample.value],
|
|
start=Window.preceding(2)))
|
|
.order_by(Sample.counter, Sample.value))
|
|
self.assertEqual(list(query.tuples()), [
|
|
(1, 10., 10.),
|
|
(1, 20., 30.),
|
|
(1, 20., 50.),
|
|
(2, 1., 41.),
|
|
(2, 1., 22.),
|
|
(2, 3., 5.),
|
|
(3, 100., 104.)])
|
|
|
|
@skip_if(IS_MYSQL, 'requires OVER() with FILTER')
|
|
def test_filter_clause(self):
|
|
condsum = fn.SUM(Sample.value).filter(Sample.counter > 1).over(
|
|
order_by=[Sample.id], start=Window.preceding(1))
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value, condsum.alias('cs'))
|
|
.order_by(Sample.value))
|
|
self.assertEqual(list(query.tuples()), [
|
|
(2, 1., 1.),
|
|
(2, 3., 4.),
|
|
(1, 10., None),
|
|
(1, 20., None),
|
|
(3, 100., 103.),
|
|
])
|
|
|
|
@skip_if(IS_MYSQL or (IS_SQLITE and not IS_SQLITE_30),
|
|
'requires FILTER with aggregates')
|
|
def test_filter_with_aggregate(self):
|
|
condsum = fn.SUM(Sample.value).filter(Sample.counter > 1)
|
|
query = (Sample
|
|
.select(Sample.counter, condsum.alias('cs'))
|
|
.group_by(Sample.counter)
|
|
.order_by(Sample.counter))
|
|
self.assertEqual(list(query.tuples()), [
|
|
(1, None),
|
|
(2, 4.),
|
|
(3, 100.)])
|
|
|
|
def test_row_number(self):
|
|
query = (Sample
|
|
.select(Sample.counter,
|
|
fn.ROW_NUMBER().over(
|
|
order_by=[Sample.counter]).alias('rn'))
|
|
.order_by(Sample.counter)
|
|
.tuples())
|
|
self.assertEqual(list(query),
|
|
[(1, 1), (1, 2), (2, 3), (2, 4), (3, 5)])
|
|
|
|
def test_sum_with_frame(self):
|
|
w = Window(order_by=[Sample.counter],
|
|
frame_type=Window.ROWS,
|
|
start=Window.preceding(1),
|
|
end=Window.CURRENT_ROW)
|
|
query = (Sample
|
|
.select(Sample.counter,
|
|
fn.SUM(Sample.value).over(w).alias('rsum'))
|
|
.window(w)
|
|
.order_by(Sample.counter)
|
|
.tuples())
|
|
results = list(query)
|
|
# Each row sums current + previous row's value.
|
|
self.assertEqual(results, [
|
|
(1, 10.0), # just 10
|
|
(1, 30.0), # 10 + 20
|
|
(2, 21.0), # 20 + 1
|
|
(2, 4.0), # 1 + 3
|
|
(3, 103.0)]) # 3 + 100
|
|
|
|
@skip_if(IS_MYSQL, 'flaky on mysql')
|
|
def test_lag_lead(self):
|
|
query = (Sample
|
|
.select(Sample.counter,
|
|
fn.LAG(Sample.value, 1).over(
|
|
order_by=[Sample.counter]).alias('prev'),
|
|
fn.LEAD(Sample.value, 1).over(
|
|
order_by=[Sample.counter]).alias('next'))
|
|
.order_by(Sample.counter)
|
|
.tuples())
|
|
results = list(query)
|
|
self.assertEqual(results, [
|
|
(1, None, 20.0),
|
|
(1, 10.0, 1.0),
|
|
(2, 20.0, 3.0),
|
|
(2, 1.0, 100.0),
|
|
(3, 3.0, None)])
|
|
#values = ((1, 10), (1, 20), (2, 1), (2, 3), (3, 100))
|
|
|
|
|
|
@skip_if(not IS_SQLITE_15, 'requires row-values')
|
|
class TestTupleComparison(ModelTestCase):
|
|
requires = [User]
|
|
|
|
def test_tuples(self):
|
|
ua, ub, uc = [User.create(username=username) for username in 'abc']
|
|
query = User.select().where(
|
|
Tuple(User.username, User.id) == ('b', ub.id))
|
|
self.assertSQL(query, (
|
|
'SELECT "t1"."id", "t1"."username" FROM "users" AS "t1" '
|
|
'WHERE (("t1"."username", "t1"."id") = (?, ?))'), ['b', ub.id])
|
|
self.assertEqual(query.count(), 1)
|
|
obj = query.get()
|
|
self.assertEqual(obj, ub)
|
|
|
|
def test_tuple_subquery(self):
|
|
ua, ub, uc = [User.create(username=username) for username in 'abc']
|
|
UA = User.alias()
|
|
subquery = (UA
|
|
.select(UA.username, UA.id)
|
|
.where(UA.username != 'b'))
|
|
|
|
query = (User
|
|
.select(User.username)
|
|
.where(Tuple(User.username, User.id).in_(subquery))
|
|
.order_by(User.username))
|
|
self.assertEqual([u.username for u in query], ['a', 'c'])
|
|
|
|
@requires_models(CPK)
|
|
def test_row_value_composite_key(self):
|
|
CPK.insert_many([('k1', 1, 1), ('k2', 2, 2), ('k3', 3, 3)]).execute()
|
|
|
|
cpk = CPK.get(CPK._meta.primary_key == ('k2', 2))
|
|
self.assertEqual(cpk._pk, ('k2', 2))
|
|
|
|
cpk = CPK['k3', 3]
|
|
self.assertEqual(cpk._pk, ('k3', 3))
|
|
|
|
uq = CPK.update(extra=20).where(CPK._meta.primary_key != ('k2', 2))
|
|
uq.execute()
|
|
|
|
self.assertEqual(list(sorted(CPK.select().tuples())), [
|
|
('k1', 1, 20), ('k2', 2, 2), ('k3', 3, 20)])
|
|
|
|
|
|
class CNote(TestModel):
|
|
content = TextField()
|
|
timestamp = TimestampField()
|
|
|
|
class CFile(TestModel):
|
|
filename = CharField(primary_key=True)
|
|
data = TextField()
|
|
timestamp = TimestampField()
|
|
|
|
|
|
class TestCompoundSelectModels(ModelTestCase):
|
|
requires = [CFile, CNote]
|
|
|
|
def setUp(self):
|
|
super(TestCompoundSelectModels, self).setUp()
|
|
def generate_ts():
|
|
i = [0]
|
|
def _inner():
|
|
i[0] += 1
|
|
return datetime.datetime(2018, 1, i[0])
|
|
return _inner
|
|
make_ts = generate_ts()
|
|
self.ts = lambda i: datetime.datetime(2018, 1, i)
|
|
|
|
with self.database.atomic():
|
|
for i, content in enumerate(('note-a', 'note-b', 'note-c'), 1):
|
|
CNote.create(id=i, content=content, timestamp=make_ts())
|
|
|
|
file_data = (
|
|
('peewee.txt', 'peewee orm'),
|
|
('walrus.txt', 'walrus redis toolkit'),
|
|
('huey.txt', 'huey task queue'))
|
|
for filename, data in file_data:
|
|
CFile.create(filename=filename, data=data, timestamp=make_ts())
|
|
|
|
def test_mix_models_with_model_row_type(self):
|
|
cast = 'CHAR' if IS_MYSQL else 'TEXT'
|
|
lhs = CNote.select(CNote.id.cast(cast).alias('id_text'),
|
|
CNote.content, CNote.timestamp)
|
|
rhs = CFile.select(CFile.filename, CFile.data, CFile.timestamp)
|
|
query = (lhs | rhs).order_by(SQL('timestamp')).limit(4)
|
|
|
|
data = [(n.id_text, n.content, n.timestamp) for n in query]
|
|
self.assertEqual(data, [
|
|
('1', 'note-a', self.ts(1)),
|
|
('2', 'note-b', self.ts(2)),
|
|
('3', 'note-c', self.ts(3)),
|
|
('peewee.txt', 'peewee orm', self.ts(4))])
|
|
|
|
def test_mixed_models_tuple_row_type(self):
|
|
cast = 'CHAR' if IS_MYSQL else 'TEXT'
|
|
lhs = CNote.select(CNote.id.cast(cast).alias('id'),
|
|
CNote.content, CNote.timestamp)
|
|
rhs = CFile.select(CFile.filename, CFile.data, CFile.timestamp)
|
|
query = (lhs | rhs).order_by(SQL('timestamp')).limit(5)
|
|
|
|
self.assertEqual(list(query.tuples()), [
|
|
('1', 'note-a', self.ts(1)),
|
|
('2', 'note-b', self.ts(2)),
|
|
('3', 'note-c', self.ts(3)),
|
|
('peewee.txt', 'peewee orm', self.ts(4)),
|
|
('walrus.txt', 'walrus redis toolkit', self.ts(5))])
|
|
|
|
def test_mixed_models_dict_row_type(self):
|
|
notes = CNote.select(CNote.content, CNote.timestamp)
|
|
files = CFile.select(CFile.filename, CFile.timestamp)
|
|
|
|
query = (notes | files).order_by(SQL('timestamp').desc()).limit(4)
|
|
self.assertEqual(list(query.dicts()), [
|
|
{'content': 'huey.txt', 'timestamp': self.ts(6)},
|
|
{'content': 'walrus.txt', 'timestamp': self.ts(5)},
|
|
{'content': 'peewee.txt', 'timestamp': self.ts(4)},
|
|
{'content': 'note-c', 'timestamp': self.ts(3)}])
|
|
def _create_users_tweets(db):
|
|
data = (
|
|
('huey', ('meow', 'hiss', 'purr')),
|
|
('mickey', ('woof', 'bark')),
|
|
('zaizee', ()))
|
|
with db.atomic():
|
|
for username, tweets in data:
|
|
user = User.create(username=username)
|
|
for tweet in tweets:
|
|
Tweet.create(user=user, content=tweet)
|
|
|
|
|
|
class TestSubqueryInSelect(ModelTestCase):
|
|
requires = [User, Tweet]
|
|
|
|
def setUp(self):
|
|
super(TestSubqueryInSelect, self).setUp()
|
|
_create_users_tweets(self.database)
|
|
|
|
def test_subquery_in_select(self):
|
|
subq = User.select().where(User.username == 'huey')
|
|
query = (Tweet
|
|
.select(Tweet.content, Tweet.user.in_(subq).alias('is_huey'))
|
|
.order_by(Tweet.content))
|
|
self.assertEqual([(r.content, r.is_huey) for r in query], [
|
|
('bark', False),
|
|
('hiss', True),
|
|
('meow', True),
|
|
('purr', True),
|
|
('woof', False)])
|
|
|
|
|
|
class TUser(TestModel):
|
|
username = TextField()
|
|
|
|
|
|
class Transaction(TestModel):
|
|
user = ForeignKeyField(TUser, backref='transactions')
|
|
amount = FloatField(default=0.)
|
|
|
|
|
|
class TestSumCase(ModelTestCase):
|
|
@requires_models(User)
|
|
def test_sum_case(self):
|
|
for username in ('charlie', 'huey', 'zaizee'):
|
|
User.create(username=username)
|
|
|
|
case = Case(None, [(User.username.endswith('e'), 1)], 0)
|
|
e_sum = fn.SUM(case)
|
|
query = (User
|
|
.select(User.username, e_sum.alias('e_sum'))
|
|
.group_by(User.username)
|
|
.order_by(User.username))
|
|
self.assertSQL(query, (
|
|
'SELECT "t1"."username", '
|
|
'SUM(CASE WHEN ("t1"."username" ILIKE ?) THEN ? ELSE ? END) '
|
|
'AS "e_sum" '
|
|
'FROM "users" AS "t1" '
|
|
'GROUP BY "t1"."username" '
|
|
'ORDER BY "t1"."username"'), ['%e', 1, 0])
|
|
|
|
data = [(user.username, user.e_sum) for user in query]
|
|
self.assertEqual(data, [
|
|
('charlie', 1),
|
|
('huey', 0),
|
|
('zaizee', 1)])
|
|
|
|
|
|
class TestMaxAlias(ModelTestCase):
|
|
requires = [Transaction, TUser]
|
|
|
|
def test_max_alias(self):
|
|
with self.database.atomic():
|
|
charlie = TUser.create(username='charlie')
|
|
huey = TUser.create(username='huey')
|
|
|
|
data = (
|
|
(charlie, 10.),
|
|
(charlie, 20.),
|
|
(charlie, 30.),
|
|
(huey, 1.5),
|
|
(huey, 2.5))
|
|
for user, amount in data:
|
|
Transaction.create(user=user, amount=amount)
|
|
|
|
with self.assertQueryCount(1):
|
|
amount = fn.MAX(Transaction.amount).alias('amount')
|
|
query = (Transaction
|
|
.select(amount, TUser.username)
|
|
.join(TUser)
|
|
.group_by(TUser.username)
|
|
.order_by(TUser.username))
|
|
data = [(txn.amount, txn.user.username) for txn in query]
|
|
|
|
self.assertEqual(data, [
|
|
(30., 'charlie'),
|
|
(2.5, 'huey')])
|
|
|
|
|
|
class Datum(TestModel):
|
|
key = TextField()
|
|
value = IntegerField(null=True)
|
|
|
|
class TestNullOrdering(ModelTestCase):
|
|
requires = [Datum]
|
|
|
|
def test_null_ordering(self):
|
|
values = [('k1', 1), ('ka', None), ('k2', 2), ('kb', None)]
|
|
Datum.insert_many(values, fields=[Datum.key, Datum.value]).execute()
|
|
|
|
def assertOrder(ordering, expected):
|
|
query = Datum.select().order_by(*ordering)
|
|
self.assertEqual([d.key for d in query], expected)
|
|
|
|
# Ascending order.
|
|
nulls_last = (Datum.value.asc(nulls='last'), Datum.key)
|
|
assertOrder(nulls_last, ['k1', 'k2', 'ka', 'kb'])
|
|
|
|
nulls_first = (Datum.value.asc(nulls='first'), Datum.key)
|
|
assertOrder(nulls_first, ['ka', 'kb', 'k1', 'k2'])
|
|
|
|
# Descending order.
|
|
nulls_last = (Datum.value.desc(nulls='last'), Datum.key)
|
|
assertOrder(nulls_last, ['k2', 'k1', 'ka', 'kb'])
|
|
|
|
nulls_first = (Datum.value.desc(nulls='first'), Datum.key)
|
|
assertOrder(nulls_first, ['ka', 'kb', 'k2', 'k1'])
|
|
|
|
# Invalid values.
|
|
self.assertRaises(ValueError, Datum.value.desc, nulls='bar')
|
|
self.assertRaises(ValueError, Datum.value.asc, nulls='foo')
|
|
|
|
|
|
class TestColumnNameStripping(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [Person]
|
|
|
|
def test_column_name_stripping(self):
|
|
d1 = datetime.date(1990, 1, 1)
|
|
d2 = datetime.date(1990, 1, 1)
|
|
p1 = Person.create(first='f1', last='l1', dob=d1)
|
|
p2 = Person.create(first='f2', last='l2', dob=d2)
|
|
|
|
query = Person.select(
|
|
fn.MIN(Person.dob),
|
|
fn.MAX(Person.dob).alias('mdob'))
|
|
# Get the row as a model.
|
|
row = query.get()
|
|
self.assertEqual(row.dob, d1)
|
|
self.assertEqual(row.mdob, d2)
|
|
|
|
row = query.dicts().get()
|
|
self.assertEqual(row['dob'], d1)
|
|
self.assertEqual(row['mdob'], d2)
|
|
|
|
class TestSelectValueConversion(ModelTestCase):
|
|
requires = [User]
|
|
|
|
@skip_if(IS_SQLITE_OLD or IS_MYSQL)
|
|
def test_select_value_conversion(self):
|
|
u1 = User.create(username='u1')
|
|
cte = User.select(User.id.cast('text')).cte('tmp', columns=('id',))
|
|
|
|
query = User.select(cte.c.id.alias('id')).with_cte(cte).from_(cte)
|
|
u1_id, = [user.id for user in query]
|
|
self.assertEqual(u1_id, u1.id)
|
|
|
|
query2 = User.select(cte.c.id.coerce(False)).with_cte(cte).from_(cte)
|
|
u1_id, = [user.id for user in query2]
|
|
self.assertEqual(u1_id, str(u1.id))
|
|
|
|
|
|
class TestOrWhere(ModelTestCase):
|
|
requires = [User]
|
|
|
|
def test_orwhere(self):
|
|
User.insert_many([{'username': u} for u in
|
|
('huey', 'mickey', 'zaizee')]).execute()
|
|
query = (User
|
|
.select()
|
|
.orwhere(User.username == 'huey')
|
|
.orwhere(User.username == 'zaizee')
|
|
.order_by(User.username))
|
|
self.assertEqual([u.username for u in query], ['huey', 'zaizee'])
|
|
|
|
def test_where_then_orwhere(self):
|
|
User.insert_many([{'username': u} for u in
|
|
('huey', 'mickey', 'zaizee')]).execute()
|
|
# where + orwhere: the where is OR'd with subsequent orwhere calls.
|
|
query = (User
|
|
.select()
|
|
.where(User.username == 'huey')
|
|
.orwhere(User.username == 'zaizee')
|
|
.order_by(User.username))
|
|
self.assertEqual([u.username for u in query], ['huey', 'zaizee'])
|
|
|
|
|
|
class TestEnsureJoin(ModelTestCase):
|
|
requires = [User, Tweet]
|
|
|
|
def test_ensure_join_noop(self):
|
|
"""If join already exists, ensure_join doesn't duplicate it."""
|
|
u = User.create(username='huey')
|
|
Tweet.create(user=u, content='meow')
|
|
|
|
query = (User
|
|
.select(User, Tweet.content)
|
|
.join(Tweet)
|
|
.switch(User)
|
|
.ensure_join(User, Tweet))
|
|
result = [(row.username, row.tweet.content) for row in query]
|
|
self.assertEqual(result, [('huey', 'meow')])
|
|
|
|
def test_ensure_join_adds_when_missing(self):
|
|
"""If join doesn't exist, ensure_join adds it."""
|
|
u = User.create(username='huey')
|
|
Tweet.create(user=u, content='meow')
|
|
|
|
query = (User
|
|
.select(User, Tweet.content)
|
|
.ensure_join(User, Tweet))
|
|
result = [(row.username, row.tweet.content) for row in query]
|
|
self.assertEqual(result, [('huey', 'meow')])
|
|
|
|
|
|
class TestScalarIntegration(ModelTestCase):
|
|
requires = [User, Sample]
|
|
|
|
@requires_models(User)
|
|
def test_scalar(self):
|
|
for u in ('huey', 'mickey', 'zaizee'):
|
|
User.create(username=u)
|
|
count = User.select(fn.COUNT(User.id)).scalar()
|
|
self.assertEqual(count, 3)
|
|
|
|
@requires_models(User)
|
|
def test_scalar_as_tuple(self):
|
|
for u in ('huey', 'mickey', 'zaizee'):
|
|
User.create(username=u)
|
|
count, mx = (User
|
|
.select(fn.COUNT(User.id), fn.MAX(User.id))
|
|
.scalar(as_tuple=True))
|
|
self.assertEqual(count, 3)
|
|
self.assertTrue(mx > 0)
|
|
|
|
@requires_models(User)
|
|
def test_scalar_as_dict(self):
|
|
for u in ('huey', 'mickey', 'zaizee'):
|
|
User.create(username=u)
|
|
result = (User
|
|
.select(fn.COUNT(User.id).alias('ct'))
|
|
.scalar(as_dict=True))
|
|
self.assertEqual(result, {'ct': 3})
|
|
|
|
@requires_models(Sample)
|
|
def test_scalar_empty_result(self):
|
|
val = (Sample
|
|
.select(fn.MAX(Sample.value))
|
|
.scalar())
|
|
self.assertTrue(val is None)
|
|
|
|
|
|
class TestExistsIntegration(ModelTestCase):
|
|
requires = [User]
|
|
|
|
@requires_models(User)
|
|
def test_exists_true(self):
|
|
User.create(username='huey')
|
|
self.assertTrue(
|
|
User.select().where(User.username == 'huey').exists())
|
|
|
|
@requires_models(User)
|
|
def test_exists_false(self):
|
|
User.create(username='huey')
|
|
self.assertFalse(
|
|
User.select().where(User.username == 'nobody').exists())
|
|
|
|
|
|
class TestObjectsIntegration(ModelTestCase):
|
|
requires = [User, Tweet]
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_objects_returns_target_model(self):
|
|
u = User.create(username='huey')
|
|
Tweet.create(user=u, content='meow')
|
|
Tweet.create(user=u, content='purr')
|
|
|
|
query = (Tweet
|
|
.select(Tweet, User)
|
|
.join(User)
|
|
.order_by(Tweet.content)
|
|
.objects())
|
|
results = list(query)
|
|
# .objects() maps all columns onto the Tweet model.
|
|
self.assertEqual(len(results), 2)
|
|
self.assertTrue(isinstance(results[0], Tweet))
|
|
self.assertEqual(results[0].content, 'meow')
|
|
self.assertEqual(results[0].username, 'huey')
|
|
self.assertEqual(results[1].content, 'purr')
|
|
|
|
|
|
class TestRowTypeIntegration(ModelTestCase):
|
|
requires = [User]
|
|
|
|
@requires_models(User)
|
|
def test_tuples(self):
|
|
User.create(username='huey')
|
|
result = list(User.select(User.username).tuples())
|
|
self.assertEqual(result, [('huey',)])
|
|
|
|
@requires_models(User)
|
|
def test_dicts(self):
|
|
User.create(username='huey')
|
|
result = list(User.select(User.username).dicts())
|
|
self.assertEqual(result, [{'username': 'huey'}])
|
|
|
|
@requires_models(User)
|
|
def test_namedtuples(self):
|
|
User.create(username='huey')
|
|
result = list(User.select(User.username).namedtuples())
|
|
self.assertEqual(len(result), 1)
|
|
self.assertEqual(result[0].username, 'huey')
|
|
|
|
|
|
class VL(TestModel):
|
|
n = IntegerField()
|
|
s = CharField()
|
|
|
|
|
|
@skip_if(IS_SQLITE_OLD or IS_MYSQL or IS_CRDB)
|
|
class TestValuesListIntegration(ModelTestCase):
|
|
requires = [VL]
|
|
_data = [(1, 'one'), (2, 'two'), (3, 'three')]
|
|
|
|
def test_insert_into_select_from_vl(self):
|
|
vl = ValuesList(self._data)
|
|
cte = vl.cte('newvals', columns=['n', 's'])
|
|
res = (VL
|
|
.insert_from(cte.select(cte.c.n, cte.c.s), fields=[VL.n, VL.s])
|
|
.with_cte(cte)
|
|
.execute())
|
|
|
|
vq = VL.select().order_by(VL.n)
|
|
self.assertEqual([(v.n, v.s) for v in vq], self._data)
|
|
|
|
def test_update_vl_cte(self):
|
|
VL.insert_many(self._data).execute()
|
|
|
|
new_values = [(1, 'One'), (3, 'Three'), (4, 'Four')]
|
|
cte = ValuesList(new_values).cte('new_values', columns=('n', 's'))
|
|
|
|
# We have to use a subquery to update the individual column, as SQLite
|
|
# does not support UPDATE/FROM syntax.
|
|
subq = (cte
|
|
.select(cte.c.s)
|
|
.where(VL.n == cte.c.n))
|
|
|
|
# Perform the update, assigning extra the new value from the values
|
|
# list, and restricting the overall update using the composite pk.
|
|
res = (VL
|
|
.update(s=subq)
|
|
.where(VL.n.in_(cte.select(cte.c.n)))
|
|
.with_cte(cte)
|
|
.execute())
|
|
|
|
vq = VL.select().order_by(VL.n)
|
|
self.assertEqual([(v.n, v.s) for v in vq], [
|
|
(1, 'One'), (2, 'two'), (3, 'Three')])
|
|
|
|
def test_values_list(self):
|
|
vl = ValuesList(self._data)
|
|
|
|
query = vl.select(SQL('*'))
|
|
self.assertEqual(list(query.tuples().bind(self.database)), self._data)
|
|
|
|
@requires_postgresql
|
|
def test_values_list_named_columns(self):
|
|
vl = ValuesList(self._data).columns('idx', 'name')
|
|
query = (vl
|
|
.select(vl.c.idx, vl.c.name)
|
|
.order_by(vl.c.idx.desc()))
|
|
self.assertEqual(list(query.tuples().bind(self.database)),
|
|
self._data[::-1])
|
|
|
|
def test_values_list_named_columns_in_cte(self):
|
|
vl = ValuesList(self._data)
|
|
cte = vl.cte('val', columns=('idx', 'name'))
|
|
query = (cte
|
|
.select(cte.c.idx, cte.c.name)
|
|
.order_by(cte.c.idx.desc())
|
|
.with_cte(cte))
|
|
self.assertEqual(list(query.tuples().bind(self.database)),
|
|
self._data[::-1])
|
|
|
|
def test_named_values_list(self):
|
|
vl = ValuesList(self._data).alias('vl')
|
|
query = vl.select()
|
|
self.assertEqual(list(query.tuples().bind(self.database)), self._data)
|
|
|
|
# ===========================================================================
|
|
# Common Table Expressions (CTE)
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
class Member(TestModel):
|
|
name = TextField()
|
|
recommendedby = ForeignKeyField('self', null=True)
|
|
|
|
|
|
class TestCTEIntegration(ModelTestCase):
|
|
requires = [Category]
|
|
|
|
def setUp(self):
|
|
super(TestCTEIntegration, self).setUp()
|
|
CC = Category.create
|
|
root = CC(name='root')
|
|
p1 = CC(name='p1', parent=root)
|
|
p2 = CC(name='p2', parent=root)
|
|
p3 = CC(name='p3', parent=root)
|
|
c11 = CC(name='c11', parent=p1)
|
|
c12 = CC(name='c12', parent=p1)
|
|
c31 = CC(name='c31', parent=p3)
|
|
|
|
@skip_if(IS_SQLITE_OLD or (IS_MYSQL and not IS_MYSQL_ADVANCED_FEATURES)
|
|
or IS_CRDB)
|
|
@requires_models(Member)
|
|
def test_docs_example(self):
|
|
f = Member.create(name='founder')
|
|
gen2_1 = Member.create(name='g2-1', recommendedby=f)
|
|
gen2_2 = Member.create(name='g2-2', recommendedby=f)
|
|
gen2_3 = Member.create(name='g2-3', recommendedby=f)
|
|
gen3_1_1 = Member.create(name='g3-1-1', recommendedby=gen2_1)
|
|
gen3_1_2 = Member.create(name='g3-1-2', recommendedby=gen2_1)
|
|
gen3_3_1 = Member.create(name='g3-3-1', recommendedby=gen2_3)
|
|
|
|
# Get recommender chain for 331.
|
|
base = (Member
|
|
.select(Member.recommendedby)
|
|
.where(Member.id == gen3_3_1.id)
|
|
.cte('recommenders', recursive=True, columns=('recommender',)))
|
|
|
|
MA = Member.alias()
|
|
recursive = (MA
|
|
.select(MA.recommendedby)
|
|
.join(base, on=(MA.id == base.c.recommender)))
|
|
|
|
cte = base.union_all(recursive)
|
|
query = (cte
|
|
.select_from(cte.c.recommender, Member.name)
|
|
.join(Member, on=(cte.c.recommender == Member.id))
|
|
.order_by(Member.id.desc()))
|
|
self.assertEqual([m.name for m in query], ['g2-3', 'founder'])
|
|
|
|
@skip_if(IS_SQLITE_OLD or (IS_MYSQL and not IS_MYSQL_ADVANCED_FEATURES))
|
|
def test_simple_cte(self):
|
|
cte = (Category
|
|
.select(Category.name, Category.parent)
|
|
.cte('catz', columns=('name', 'parent')))
|
|
|
|
cte_sql = ('WITH "catz" ("name", "parent") AS ('
|
|
'SELECT "t1"."name", "t1"."parent_id" '
|
|
'FROM "category" AS "t1") '
|
|
'SELECT "catz"."name", "catz"."parent" AS "pname" '
|
|
'FROM "catz" '
|
|
'ORDER BY "catz"."name"')
|
|
|
|
query = (Category
|
|
.select(cte.c.name, cte.c.parent.alias('pname'))
|
|
.from_(cte)
|
|
.order_by(cte.c.name)
|
|
.with_cte(cte))
|
|
self.assertSQL(query, cte_sql, [])
|
|
|
|
query2 = (cte.select_from(cte.c.name, cte.c.parent.alias('pname'))
|
|
.order_by(cte.c.name))
|
|
self.assertSQL(query2, cte_sql, [])
|
|
|
|
self.assertEqual([(row.name, row.pname) for row in query], [
|
|
('c11', 'p1'),
|
|
('c12', 'p1'),
|
|
('c31', 'p3'),
|
|
('p1', 'root'),
|
|
('p2', 'root'),
|
|
('p3', 'root'),
|
|
('root', None)])
|
|
self.assertEqual([(row.name, row.pname) for row in query],
|
|
[(row.name, row.pname) for row in query2])
|
|
|
|
@skip_if(IS_SQLITE_OLD or (IS_MYSQL and not IS_MYSQL_ADVANCED_FEATURES))
|
|
def test_cte_join(self):
|
|
cte = (Category
|
|
.select(Category.name)
|
|
.cte('parents', columns=('name',)))
|
|
|
|
query = (Category
|
|
.select(Category.name, cte.c.name.alias('pname'))
|
|
.join(cte, on=(Category.parent == cte.c.name))
|
|
.order_by(Category.name)
|
|
.with_cte(cte))
|
|
|
|
self.assertSQL(query, (
|
|
'WITH "parents" ("name") AS ('
|
|
'SELECT "t1"."name" FROM "category" AS "t1") '
|
|
'SELECT "t2"."name", "parents"."name" AS "pname" '
|
|
'FROM "category" AS "t2" '
|
|
'INNER JOIN "parents" ON ("t2"."parent_id" = "parents"."name") '
|
|
'ORDER BY "t2"."name"'), [])
|
|
self.assertEqual([(c.name, c.parents['pname']) for c in query], [
|
|
('c11', 'p1'),
|
|
('c12', 'p1'),
|
|
('c31', 'p3'),
|
|
('p1', 'root'),
|
|
('p2', 'root'),
|
|
('p3', 'root'),
|
|
])
|
|
|
|
@skip_if(IS_SQLITE_OLD or IS_MYSQL or IS_CRDB, 'requires recursive cte')
|
|
def test_recursive_cte(self):
|
|
def get_parents(cname):
|
|
C1 = Category.alias()
|
|
C2 = Category.alias()
|
|
|
|
level = SQL('1').cast('integer').alias('level')
|
|
path = C1.name.cast('text').alias('path')
|
|
|
|
base = (C1
|
|
.select(C1.name, C1.parent, level, path)
|
|
.where(C1.name == cname)
|
|
.cte('parents', recursive=True))
|
|
|
|
rlevel = (base.c.level + 1).alias('level')
|
|
rpath = base.c.path.concat('->').concat(C2.name).alias('path')
|
|
recursive = (C2
|
|
.select(C2.name, C2.parent, rlevel, rpath)
|
|
.from_(base)
|
|
.join(C2, on=(C2.name == base.c.parent_id)))
|
|
|
|
cte = base + recursive
|
|
query = (cte
|
|
.select_from(cte.c.name, cte.c.level, cte.c.path)
|
|
.order_by(cte.c.level))
|
|
self.assertSQL(query, (
|
|
'WITH RECURSIVE "parents" AS ('
|
|
'SELECT "t1"."name", "t1"."parent_id", '
|
|
'CAST(1 AS integer) AS "level", '
|
|
'CAST("t1"."name" AS text) AS "path" '
|
|
'FROM "category" AS "t1" '
|
|
'WHERE ("t1"."name" = ?) '
|
|
'UNION ALL '
|
|
'SELECT "t2"."name", "t2"."parent_id", '
|
|
'("parents"."level" + ?) AS "level", '
|
|
'(("parents"."path" || ?) || "t2"."name") AS "path" '
|
|
'FROM "parents" '
|
|
'INNER JOIN "category" AS "t2" '
|
|
'ON ("t2"."name" = "parents"."parent_id")) '
|
|
'SELECT "parents"."name", "parents"."level", "parents"."path" '
|
|
'FROM "parents" '
|
|
'ORDER BY "parents"."level"'), [cname, 1, '->'])
|
|
return query
|
|
|
|
data = [row for row in get_parents('c31').tuples()]
|
|
self.assertEqual(data, [
|
|
('c31', 1, 'c31'),
|
|
('p3', 2, 'c31->p3'),
|
|
('root', 3, 'c31->p3->root')])
|
|
|
|
data = [(c.name, c.level, c.path)
|
|
for c in get_parents('c12').namedtuples()]
|
|
self.assertEqual(data, [
|
|
('c12', 1, 'c12'),
|
|
('p1', 2, 'c12->p1'),
|
|
('root', 3, 'c12->p1->root')])
|
|
|
|
query = get_parents('root')
|
|
data = [(r.name, r.level, r.path) for r in query]
|
|
self.assertEqual(data, [('root', 1, 'root')])
|
|
|
|
@skip_if(IS_SQLITE_OLD or IS_MYSQL or IS_CRDB, 'requires recursive cte')
|
|
def test_recursive_cte2(self):
|
|
hierarchy = (Category
|
|
.select(Category.name, Value(0).alias('level'))
|
|
.where(Category.parent.is_null(True))
|
|
.cte(name='hierarchy', recursive=True))
|
|
|
|
C = Category.alias()
|
|
recursive = (C
|
|
.select(C.name, (hierarchy.c.level + 1).alias('level'))
|
|
.join(hierarchy, on=(C.parent == hierarchy.c.name)))
|
|
|
|
cte = hierarchy.union_all(recursive)
|
|
query = (cte
|
|
.select_from(cte.c.name, cte.c.level)
|
|
.order_by(cte.c.name))
|
|
self.assertEqual([(r.name, r.level) for r in query], [
|
|
('c11', 2),
|
|
('c12', 2),
|
|
('c31', 2),
|
|
('p1', 1),
|
|
('p2', 1),
|
|
('p3', 1),
|
|
('root', 0)])
|
|
|
|
@skip_if(IS_SQLITE_OLD or IS_MYSQL or IS_CRDB, 'requires recursive cte')
|
|
def test_recursive_cte_docs_example(self):
|
|
# Define the base case of our recursive CTE. This will be categories that
|
|
# have a null parent foreign-key.
|
|
Base = Category.alias()
|
|
level = Value(1).cast('integer').alias('level')
|
|
path = Base.name.cast('text').alias('path')
|
|
base_case = (Base
|
|
.select(Base.name, Base.parent, level, path)
|
|
.where(Base.parent.is_null())
|
|
.cte('base', recursive=True))
|
|
|
|
# Define the recursive terms.
|
|
RTerm = Category.alias()
|
|
rlevel = (base_case.c.level + 1).alias('level')
|
|
rpath = base_case.c.path.concat('->').concat(RTerm.name).alias('path')
|
|
recursive = (RTerm
|
|
.select(RTerm.name, RTerm.parent, rlevel, rpath)
|
|
.join(base_case, on=(RTerm.parent == base_case.c.name)))
|
|
|
|
# The recursive CTE is created by taking the base case and UNION ALL with
|
|
# the recursive term.
|
|
cte = base_case.union_all(recursive)
|
|
|
|
# We will now query from the CTE to get the categories, their levels, and
|
|
# their paths.
|
|
query = (cte
|
|
.select_from(cte.c.name, cte.c.level, cte.c.path)
|
|
.order_by(cte.c.path))
|
|
data = [(obj.name, obj.level, obj.path) for obj in query]
|
|
self.assertEqual(data, [
|
|
('root', 1, 'root'),
|
|
('p1', 2, 'root->p1'),
|
|
('c11', 3, 'root->p1->c11'),
|
|
('c12', 3, 'root->p1->c12'),
|
|
('p2', 2, 'root->p2'),
|
|
('p3', 2, 'root->p3'),
|
|
('c31', 3, 'root->p3->c31')])
|
|
|
|
@requires_models(Sample)
|
|
@skip_if(IS_SQLITE_OLD or IS_MYSQL, 'sqlite too old for ctes, mysql flaky')
|
|
def test_cte_reuse_aggregate(self):
|
|
data = (
|
|
(1, (1.25, 1.5, 1.75)),
|
|
(2, (2.1, 2.3, 2.5, 2.7, 2.9)),
|
|
(3, (3.5, 3.5)))
|
|
with self.database.atomic():
|
|
for counter, values in data:
|
|
(Sample
|
|
.insert_many([(counter, value) for value in values],
|
|
fields=[Sample.counter, Sample.value])
|
|
.execute())
|
|
|
|
cte = (Sample
|
|
.select(Sample.counter, fn.AVG(Sample.value).alias('avg_value'))
|
|
.group_by(Sample.counter)
|
|
.cte('count_to_avg', columns=('counter', 'avg_value')))
|
|
|
|
query = (Sample
|
|
.select(Sample.counter,
|
|
(Sample.value - cte.c.avg_value).alias('diff'))
|
|
.join(cte, on=(Sample.counter == cte.c.counter))
|
|
.where(Sample.value > cte.c.avg_value)
|
|
.order_by(Sample.value)
|
|
.with_cte(cte))
|
|
self.assertEqual([(a, round(b, 2)) for a, b in query.tuples()], [
|
|
(1, .25),
|
|
(2, .2),
|
|
(2, .4)])
|
|
|
|
@skip_if(IS_SQLITE_OLD or IS_MYSQL)
|
|
@requires_models(Sample)
|
|
def test_cte_with_aggregate_filter(self):
|
|
for i in range(1, 11):
|
|
Sample.create(counter=i, value=float(i * i))
|
|
|
|
cte = (Sample
|
|
.select(Sample.counter, Sample.value)
|
|
.where(Sample.counter <= 5)
|
|
.cte('small'))
|
|
query = (cte
|
|
.select_from(fn.SUM(cte.c.value).alias('total'))
|
|
.where(cte.c.counter > 2))
|
|
result = query.scalar()
|
|
# sum of 3^2 + 4^2 + 5^2 = 9 + 16 + 25 = 50
|
|
self.assertEqual(result, 50.0)
|
|
|
|
|
|
class C_Product(TestModel):
|
|
name = CharField()
|
|
price = IntegerField(default=0)
|
|
|
|
class C_Archive(TestModel):
|
|
name = CharField()
|
|
price = IntegerField(default=0)
|
|
|
|
class C_Part(TestModel):
|
|
part = CharField(primary_key=True)
|
|
sub_part = ForeignKeyField('self', null=True)
|
|
|
|
@skip_unless(IS_POSTGRESQL)
|
|
class TestDataModifyingCTEIntegration(ModelTestCase):
|
|
requires = [C_Product, C_Archive, C_Part]
|
|
|
|
def setUp(self):
|
|
super(TestDataModifyingCTEIntegration, self).setUp()
|
|
for i in range(5):
|
|
C_Product.create(name='p%s' % i, price=i)
|
|
mp1_c_g = C_Part.create(part='mp1-c-g')
|
|
mp1_c = C_Part.create(part='mp1-c', sub_part=mp1_c_g)
|
|
mp1 = C_Part.create(part='mp1', sub_part=mp1_c)
|
|
mp2_c_g = C_Part.create(part='mp2-c-g')
|
|
mp2_c = C_Part.create(part='mp2-c', sub_part=mp2_c_g)
|
|
mp2 = C_Part.create(part='mp2', sub_part=mp2_c)
|
|
|
|
def test_data_modifying_cte_delete(self):
|
|
query = (C_Product.delete()
|
|
.where(C_Product.price < 3)
|
|
.returning(C_Product))
|
|
cte = query.cte('moved_rows')
|
|
|
|
src = Select((cte,), (cte.c.id, cte.c.name, cte.c.price))
|
|
res = (C_Archive
|
|
.insert_from(src, (C_Archive.id, C_Archive.name, C_Archive.price))
|
|
.with_cte(cte)
|
|
.execute())
|
|
self.assertEqual(len(list(res)), 3)
|
|
|
|
self.assertEqual(
|
|
sorted([(p.name, p.price) for p in C_Product.select()]),
|
|
[('p3', 3), ('p4', 4)])
|
|
self.assertEqual(
|
|
sorted([(p.name, p.price) for p in C_Archive.select()]),
|
|
[('p0', 0), ('p1', 1), ('p2', 2)])
|
|
|
|
base = (C_Part
|
|
.select(C_Part.sub_part, C_Part.part)
|
|
.where(C_Part.part == 'mp1')
|
|
.cte('included_parts', recursive=True,
|
|
columns=('sub_part', 'part')))
|
|
PA = C_Part.alias('p')
|
|
recursive = (PA
|
|
.select(PA.sub_part, PA.part)
|
|
.join(base, on=(PA.part == base.c.sub_part)))
|
|
cte = base.union_all(recursive)
|
|
|
|
sq = Select((cte,), (cte.c.part,))
|
|
res = (C_Part.delete()
|
|
.where(C_Part.part.in_(sq))
|
|
.with_cte(cte)
|
|
.execute())
|
|
|
|
self.assertEqual(sorted([p.part for p in C_Part.select()]),
|
|
['mp2', 'mp2-c', 'mp2-c-g'])
|
|
|
|
def test_data_modifying_cte_update(self):
|
|
# Populate archive table w/copy of data in product.
|
|
C_Archive.insert_from(
|
|
C_Product.select(),
|
|
(C_Product.id, C_Product.name, C_Product.price)).execute()
|
|
|
|
query = (C_Product
|
|
.update(price=C_Product.price * 2)
|
|
.returning(C_Product.id, C_Product.name, C_Product.price))
|
|
cte = query.cte('t')
|
|
|
|
sq = cte.select_from(cte.c.id, cte.c.name, cte.c.price)
|
|
self.assertEqual(sorted([(x.name, x.price) for x in sq]), [
|
|
('p0', 0), ('p1', 2), ('p2', 4), ('p3', 6), ('p4', 8)])
|
|
|
|
# Ensure changes were persisted.
|
|
self.assertEqual(sorted([(x.name, x.price) for x in C_Product]), [
|
|
('p0', 0), ('p1', 2), ('p2', 4), ('p3', 6), ('p4', 8)])
|
|
|
|
sq = Select((cte,), (cte.c.id, cte.c.price))
|
|
res = (C_Archive
|
|
.update(price=sq.c.price)
|
|
.from_(sq)
|
|
.where(C_Archive.id == sq.c.id)
|
|
.with_cte(cte)
|
|
.execute())
|
|
|
|
self.assertEqual(sorted([(x.name, x.price) for x in C_Product]), [
|
|
('p0', 0), ('p1', 4), ('p2', 8), ('p3', 12), ('p4', 16)])
|
|
self.assertEqual(sorted([(x.name, x.price) for x in C_Archive]), [
|
|
('p0', 0), ('p1', 4), ('p2', 8), ('p3', 12), ('p4', 16)])
|
|
|
|
def test_data_modifying_cte_insert(self):
|
|
query = (C_Product
|
|
.insert({'name': 'p5', 'price': 5})
|
|
.returning(C_Product.id, C_Product.name, C_Product.price))
|
|
cte = query.cte('t')
|
|
|
|
sq = cte.select_from(cte.c.id, cte.c.name, cte.c.price)
|
|
self.assertEqual([(p.name, p.price) for p in sq], [('p5', 5)])
|
|
|
|
query = (C_Product
|
|
.insert({'name': 'p6', 'price': 6})
|
|
.returning(C_Product.id, C_Product.name, C_Product.price))
|
|
cte = query.cte('t')
|
|
sq = Select((cte,), (cte.c.id, cte.c.name, cte.c.price))
|
|
res = (C_Archive
|
|
.insert_from(sq, (sq.c.id, sq.c.name, sq.c.price))
|
|
.with_cte(cte)
|
|
.execute())
|
|
self.assertEqual([(p.name, p.price) for p in C_Archive], [('p6', 6)])
|
|
|
|
self.assertEqual(sorted([(p.name, p.price) for p in C_Product]), [
|
|
('p0', 0), ('p1', 1), ('p2', 2), ('p3', 3), ('p4', 4), ('p5', 5),
|
|
('p6', 6)])
|
|
|
|
# ===========================================================================
|
|
# INSERT conflict handling / upsert (per-dialect)
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
class OnConflictTests(object):
|
|
requires = [Emp]
|
|
test_data = (
|
|
('huey', 'cat', '123'),
|
|
('zaizee', 'cat', '124'),
|
|
('mickey', 'dog', '125'),
|
|
)
|
|
|
|
def setUp(self):
|
|
super(OnConflictTests, self).setUp()
|
|
for first, last, empno in self.test_data:
|
|
Emp.create(first=first, last=last, empno=empno)
|
|
|
|
def assertData(self, expected):
|
|
query = (Emp
|
|
.select(Emp.first, Emp.last, Emp.empno)
|
|
.order_by(Emp.id)
|
|
.tuples())
|
|
self.assertEqual(list(query), expected)
|
|
|
|
def test_ignore(self):
|
|
query = (Emp
|
|
.insert(first='foo', last='bar', empno='123')
|
|
.on_conflict('ignore')
|
|
.execute())
|
|
self.assertData(list(self.test_data))
|
|
|
|
|
|
def requires_upsert(m):
|
|
return skip_unless(IS_SQLITE_24 or IS_POSTGRESQL or IS_CRDB,
|
|
'requires upsert')(m)
|
|
|
|
|
|
class KV(TestModel):
|
|
key = CharField(unique=True)
|
|
value = IntegerField()
|
|
|
|
|
|
class PGOnConflictTests(OnConflictTests):
|
|
@requires_upsert
|
|
def test_update(self):
|
|
# Conflict on empno - we'll preserve name and update the ID. This will
|
|
# overwrite the previous row and set a new ID.
|
|
res = (Emp
|
|
.insert(first='foo', last='bar', empno='125')
|
|
.on_conflict(
|
|
conflict_target=(Emp.empno,),
|
|
preserve=(Emp.first, Emp.last),
|
|
update={Emp.empno: '125.1'})
|
|
.execute())
|
|
self.assertData([
|
|
('huey', 'cat', '123'),
|
|
('zaizee', 'cat', '124'),
|
|
('foo', 'bar', '125.1')])
|
|
|
|
# Conflicts on first/last name. The first name is preserved while the
|
|
# last-name is updated. The new empno is thrown out.
|
|
res = (Emp
|
|
.insert(first='foo', last='bar', empno='126')
|
|
.on_conflict(
|
|
conflict_target=(Emp.first, Emp.last),
|
|
preserve=(Emp.first,),
|
|
update={Emp.last: 'baze'})
|
|
.execute())
|
|
self.assertData([
|
|
('huey', 'cat', '123'),
|
|
('zaizee', 'cat', '124'),
|
|
('foo', 'baze', '125.1')])
|
|
|
|
@requires_upsert
|
|
@requires_models(OCTest)
|
|
def test_update_ignore_with_conflict_target(self):
|
|
query = OCTest.insert(a='foo', b=1).on_conflict(
|
|
action='IGNORE',
|
|
conflict_target=(OCTest.a,))
|
|
rowid1 = query.execute()
|
|
self.assertTrue(rowid1 is not None)
|
|
|
|
query.clone().execute() # Nothing happens, insert is ignored.
|
|
self.assertEqual(OCTest.select().count(), 1)
|
|
|
|
OCTest.insert(a='foo', b=2).on_conflict_ignore().execute()
|
|
self.assertEqual(OCTest.select().count(), 1)
|
|
|
|
OCTest.insert(a='bar', b=1).on_conflict_ignore().execute()
|
|
self.assertEqual(OCTest.select().count(), 2)
|
|
|
|
@requires_upsert
|
|
@requires_models(OCTest)
|
|
def test_update_atomic(self):
|
|
# Add a new row with the given "a" value. If a conflict occurs,
|
|
# re-insert with b=b+2.
|
|
query = OCTest.insert(a='foo', b=1).on_conflict(
|
|
conflict_target=(OCTest.a,),
|
|
update={OCTest.b: OCTest.b + 2})
|
|
|
|
# First execution returns rowid=1. Second execution hits the conflict-
|
|
# resolution, and will update the value in "b" from 1 -> 3.
|
|
rowid1 = query.execute()
|
|
rowid2 = query.clone().execute()
|
|
self.assertEqual(rowid1, rowid2)
|
|
|
|
obj = OCTest.get()
|
|
self.assertEqual(obj.a, 'foo')
|
|
self.assertEqual(obj.b, 3)
|
|
|
|
query = OCTest.insert(a='foo', b=4, c=5).on_conflict(
|
|
conflict_target=[OCTest.a],
|
|
preserve=[OCTest.c],
|
|
update={OCTest.b: OCTest.b + 100})
|
|
self.assertEqual(query.execute(), rowid2)
|
|
|
|
obj = OCTest.get()
|
|
self.assertEqual(obj.a, 'foo')
|
|
self.assertEqual(obj.b, 103)
|
|
self.assertEqual(obj.c, 5)
|
|
|
|
@requires_upsert
|
|
@requires_models(OCTest)
|
|
def test_update_where_clause(self):
|
|
# Add a new row with the given "a" value. If a conflict occurs,
|
|
# re-insert with b=b+2 so long as the original b < 3.
|
|
query = OCTest.insert(a='foo', b=1).on_conflict(
|
|
conflict_target=(OCTest.a,),
|
|
update={OCTest.b: OCTest.b + 2},
|
|
where=(OCTest.b < 3))
|
|
|
|
# First execution returns rowid=1. Second execution hits the conflict-
|
|
# resolution, and will update the value in "b" from 1 -> 3.
|
|
rowid1 = query.execute()
|
|
rowid2 = query.clone().execute()
|
|
self.assertEqual(rowid1, rowid2)
|
|
|
|
obj = OCTest.get()
|
|
self.assertEqual(obj.a, 'foo')
|
|
self.assertEqual(obj.b, 3)
|
|
|
|
# Third execution also returns rowid=1. The WHERE clause prevents us
|
|
# from updating "b" again. If this is SQLite, we get the rowid back, if
|
|
# this is Postgresql we get None (since nothing happened).
|
|
rowid3 = query.clone().execute()
|
|
if IS_SQLITE:
|
|
self.assertEqual(rowid1, rowid3)
|
|
else:
|
|
self.assertTrue(rowid3 is None)
|
|
|
|
# Because we didn't satisfy the WHERE clause, the value in "b" is
|
|
# not incremented again.
|
|
obj = OCTest.get()
|
|
self.assertEqual(obj.a, 'foo')
|
|
self.assertEqual(obj.b, 3)
|
|
|
|
@requires_upsert
|
|
@requires_models(Emp) # Has unique on first/last, unique on empno.
|
|
def test_conflict_update_excluded(self):
|
|
e1 = Emp.create(first='huey', last='c', empno='10')
|
|
e2 = Emp.create(first='zaizee', last='c', empno='20')
|
|
|
|
res = (Emp.insert(first='huey', last='c', empno='30')
|
|
.on_conflict(conflict_target=(Emp.first, Emp.last),
|
|
update={Emp.empno: Emp.empno + EXCLUDED.empno},
|
|
where=(EXCLUDED.empno != Emp.empno))
|
|
.execute())
|
|
|
|
data = sorted(Emp.select(Emp.first, Emp.last, Emp.empno).tuples())
|
|
self.assertEqual(data, [('huey', 'c', '1030'), ('zaizee', 'c', '20')])
|
|
|
|
@requires_upsert
|
|
@requires_models(KV)
|
|
def test_conflict_update_excluded2(self):
|
|
KV.create(key='k1', value=1)
|
|
|
|
query = (KV.insert(key='k1', value=10)
|
|
.on_conflict(conflict_target=[KV.key],
|
|
update={KV.value: KV.value + EXCLUDED.value},
|
|
where=(EXCLUDED.value > KV.value)))
|
|
query.execute()
|
|
self.assertEqual(KV.select(KV.key, KV.value).tuples()[:], [('k1', 11)])
|
|
|
|
# Running it again will have no effect this time, since the new value
|
|
# (10) is not greater than the pre-existing row value (11).
|
|
query.execute()
|
|
self.assertEqual(KV.select(KV.key, KV.value).tuples()[:], [('k1', 11)])
|
|
|
|
@requires_upsert
|
|
@skip_if(IS_CRDB, 'crdb does not support the WHERE clause')
|
|
@requires_models(UKVP)
|
|
def test_conflict_target_constraint_where(self):
|
|
u1 = UKVP.create(key='k1', value=1, extra=1)
|
|
u2 = UKVP.create(key='k2', value=2, extra=2)
|
|
|
|
fields = [UKVP.key, UKVP.value, UKVP.extra]
|
|
data = [('k1', 1, 2), ('k2', 2, 3)]
|
|
|
|
# XXX: SQLite does not seem to accept parameterized values for the
|
|
# conflict target WHERE clause (e.g., the partial index). So we have to
|
|
# express this literally as ("extra" > 1) rather than using an
|
|
# expression which will be parameterized. Hopefully SQLite's authors
|
|
# decide this is a bug and fix it.
|
|
if IS_SQLITE:
|
|
conflict_where = UKVP.extra > SQL('1')
|
|
else:
|
|
conflict_where = UKVP.extra > 1
|
|
|
|
res = (UKVP.insert_many(data, fields)
|
|
.on_conflict(conflict_target=(UKVP.key, UKVP.value),
|
|
conflict_where=conflict_where,
|
|
preserve=(UKVP.extra,))
|
|
.execute())
|
|
|
|
# How many rows exist? The first one would not have triggered the
|
|
# conflict resolution, since the existing k1/1 row's "extra" value was
|
|
# not greater than 1, thus it did not satisfy the index condition.
|
|
# The second row (k2/2/3) would have triggered the resolution.
|
|
self.assertEqual(UKVP.select().count(), 3)
|
|
query = (UKVP
|
|
.select(UKVP.key, UKVP.value, UKVP.extra)
|
|
.order_by(UKVP.key, UKVP.value, UKVP.extra)
|
|
.tuples())
|
|
|
|
self.assertEqual(list(query), [
|
|
('k1', 1, 1),
|
|
('k1', 1, 2),
|
|
('k2', 2, 3)])
|
|
|
|
# Verify the primary-key of k2 did not change.
|
|
u2_db = UKVP.get(UKVP.key == 'k2')
|
|
self.assertEqual(u2_db.id, u2.id)
|
|
|
|
|
|
@requires_mysql
|
|
class TestUpsertMySQL(OnConflictTests, ModelTestCase):
|
|
def test_replace(self):
|
|
# Unique constraint on first/last would fail - replace.
|
|
query = (Emp
|
|
.insert(first='mickey', last='dog', empno='1337')
|
|
.on_conflict('replace')
|
|
.execute())
|
|
self.assertData([
|
|
('huey', 'cat', '123'),
|
|
('zaizee', 'cat', '124'),
|
|
('mickey', 'dog', '1337')])
|
|
|
|
# Unique constraint on empno would fail - replace.
|
|
query = (Emp
|
|
.insert(first='nuggie', last='dog', empno='123')
|
|
.on_conflict('replace')
|
|
.execute())
|
|
self.assertData([
|
|
('zaizee', 'cat', '124'),
|
|
('mickey', 'dog', '1337'),
|
|
('nuggie', 'dog', '123')])
|
|
|
|
# No problems, data added.
|
|
query = (Emp
|
|
.insert(first='beanie', last='cat', empno='126')
|
|
.on_conflict('replace')
|
|
.execute())
|
|
self.assertData([
|
|
('zaizee', 'cat', '124'),
|
|
('mickey', 'dog', '1337'),
|
|
('nuggie', 'dog', '123'),
|
|
('beanie', 'cat', '126')])
|
|
|
|
@requires_models(OCTest)
|
|
def test_update(self):
|
|
pk = (OCTest
|
|
.insert(a='a', b=3)
|
|
.on_conflict(update={OCTest.b: 1337})
|
|
.execute())
|
|
oc = OCTest.get(OCTest.a == 'a')
|
|
self.assertEqual(oc.b, 3)
|
|
|
|
pk2 = (OCTest
|
|
.insert(a='a', b=4)
|
|
.on_conflict(update={OCTest.b: OCTest.b + 10})
|
|
.execute())
|
|
self.assertEqual(pk, pk2)
|
|
self.assertEqual(OCTest.select().count(), 1)
|
|
|
|
oc = OCTest.get(OCTest.a == 'a')
|
|
self.assertEqual(oc.b, 13)
|
|
|
|
pk3 = (OCTest
|
|
.insert(a='a2', b=5)
|
|
.on_conflict(update={OCTest.b: 1337})
|
|
.execute())
|
|
self.assertTrue(pk3 != pk2)
|
|
self.assertEqual(OCTest.select().count(), 2)
|
|
|
|
oc = OCTest.get(OCTest.a == 'a2')
|
|
self.assertEqual(oc.b, 5)
|
|
|
|
@requires_models(OCTest)
|
|
def test_update_preserve(self):
|
|
OCTest.create(a='a', b=3)
|
|
|
|
pk = (OCTest
|
|
.insert(a='a', b=4)
|
|
.on_conflict(preserve=[OCTest.b])
|
|
.execute())
|
|
oc = OCTest.get(OCTest.a == 'a')
|
|
self.assertEqual(oc.b, 4)
|
|
|
|
pk2 = (OCTest
|
|
.insert(a='a', b=5, c=6)
|
|
.on_conflict(
|
|
preserve=[OCTest.c],
|
|
update={OCTest.b: OCTest.b + 100})
|
|
.execute())
|
|
self.assertEqual(pk, pk2)
|
|
self.assertEqual(OCTest.select().count(), 1)
|
|
|
|
oc = OCTest.get(OCTest.a == 'a')
|
|
self.assertEqual(oc.b, 104)
|
|
self.assertEqual(oc.c, 6)
|
|
|
|
|
|
class TestReplaceSqlite(OnConflictTests, ModelTestCase):
|
|
database = get_in_memory_db()
|
|
|
|
def test_replace(self):
|
|
# Unique constraint on first/last would fail - replace.
|
|
query = (Emp
|
|
.insert(first='mickey', last='dog', empno='1337')
|
|
.on_conflict('replace')
|
|
.execute())
|
|
self.assertData([
|
|
('huey', 'cat', '123'),
|
|
('zaizee', 'cat', '124'),
|
|
('mickey', 'dog', '1337')])
|
|
|
|
# Unique constraint on empno would fail - replace.
|
|
query = (Emp
|
|
.insert(first='nuggie', last='dog', empno='123')
|
|
.on_conflict('replace')
|
|
.execute())
|
|
self.assertData([
|
|
('zaizee', 'cat', '124'),
|
|
('mickey', 'dog', '1337'),
|
|
('nuggie', 'dog', '123')])
|
|
|
|
# No problems, data added.
|
|
query = (Emp
|
|
.insert(first='beanie', last='cat', empno='126')
|
|
.on_conflict('replace')
|
|
.execute())
|
|
self.assertData([
|
|
('zaizee', 'cat', '124'),
|
|
('mickey', 'dog', '1337'),
|
|
('nuggie', 'dog', '123'),
|
|
('beanie', 'cat', '126')])
|
|
|
|
def test_model_replace(self):
|
|
Emp.replace(first='mickey', last='dog', empno='1337').execute()
|
|
self.assertData([
|
|
('huey', 'cat', '123'),
|
|
('zaizee', 'cat', '124'),
|
|
('mickey', 'dog', '1337')])
|
|
|
|
Emp.replace(first='beanie', last='cat', empno='999').execute()
|
|
self.assertData([
|
|
('huey', 'cat', '123'),
|
|
('zaizee', 'cat', '124'),
|
|
('mickey', 'dog', '1337'),
|
|
('beanie', 'cat', '999')])
|
|
|
|
Emp.replace_many([('h', 'cat', '123'), ('z', 'cat', '124'),
|
|
('b', 'cat', '125')],
|
|
fields=[Emp.first, Emp.last, Emp.empno]).execute()
|
|
self.assertData([
|
|
('mickey', 'dog', '1337'),
|
|
('beanie', 'cat', '999'),
|
|
('h', 'cat', '123'),
|
|
('z', 'cat', '124'),
|
|
('b', 'cat', '125')])
|
|
|
|
|
|
@requires_sqlite
|
|
class TestUpsertSqlite(PGOnConflictTests, ModelTestCase):
|
|
database = get_in_memory_db()
|
|
|
|
@skip_if(IS_SQLITE_24, 'requires sqlite < 3.24')
|
|
def test_no_preserve_update_where(self):
|
|
# Ensure on SQLite < 3.24 we cannot update or preserve values.
|
|
base = Emp.insert(first='foo', last='bar', empno='125')
|
|
|
|
preserve = base.on_conflict(preserve=[Emp.last])
|
|
self.assertRaises(ValueError, preserve.execute)
|
|
|
|
update = base.on_conflict(update={Emp.empno: 'xxx'})
|
|
self.assertRaises(ValueError, update.execute)
|
|
|
|
where = base.on_conflict(where=(Emp.id > 10))
|
|
self.assertRaises(ValueError, where.execute)
|
|
|
|
@skip_unless(IS_SQLITE_24, 'requires sqlite >= 3.24')
|
|
def test_update_meets_requirements(self):
|
|
# Ensure that on >= 3.24 any updates meet the minimum criteria.
|
|
base = Emp.insert(first='foo', last='bar', empno='125')
|
|
|
|
# Must specify update or preserve.
|
|
no_update_preserve = base.on_conflict(conflict_target=(Emp.empno,))
|
|
self.assertRaises(ValueError, no_update_preserve.execute)
|
|
|
|
# Must specify a conflict target.
|
|
no_conflict_target = base.on_conflict(update={Emp.empno: '125.1'})
|
|
self.assertRaises(ValueError, no_conflict_target.execute)
|
|
|
|
@skip_unless(IS_SQLITE_24, 'requires sqlite >= 3.24')
|
|
def test_do_nothing(self):
|
|
query = (Emp
|
|
.insert(first='foo', last='bar', empno='123')
|
|
.on_conflict('nothing'))
|
|
self.assertSQL(query, (
|
|
'INSERT INTO "emp" ("first", "last", "empno") '
|
|
'VALUES (?, ?, ?) ON CONFLICT DO NOTHING'), ['foo', 'bar', '123'])
|
|
|
|
query.execute() # Conflict occurs with empno='123'.
|
|
self.assertData(list(self.test_data))
|
|
|
|
|
|
class UKV(TestModel):
|
|
key = TextField()
|
|
value = TextField()
|
|
extra = TextField(default='')
|
|
|
|
class Meta:
|
|
constraints = [
|
|
SQL('constraint ukv_key_value unique(key, value)'),
|
|
]
|
|
|
|
|
|
class UKVRel(TestModel):
|
|
key = TextField()
|
|
value = TextField()
|
|
extra = TextField()
|
|
|
|
class Meta:
|
|
indexes = (
|
|
(('key', 'value'), True),
|
|
)
|
|
|
|
|
|
@requires_pglike
|
|
class TestUpsertPostgresql(PGOnConflictTests, ModelTestCase):
|
|
@requires_postgresql
|
|
@requires_models(UKV)
|
|
def test_conflict_target_constraint(self):
|
|
u1 = UKV.create(key='k1', value='v1')
|
|
u2 = UKV.create(key='k2', value='v2')
|
|
|
|
ret = (UKV.insert(key='k1', value='v1', extra='e1')
|
|
.on_conflict(conflict_target=(UKV.key, UKV.value),
|
|
preserve=(UKV.extra,))
|
|
.execute())
|
|
self.assertEqual(ret, u1.id)
|
|
|
|
# Changes were saved successfully.
|
|
u1_db = UKV.get(UKV.key == 'k1')
|
|
self.assertEqual(u1_db.key, 'k1')
|
|
self.assertEqual(u1_db.value, 'v1')
|
|
self.assertEqual(u1_db.extra, 'e1')
|
|
self.assertEqual(UKV.select().count(), 2)
|
|
|
|
ret = (UKV.insert(key='k2', value='v2', extra='e2')
|
|
.on_conflict(conflict_constraint='ukv_key_value',
|
|
preserve=(UKV.extra,))
|
|
.execute())
|
|
self.assertEqual(ret, u2.id)
|
|
|
|
# Changes were saved successfully.
|
|
u2_db = UKV.get(UKV.key == 'k2')
|
|
self.assertEqual(u2_db.key, 'k2')
|
|
self.assertEqual(u2_db.value, 'v2')
|
|
self.assertEqual(u2_db.extra, 'e2')
|
|
self.assertEqual(UKV.select().count(), 2)
|
|
|
|
ret = (UKV.insert(key='k3', value='v3', extra='e3')
|
|
.on_conflict(conflict_target=[UKV.key, UKV.value],
|
|
preserve=[UKV.extra])
|
|
.execute())
|
|
self.assertTrue(ret > u2_db.id)
|
|
self.assertEqual(UKV.select().count(), 3)
|
|
|
|
@requires_models(UKV, UKVRel)
|
|
def test_conflict_ambiguous_column(self):
|
|
# k1/v1/e1, k2/v2/e0, k3/v3/e1
|
|
for i in [1, 2, 3]:
|
|
UKV.create(key='k%s' % i, value='v%s' % i, extra='e%s' % (i % 2))
|
|
|
|
UKVRel.create(key='k1', value='v1', extra='x1')
|
|
UKVRel.create(key='k2', value='v2', extra='x2')
|
|
|
|
subq = UKV.select(UKV.key, UKV.value, UKV.extra)
|
|
query = (UKVRel
|
|
.insert_from(subq, [UKVRel.key, UKVRel.value, UKVRel.extra])
|
|
.on_conflict(conflict_target=[UKVRel.key, UKVRel.value],
|
|
preserve=[UKVRel.extra],
|
|
where=(UKVRel.key != 'k2')))
|
|
self.assertSQL(query, (
|
|
'INSERT INTO "ukv_rel" ("key", "value", "extra") '
|
|
'SELECT "t1"."key", "t1"."value", "t1"."extra" FROM "ukv" AS "t1" '
|
|
'ON CONFLICT ("key", "value") DO UPDATE '
|
|
'SET "extra" = EXCLUDED."extra" '
|
|
'WHERE ("ukv_rel"."key" != ?) RETURNING "ukv_rel"."id"'), ['k2'])
|
|
|
|
query.execute()
|
|
query = (UKVRel
|
|
.select(UKVRel.key, UKVRel.value, UKVRel.extra)
|
|
.order_by(UKVRel.key))
|
|
self.assertEqual(list(query.tuples()), [
|
|
('k1', 'v1', 'e1'),
|
|
('k2', 'v2', 'x2'),
|
|
('k3', 'v3', 'e1')])
|
|
|
|
@requires_models(Emp)
|
|
def test_upsert_preserves_existing(self):
|
|
#Emp.create(first='beanie', last='cat', empno='998')
|
|
Emp.create(first='beanie', last='cat', empno='999')
|
|
(Emp
|
|
.insert(first='huey', last='kitten', empno='999')
|
|
.on_conflict(
|
|
conflict_target=(Emp.empno,),
|
|
preserve=(Emp.last,))
|
|
.execute())
|
|
obj = Emp.get(Emp.empno == '999')
|
|
self.assertEqual(obj.first, 'beanie')
|
|
# last was NOT preserved, so it gets the val from the insert.
|
|
self.assertEqual(obj.last, 'kitten')
|
|
|
|
@requires_models(Emp)
|
|
def test_upsert_update_expression(self):
|
|
Emp.create(first='huey', last='cat', empno='999')
|
|
(Emp
|
|
.insert(first='hueky', last='kitten', empno='999')
|
|
.on_conflict(
|
|
conflict_target=(Emp.empno,),
|
|
update={Emp.first: Emp.first + 'yyy',
|
|
Emp.last: Emp.last + 'lands'})
|
|
.execute())
|
|
obj = Emp.get(Emp.empno == '999')
|
|
self.assertEqual(obj.first, 'hueyyyy')
|
|
self.assertEqual(obj.last, 'catlands')
|
|
|
|
# ===========================================================================
|
|
# FOR UPDATE, RETURNING, UPDATE FROM, and LATERAL
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
@skip_if(IS_SQLITE or (IS_MYSQL and not IS_MYSQL_ADVANCED_FEATURES))
|
|
@skip_unless(db.for_update, 'requires for update')
|
|
class TestForUpdateIntegration(ModelTestCase):
|
|
requires = [User, Tweet]
|
|
|
|
def setUp(self):
|
|
super(TestForUpdateIntegration, self).setUp()
|
|
self.alt_db = new_connection()
|
|
class AltUser(User):
|
|
class Meta:
|
|
database = self.alt_db
|
|
table_name = User._meta.table_name
|
|
class AltTweet(Tweet):
|
|
class Meta:
|
|
database = self.alt_db
|
|
table_name = Tweet._meta.table_name
|
|
self.AltUser = AltUser
|
|
self.AltTweet = AltTweet
|
|
|
|
def tearDown(self):
|
|
self.alt_db.close()
|
|
super(TestForUpdateIntegration, self).tearDown()
|
|
|
|
@skip_if(IS_CRDB, 'crdb locks-up on this test, blocking reads')
|
|
def test_for_update(self):
|
|
with self.database.atomic():
|
|
User.create(username='huey')
|
|
zaizee = User.create(username='zaizee')
|
|
|
|
AltUser = self.AltUser
|
|
|
|
with self.database.manual_commit():
|
|
self.database.begin()
|
|
users = (User.select().where(User.username == 'zaizee')
|
|
.for_update()
|
|
.execute())
|
|
updated = (User
|
|
.update(username='ziggy')
|
|
.where(User.username == 'zaizee')
|
|
.execute())
|
|
self.assertEqual(updated, 1)
|
|
|
|
if IS_POSTGRESQL:
|
|
nrows = (AltUser
|
|
.update(username='huey-x')
|
|
.where(AltUser.username == 'huey')
|
|
.execute())
|
|
self.assertEqual(nrows, 1)
|
|
|
|
query = (AltUser
|
|
.select(AltUser.username)
|
|
.where(AltUser.id == zaizee.id))
|
|
self.assertEqual(query.get().username, 'zaizee')
|
|
|
|
self.database.commit()
|
|
self.assertEqual(query.get().username, 'ziggy')
|
|
|
|
def test_for_update_blocking(self):
|
|
User.create(username='u1')
|
|
|
|
AltUser = self.AltUser
|
|
evt = threading.Event()
|
|
def run_in_thread():
|
|
with self.alt_db.atomic():
|
|
evt.wait()
|
|
n = (AltUser.update(username='u1-y')
|
|
.where(AltUser.username == 'u1')
|
|
.execute())
|
|
self.assertEqual(n, 0)
|
|
|
|
t = threading.Thread(target=run_in_thread)
|
|
t.daemon = True
|
|
t.start()
|
|
|
|
with self.database.atomic() as txn:
|
|
q = (User.select()
|
|
.where(User.username == 'u1')
|
|
.for_update()
|
|
.execute())
|
|
evt.set()
|
|
|
|
n = (User.update(username='u1-x')
|
|
.where(User.username == 'u1')
|
|
.execute())
|
|
self.assertEqual(n, 1)
|
|
|
|
t.join(timeout=5)
|
|
u = User.get()
|
|
self.assertEqual(u.username, 'u1-x')
|
|
|
|
def test_for_update_nested(self):
|
|
User.insert_many([(u,) for u in 'abc']).execute()
|
|
subq = User.select().where(User.username != 'b').for_update()
|
|
nrows = (User
|
|
.delete()
|
|
.where(User.id.in_(subq))
|
|
.execute())
|
|
self.assertEqual(nrows, 2)
|
|
|
|
def test_for_update_nowait(self):
|
|
User.create(username='huey')
|
|
zaizee = User.create(username='zaizee')
|
|
|
|
AltUser = self.AltUser
|
|
|
|
with self.database.manual_commit():
|
|
self.database.begin()
|
|
users = (User
|
|
.select(User.username)
|
|
.where(User.username == 'zaizee')
|
|
.for_update(nowait=True)
|
|
.execute())
|
|
|
|
def will_fail():
|
|
return (AltUser
|
|
.select()
|
|
.where(AltUser.username == 'zaizee')
|
|
.for_update(nowait=True)
|
|
.get())
|
|
|
|
self.assertRaises((OperationalError, InternalError), will_fail)
|
|
self.database.commit()
|
|
|
|
@requires_postgresql
|
|
@requires_models(User, Tweet)
|
|
def test_for_update_of(self):
|
|
h = User.create(username='huey')
|
|
z = User.create(username='zaizee')
|
|
Tweet.create(user=h, content='h')
|
|
Tweet.create(user=z, content='z')
|
|
|
|
AltUser, AltTweet = self.AltUser, self.AltTweet
|
|
|
|
with self.database.manual_commit():
|
|
self.database.begin()
|
|
# Lock tweets by huey.
|
|
query = (Tweet
|
|
.select()
|
|
.join(User)
|
|
.where(User.username == 'huey')
|
|
.for_update(of=Tweet, nowait=True))
|
|
qr = query.execute()
|
|
|
|
# No problem updating zaizee's tweet or huey's user.
|
|
nrows = (AltTweet
|
|
.update(content='zx')
|
|
.where(AltTweet.user == z.id)
|
|
.execute())
|
|
self.assertEqual(nrows, 1)
|
|
|
|
nrows = (AltUser
|
|
.update(username='huey-x')
|
|
.where(AltUser.username == 'huey')
|
|
.execute())
|
|
self.assertEqual(nrows, 1)
|
|
|
|
def will_fail():
|
|
(AltTweet
|
|
.select()
|
|
.where(AltTweet.user == h)
|
|
.for_update(nowait=True)
|
|
.get())
|
|
self.assertRaises((OperationalError, InternalError), will_fail)
|
|
|
|
self.database.commit()
|
|
|
|
query = Tweet.select(Tweet, User).join(User).order_by(Tweet.id)
|
|
self.assertEqual([(t.content, t.user.username) for t in query],
|
|
[('h', 'huey-x'), ('zx', 'zaizee')])
|
|
|
|
|
|
class ServerDefault(TestModel):
|
|
timestamp = DateTimeField(constraints=[SQL('default (now())')])
|
|
|
|
|
|
@requires_postgresql
|
|
class TestReturningIntegration(ModelTestCase):
|
|
requires = [User]
|
|
|
|
def test_simple_returning(self):
|
|
query = User.insert(username='charlie')
|
|
self.assertSQL(query, (
|
|
'INSERT INTO "users" ("username") VALUES (?) '
|
|
'RETURNING "users"."id"'),
|
|
['charlie'])
|
|
|
|
self.assertEqual(query.execute(), 1)
|
|
|
|
# By default returns a tuple.
|
|
query = User.insert(username='huey')
|
|
self.assertEqual(query.execute(), 2)
|
|
self.assertEqual(list(query), [(2,)])
|
|
|
|
# If we specify a returning clause we get user instances.
|
|
query = User.insert(username='snoobie').returning(User)
|
|
query.execute()
|
|
self.assertEqual([x.username for x in query], ['snoobie'])
|
|
|
|
query = (User
|
|
.insert(username='zaizee')
|
|
.returning(User.id, User.username)
|
|
.dicts())
|
|
self.assertSQL(query, (
|
|
'INSERT INTO "users" ("username") VALUES (?) '
|
|
'RETURNING "users"."id", "users"."username"'), ['zaizee'])
|
|
|
|
cursor = query.execute()
|
|
row, = list(cursor)
|
|
self.assertEqual(row, {'id': 4, 'username': 'zaizee'})
|
|
|
|
query = (User
|
|
.insert(username='mickey')
|
|
.returning(User)
|
|
.objects())
|
|
self.assertSQL(query, (
|
|
'INSERT INTO "users" ("username") VALUES (?) '
|
|
'RETURNING "users"."id", "users"."username"'), ['mickey'])
|
|
cursor = query.execute()
|
|
row, = list(cursor)
|
|
self.assertEqual(row.id, 5)
|
|
self.assertEqual(row.username, 'mickey')
|
|
|
|
# Can specify aliases.
|
|
query = (User
|
|
.insert(username='sipp')
|
|
.returning(User.username.alias('new_username')))
|
|
self.assertEqual([x.new_username for x in query.execute()], ['sipp'])
|
|
|
|
# Minimal test with insert_many.
|
|
query = User.insert_many([('u7',), ('u8',)])
|
|
self.assertEqual([r for r, in query.execute()], [7, 8])
|
|
|
|
# Test with insert / on conflict.
|
|
query = (User
|
|
.insert_many([(7, 'u7',), (9, 'u9',)],
|
|
[User.id, User.username])
|
|
.on_conflict(conflict_target=[User.id],
|
|
update={User.username: User.username + 'x'})
|
|
.returning(User))
|
|
self.assertEqual([(x.id, x.username) for x in query],
|
|
[(7, 'u7x'), (9, 'u9')])
|
|
|
|
def test_simple_returning_insert_update_delete(self):
|
|
res = User.insert(username='charlie').returning(User).execute()
|
|
self.assertEqual([u.username for u in res], ['charlie'])
|
|
|
|
res = (User
|
|
.update(username='charlie2')
|
|
.where(User.id == 1)
|
|
.returning(User)
|
|
.execute())
|
|
# Subsequent iterations are cached.
|
|
for _ in range(2):
|
|
self.assertEqual([u.username for u in res], ['charlie2'])
|
|
|
|
res = (User
|
|
.delete()
|
|
.where(User.id == 1)
|
|
.returning(User)
|
|
.execute())
|
|
# Subsequent iterations are cached.
|
|
for _ in range(2):
|
|
self.assertEqual([u.username for u in res], ['charlie2'])
|
|
|
|
def test_simple_insert_update_delete_no_returning(self):
|
|
query = User.insert(username='charlie')
|
|
self.assertEqual(query.execute(), 1)
|
|
|
|
query = User.insert(username='huey')
|
|
self.assertEqual(query.execute(), 2)
|
|
|
|
query = User.update(username='huey2').where(User.username == 'huey')
|
|
self.assertEqual(query.execute(), 1)
|
|
self.assertEqual(query.execute(), 0) # No rows updated!
|
|
|
|
query = User.delete().where(User.username == 'huey2')
|
|
self.assertEqual(query.execute(), 1)
|
|
self.assertEqual(query.execute(), 0) # No rows updated!
|
|
|
|
@requires_models(ServerDefault)
|
|
def test_returning_server_defaults(self):
|
|
query = (ServerDefault
|
|
.insert()
|
|
.returning(ServerDefault.id, ServerDefault.timestamp))
|
|
self.assertSQL(query, (
|
|
'INSERT INTO "server_default" '
|
|
'DEFAULT VALUES '
|
|
'RETURNING "server_default"."id", "server_default"."timestamp"'),
|
|
[])
|
|
|
|
with self.assertQueryCount(1):
|
|
cursor = query.dicts().execute()
|
|
row, = list(cursor)
|
|
|
|
self.assertTrue(row['timestamp'] is not None)
|
|
|
|
obj = ServerDefault.get(ServerDefault.id == row['id'])
|
|
self.assertEqual(obj.timestamp, row['timestamp'])
|
|
|
|
def test_no_return(self):
|
|
query = User.insert(username='huey').returning()
|
|
self.assertIsNone(query.execute())
|
|
|
|
user = User.get(User.username == 'huey')
|
|
self.assertEqual(user.username, 'huey')
|
|
self.assertTrue(user.id >= 1)
|
|
|
|
@requires_models(Category)
|
|
def test_non_int_pk_returning(self):
|
|
query = Category.insert(name='root')
|
|
self.assertSQL(query, (
|
|
'INSERT INTO "category" ("name") VALUES (?) '
|
|
'RETURNING "category"."name"'), ['root'])
|
|
|
|
self.assertEqual(query.execute(), 'root')
|
|
|
|
def test_returning_multi(self):
|
|
data = [{'username': 'huey'}, {'username': 'mickey'}]
|
|
query = User.insert_many(data)
|
|
self.assertSQL(query, (
|
|
'INSERT INTO "users" ("username") VALUES (?), (?) '
|
|
'RETURNING "users"."id"'), ['huey', 'mickey'])
|
|
|
|
data = query.execute()
|
|
|
|
# Check that the result wrapper is correctly set up.
|
|
self.assertTrue(len(data.select) == 1 and data.select[0] is User.id)
|
|
self.assertEqual(list(data), [(1,), (2,)])
|
|
|
|
query = (User
|
|
.insert_many([{'username': 'foo'},
|
|
{'username': 'bar'},
|
|
{'username': 'baz'}])
|
|
.returning(User.id, User.username)
|
|
.namedtuples())
|
|
data = query.execute()
|
|
self.assertEqual([(row.id, row.username) for row in data], [
|
|
(3, 'foo'),
|
|
(4, 'bar'),
|
|
(5, 'baz')])
|
|
|
|
@requires_models(Category)
|
|
def test_returning_query(self):
|
|
for name in ('huey', 'mickey', 'zaizee'):
|
|
Category.create(name=name)
|
|
|
|
source = Category.select(Category.name).order_by(Category.name)
|
|
query = User.insert_from(source, (User.username,))
|
|
self.assertSQL(query, (
|
|
'INSERT INTO "users" ("username") '
|
|
'SELECT "t1"."name" FROM "category" AS "t1" ORDER BY "t1"."name" '
|
|
'RETURNING "users"."id"'), [])
|
|
|
|
data = query.execute()
|
|
|
|
# Check that the result wrapper is correctly set up.
|
|
self.assertTrue(len(data.select) == 1 and data.select[0] is User.id)
|
|
self.assertEqual(list(data), [(1,), (2,), (3,)])
|
|
|
|
def test_update_returning(self):
|
|
id_list = User.insert_many([{'username': 'huey'},
|
|
{'username': 'zaizee'}]).execute()
|
|
huey_id, zaizee_id = [pk for pk, in id_list]
|
|
|
|
query = (User
|
|
.update(username='ziggy')
|
|
.where(User.username == 'zaizee')
|
|
.returning(User.id, User.username))
|
|
self.assertSQL(query, (
|
|
'UPDATE "users" SET "username" = ? '
|
|
'WHERE ("users"."username" = ?) '
|
|
'RETURNING "users"."id", "users"."username"'), ['ziggy', 'zaizee'])
|
|
data = query.execute()
|
|
user = data[0]
|
|
self.assertEqual(user.username, 'ziggy')
|
|
self.assertEqual(user.id, zaizee_id)
|
|
|
|
def test_delete_returning(self):
|
|
id_list = User.insert_many([{'username': 'huey'},
|
|
{'username': 'zaizee'}]).execute()
|
|
huey_id, zaizee_id = [pk for pk, in id_list]
|
|
|
|
query = (User
|
|
.delete()
|
|
.where(User.username == 'zaizee')
|
|
.returning(User.id, User.username))
|
|
self.assertSQL(query, (
|
|
'DELETE FROM "users" WHERE ("users"."username" = ?) '
|
|
'RETURNING "users"."id", "users"."username"'), ['zaizee'])
|
|
data = query.execute()
|
|
user = data[0]
|
|
self.assertEqual(user.username, 'zaizee')
|
|
self.assertEqual(user.id, zaizee_id)
|
|
|
|
|
|
class Reg(TestModel):
|
|
k = CharField()
|
|
v = IntegerField()
|
|
x = IntegerField()
|
|
class Meta:
|
|
indexes = (
|
|
(('k', 'v'), True),
|
|
)
|
|
|
|
|
|
returning_support = db.returning_clause or IS_SQLITE_35
|
|
|
|
|
|
@skip_unless(returning_support, 'database does not support RETURNING')
|
|
class TestReturningClauseIntegration(ModelTestCase):
|
|
requires = [Reg]
|
|
|
|
def test_crud(self):
|
|
iq = Reg.insert_many([('k1', 1, 0), ('k2', 2, 0)]).returning(Reg)
|
|
self.assertEqual([(r.id is not None, r.k, r.v) for r in iq.execute()],
|
|
[(True, 'k1', 1), (True, 'k2', 2)])
|
|
|
|
iq = (Reg
|
|
.insert_many([('k1', 1, 1), ('k2', 2, 1), ('k3', 3, 0)])
|
|
.on_conflict(
|
|
conflict_target=[Reg.k, Reg.v],
|
|
preserve=[Reg.x],
|
|
update={Reg.v: Reg.v + 1},
|
|
where=(Reg.k != 'k1'))
|
|
.returning(Reg))
|
|
ic = iq.execute()
|
|
self.assertEqual([(r.id is not None, r.k, r.v, r.x) for r in ic], [
|
|
(True, 'k2', 3, 1),
|
|
(True, 'k3', 3, 0)])
|
|
|
|
uq = (Reg
|
|
.update({Reg.v: Reg.v - 1, Reg.x: Reg.x + 1})
|
|
.where(Reg.k != 'k1')
|
|
.returning(Reg))
|
|
self.assertEqual([(r.k, r.v, r.x) for r in uq.execute()], [
|
|
('k2', 2, 2), ('k3', 2, 1)])
|
|
|
|
dq = Reg.delete().where(Reg.k != 'k1').returning(Reg)
|
|
self.assertEqual([(r.k, r.v, r.x) for r in dq.execute()], [
|
|
('k2', 2, 2), ('k3', 2, 1)])
|
|
|
|
def test_returning_expression(self):
|
|
Rs = (Reg.v + Reg.x).alias('s')
|
|
iq = (Reg
|
|
.insert_many([('k1', 1, 10), ('k2', 2, 20)])
|
|
.returning(Reg.k, Reg.v, Rs))
|
|
self.assertEqual([(r.k, r.v, r.s) for r in iq.execute()], [
|
|
('k1', 1, 11), ('k2', 2, 22)])
|
|
|
|
uq = (Reg
|
|
.update({Reg.k: Reg.k + 'x', Reg.v: Reg.v + 1})
|
|
.returning(Reg.k, Reg.v, Rs))
|
|
self.assertEqual([(r.k, r.v, r.s) for r in uq.execute()], [
|
|
('k1x', 2, 12), ('k2x', 3, 23)])
|
|
|
|
dq = Reg.delete().returning(Reg.k, Reg.v, Rs)
|
|
self.assertEqual([(r.k, r.v, r.s) for r in dq.execute()], [
|
|
('k1x', 2, 12), ('k2x', 3, 23)])
|
|
|
|
def test_returning_types(self):
|
|
Rs = (Reg.v + Reg.x).alias('s')
|
|
mapping = (
|
|
((lambda q: q), (lambda r: (r.k, r.v, r.s))),
|
|
((lambda q: q.dicts()), (lambda r: (r['k'], r['v'], r['s']))),
|
|
((lambda q: q.tuples()), (lambda r: r)),
|
|
((lambda q: q.namedtuples()), (lambda r: (r.k, r.v, r.s))))
|
|
|
|
for qconv, r2t in mapping:
|
|
iq = (Reg
|
|
.insert_many([('k1', 1, 10), ('k2', 2, 20)])
|
|
.returning(Reg.k, Reg.v, Rs))
|
|
self.assertEqual([r2t(r) for r in qconv(iq).execute()], [
|
|
('k1', 1, 11), ('k2', 2, 22)])
|
|
|
|
uq = (Reg
|
|
.update({Reg.k: Reg.k + 'x', Reg.v: Reg.v + 1})
|
|
.returning(Reg.k, Reg.v, Rs))
|
|
self.assertEqual([r2t(r) for r in qconv(uq).execute()], [
|
|
('k1x', 2, 12), ('k2x', 3, 23)])
|
|
|
|
dq = Reg.delete().returning(Reg.k, Reg.v, Rs)
|
|
self.assertEqual([r2t(r) for r in qconv(dq).execute()], [
|
|
('k1x', 2, 12), ('k2x', 3, 23)])
|
|
|
|
|
|
@requires_postgresql
|
|
class TestUpdateFromIntegration(ModelTestCase):
|
|
requires = [User]
|
|
|
|
def test_update_from(self):
|
|
u1, u2 = [User.create(username=username) for username in ('u1', 'u2')]
|
|
data = [(u1.id, 'u1-x'), (u2.id, 'u2-x')]
|
|
vl = ValuesList(data, columns=('id', 'username'), alias='tmp')
|
|
(User
|
|
.update({User.username: vl.c.username})
|
|
.from_(vl)
|
|
.where(User.id == vl.c.id)
|
|
.execute())
|
|
|
|
usernames = [u.username for u in User.select().order_by(User.username)]
|
|
self.assertEqual(usernames, ['u1-x', 'u2-x'])
|
|
|
|
def test_update_from_subselect(self):
|
|
u1, u2 = [User.create(username=username) for username in ('u1', 'u2')]
|
|
data = [(u1.id, 'u1-y'), (u2.id, 'u2-y')]
|
|
vl = ValuesList(data, columns=('id', 'username'), alias='tmp')
|
|
subq = vl.select(vl.c.id, vl.c.username)
|
|
(User
|
|
.update({User.username: subq.c.username})
|
|
.from_(subq)
|
|
.where(User.id == subq.c.id)
|
|
.execute())
|
|
|
|
usernames = [u.username for u in User.select().order_by(User.username)]
|
|
self.assertEqual(usernames, ['u1-y', 'u2-y'])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_update_from_simple(self):
|
|
u = User.create(username='u1')
|
|
t1 = Tweet.create(user=u, content='t1')
|
|
t2 = Tweet.create(user=u, content='t2')
|
|
|
|
(User
|
|
.update({User.username: Tweet.content})
|
|
.from_(Tweet)
|
|
.where(Tweet.content == 't2')
|
|
.execute())
|
|
|
|
self.assertEqual(User.get(User.id == u.id).username, 't2')
|
|
|
|
|
|
@requires_postgresql
|
|
class TestLateralJoin(ModelTestCase):
|
|
requires = [User, Tweet]
|
|
|
|
def test_lateral_join(self):
|
|
with self.database.atomic():
|
|
for i in range(3):
|
|
u = User.create(username='u%s' % i)
|
|
for j in range(4):
|
|
Tweet.create(user=u, content='u%s-t%s' % (i, j))
|
|
|
|
# GOAL: query users and their 2 most-recent tweets (by ID).
|
|
TA = Tweet.alias()
|
|
|
|
# The "outer loop" will be iterating over the users whose tweets we are
|
|
# trying to find.
|
|
user_query = (User
|
|
.select(User.id, User.username)
|
|
.order_by(User.id)
|
|
.alias('uq'))
|
|
|
|
# The inner loop will select tweets and is correlated to the outer loop
|
|
# via the WHERE clause. Note that we are using a LIMIT clause.
|
|
tweet_query = (TA
|
|
.select(TA.id, TA.content)
|
|
.where(TA.user == user_query.c.id)
|
|
.order_by(TA.id.desc())
|
|
.limit(2)
|
|
.alias('pq'))
|
|
|
|
join = NodeList((user_query, SQL('LEFT JOIN LATERAL'), tweet_query,
|
|
SQL('ON %s', [True])))
|
|
query = (Tweet
|
|
.select(user_query.c.username, tweet_query.c.content)
|
|
.from_(join)
|
|
.dicts())
|
|
self.assertEqual([row for row in query], [
|
|
{'username': 'u0', 'content': 'u0-t3'},
|
|
{'username': 'u0', 'content': 'u0-t2'},
|
|
{'username': 'u1', 'content': 'u1-t3'},
|
|
{'username': 'u1', 'content': 'u1-t2'},
|
|
{'username': 'u2', 'content': 'u2-t3'},
|
|
{'username': 'u2', 'content': 'u2-t2'}])
|
|
|
|
# ===========================================================================
|
|
# Bulk operations
|
|
# ===========================================================================
|
|
|
|
|
|
class BCUser(TestModel):
|
|
username = CharField(unique=True)
|
|
|
|
class BCTweet(TestModel):
|
|
user = ForeignKeyField(BCUser, field=BCUser.username)
|
|
content = TextField()
|
|
|
|
|
|
class TestBulkCreateWithFK(ModelTestCase):
|
|
@requires_models(BCUser, BCTweet)
|
|
def test_bulk_create_with_fk(self):
|
|
u1 = BCUser.create(username='u1')
|
|
u2 = BCUser.create(username='u2')
|
|
with self.assertQueryCount(1):
|
|
BCTweet.bulk_create([
|
|
BCTweet(user='u1', content='t%s' % i)
|
|
for i in range(4)])
|
|
|
|
self.assertEqual(BCTweet.select().where(BCTweet.user == 'u1').count(), 4)
|
|
self.assertEqual(BCTweet.select().where(BCTweet.user != 'u1').count(), 0)
|
|
|
|
u = BCUser(username='u3')
|
|
t = BCTweet(user=u, content='tx')
|
|
with self.assertQueryCount(2):
|
|
BCUser.bulk_create([u])
|
|
BCTweet.bulk_create([t])
|
|
|
|
with self.assertQueryCount(1):
|
|
t_db = (BCTweet
|
|
.select(BCTweet, BCUser)
|
|
.join(BCUser)
|
|
.where(BCUser.username == 'u3')
|
|
.get())
|
|
self.assertEqual(t_db.content, 'tx')
|
|
self.assertEqual(t_db.user.username, 'u3')
|
|
|
|
@requires_postgresql
|
|
@requires_models(User, Tweet)
|
|
def test_bulk_create_related_objects(self):
|
|
u = User(username='u1')
|
|
t = Tweet(user=u, content='t1')
|
|
with self.assertQueryCount(2):
|
|
User.bulk_create([u])
|
|
Tweet.bulk_create([t])
|
|
|
|
with self.assertQueryCount(1):
|
|
t_db = Tweet.select(Tweet, User).join(User).get()
|
|
self.assertEqual(t_db.content, 't1')
|
|
self.assertEqual(t_db.user.username, 'u1')
|
|
|
|
class UUIDReg(TestModel):
|
|
id = UUIDField(primary_key=True, default=uuid.uuid4)
|
|
key = TextField()
|
|
|
|
class CharPKKV(TestModel):
|
|
id = CharField(primary_key=True)
|
|
key = TextField()
|
|
value = IntegerField(default=0)
|
|
|
|
|
|
class TestBulkUpdateNonIntegerPK(ModelTestCase):
|
|
@requires_models(UUIDReg)
|
|
def test_bulk_update_uuid_pk(self):
|
|
r1 = UUIDReg.create(key='k1')
|
|
r2 = UUIDReg.create(key='k2')
|
|
r1.key = 'k1-x'
|
|
r2.key = 'k2-x'
|
|
UUIDReg.bulk_update((r1, r2), (UUIDReg.key,))
|
|
|
|
r1_db, r2_db = UUIDReg.select().order_by(UUIDReg.key)
|
|
self.assertEqual(r1_db.key, 'k1-x')
|
|
self.assertEqual(r2_db.key, 'k2-x')
|
|
|
|
@requires_models(CharPKKV)
|
|
def test_bulk_update_non_integer_pk(self):
|
|
a, b, c = [CharPKKV.create(id=c, key='k%s' % c) for c in 'abc']
|
|
a.key = 'ka-x'
|
|
a.value = 1
|
|
b.value = 2
|
|
c.key = 'kc-x'
|
|
c.value = 3
|
|
CharPKKV.bulk_update((a, b, c), (CharPKKV.key, CharPKKV.value))
|
|
|
|
data = list(CharPKKV.select().order_by(CharPKKV.id).tuples())
|
|
self.assertEqual(data, [
|
|
('a', 'ka-x', 1),
|
|
('b', 'kb', 2),
|
|
('c', 'kc-x', 3)])
|
|
|
|
class NDF(TestModel):
|
|
key = CharField(primary_key=True)
|
|
date = DateTimeField(null=True)
|
|
|
|
class TestBulkUpdateAllNull(ModelTestCase):
|
|
requires = [NDF]
|
|
|
|
@skip_unless(IS_SQLITE or IS_MYSQL, 'postgres cannot do this properly')
|
|
def test_bulk_update_all_null(self):
|
|
n1 = NDF.create(key='n1', date=datetime.datetime(2021, 1, 1))
|
|
n2 = NDF.create(key='n2', date=datetime.datetime(2021, 1, 2))
|
|
rows = [NDF(key=key, date=None) for key in ('n1', 'n2')]
|
|
NDF.bulk_update(rows, fields=['date'])
|
|
|
|
query = NDF.select().order_by(NDF.key).tuples()
|
|
self.assertEqual([r for r in query], [('n1', None), ('n2', None)])
|
|
|
|
class IMC(TestModel):
|
|
a = IntegerField()
|
|
b = IntegerField(null=True)
|
|
|
|
class TestChunkedInsertMany(ModelTestCase):
|
|
requires = [IMC]
|
|
|
|
def test_chunked_insert_many(self):
|
|
data = [(i, i if i % 2 == 0 else None) for i in range(100)]
|
|
for chunk in chunked(data, 10):
|
|
IMC.insert_many(chunk).execute()
|
|
|
|
q = IMC.select(IMC.a, IMC.b).order_by(IMC.id).tuples()
|
|
self.assertEqual(list(q), data)
|
|
IMC.delete().execute()
|
|
|
|
data = [{'a': i, 'b': i if i % 2 == 0 else None} for i in range(100)]
|
|
for chunk in chunked(data, 5):
|
|
IMC.insert_many(chunk).execute()
|
|
q = IMC.select(IMC.a, IMC.b).order_by(IMC.id).dicts()
|
|
self.assertEqual(list(q), data)
|
|
IMC.delete().execute()
|
|
|
|
# ===========================================================================
|
|
# Model metadata and configuration
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
class TestModelGraph(BaseTestCase):
|
|
def test_bind_model_database(self):
|
|
class User(Model): pass
|
|
class Tweet(Model):
|
|
user = ForeignKeyField(User)
|
|
class Relationship(Model):
|
|
from_user = ForeignKeyField(User, backref='relationships')
|
|
to_user = ForeignKeyField(User, backref='related_to')
|
|
class Flag(Model):
|
|
tweet = ForeignKeyField(Tweet)
|
|
class Unrelated(Model): pass
|
|
|
|
fake_db = SqliteDatabase(None)
|
|
User.bind(fake_db)
|
|
for model in (User, Tweet, Relationship, Flag):
|
|
self.assertTrue(model._meta.database is fake_db)
|
|
self.assertTrue(Unrelated._meta.database is None)
|
|
User.bind(None)
|
|
|
|
with User.bind_ctx(fake_db) as (FUser,):
|
|
self.assertTrue(FUser._meta.database is fake_db)
|
|
self.assertTrue(Unrelated._meta.database is None)
|
|
|
|
self.assertTrue(User._meta.database is None)
|
|
|
|
|
|
class TestFieldInheritance(BaseTestCase):
|
|
def test_field_inheritance(self):
|
|
class BaseModel(Model):
|
|
class Meta:
|
|
database = get_in_memory_db()
|
|
|
|
class BasePost(BaseModel):
|
|
content = TextField()
|
|
timestamp = TimestampField()
|
|
|
|
class Photo(BasePost):
|
|
image = TextField()
|
|
|
|
class Note(BasePost):
|
|
category = TextField()
|
|
|
|
self.assertEqual(BasePost._meta.sorted_field_names,
|
|
['id', 'content', 'timestamp'])
|
|
self.assertEqual(BasePost._meta.sorted_fields, [
|
|
BasePost.id,
|
|
BasePost.content,
|
|
BasePost.timestamp])
|
|
|
|
self.assertEqual(Photo._meta.sorted_field_names,
|
|
['id', 'content', 'timestamp', 'image'])
|
|
self.assertEqual(Photo._meta.sorted_fields, [
|
|
Photo.id,
|
|
Photo.content,
|
|
Photo.timestamp,
|
|
Photo.image])
|
|
|
|
self.assertEqual(Note._meta.sorted_field_names,
|
|
['id', 'content', 'timestamp', 'category'])
|
|
self.assertEqual(Note._meta.sorted_fields, [
|
|
Note.id,
|
|
Note.content,
|
|
Note.timestamp,
|
|
Note.category])
|
|
|
|
self.assertTrue(id(Photo.id) != id(Note.id))
|
|
|
|
def test_foreign_key_field_inheritance(self):
|
|
class BaseModel(Model):
|
|
class Meta:
|
|
database = get_in_memory_db()
|
|
|
|
class Category(BaseModel):
|
|
name = TextField()
|
|
|
|
class BasePost(BaseModel):
|
|
category = ForeignKeyField(Category)
|
|
timestamp = TimestampField()
|
|
|
|
class Photo(BasePost):
|
|
image = TextField()
|
|
|
|
class Note(BasePost):
|
|
content = TextField()
|
|
|
|
self.assertEqual(BasePost._meta.sorted_field_names,
|
|
['id', 'category', 'timestamp'])
|
|
self.assertEqual(BasePost._meta.sorted_fields, [
|
|
BasePost.id,
|
|
BasePost.category,
|
|
BasePost.timestamp])
|
|
|
|
self.assertEqual(Photo._meta.sorted_field_names,
|
|
['id', 'category', 'timestamp', 'image'])
|
|
self.assertEqual(Photo._meta.sorted_fields, [
|
|
Photo.id,
|
|
Photo.category,
|
|
Photo.timestamp,
|
|
Photo.image])
|
|
|
|
self.assertEqual(Note._meta.sorted_field_names,
|
|
['id', 'category', 'timestamp', 'content'])
|
|
self.assertEqual(Note._meta.sorted_fields, [
|
|
Note.id,
|
|
Note.category,
|
|
Note.timestamp,
|
|
Note.content])
|
|
|
|
self.assertEqual(Category._meta.backrefs, {
|
|
BasePost.category: BasePost,
|
|
Photo.category: Photo,
|
|
Note.category: Note})
|
|
self.assertEqual(BasePost._meta.refs, {BasePost.category: Category})
|
|
self.assertEqual(Photo._meta.refs, {Photo.category: Category})
|
|
self.assertEqual(Note._meta.refs, {Note.category: Category})
|
|
|
|
self.assertEqual(BasePost.category.backref, 'basepost_set')
|
|
self.assertEqual(Photo.category.backref, 'photo_set')
|
|
self.assertEqual(Note.category.backref, 'note_set')
|
|
|
|
def test_foreign_key_pk_inheritance(self):
|
|
class BaseModel(Model):
|
|
class Meta:
|
|
database = get_in_memory_db()
|
|
class Account(BaseModel): pass
|
|
class BaseUser(BaseModel):
|
|
account = ForeignKeyField(Account, primary_key=True)
|
|
class User(BaseUser):
|
|
username = TextField()
|
|
class Admin(BaseUser):
|
|
role = TextField()
|
|
|
|
self.assertEqual(Account._meta.backrefs, {
|
|
Admin.account: Admin,
|
|
User.account: User,
|
|
BaseUser.account: BaseUser})
|
|
|
|
self.assertEqual(BaseUser.account.backref, 'baseuser_set')
|
|
self.assertEqual(User.account.backref, 'user_set')
|
|
self.assertEqual(Admin.account.backref, 'admin_set')
|
|
self.assertTrue(Account.user_set.model is Account)
|
|
self.assertTrue(Account.admin_set.model is Account)
|
|
self.assertTrue(Account.user_set.rel_model is User)
|
|
self.assertTrue(Account.admin_set.rel_model is Admin)
|
|
|
|
self.assertSQL(Account._schema._create_table(), (
|
|
'CREATE TABLE IF NOT EXISTS "account" ('
|
|
'"id" INTEGER NOT NULL PRIMARY KEY)'), [])
|
|
|
|
self.assertSQL(User._schema._create_table(), (
|
|
'CREATE TABLE IF NOT EXISTS "user" ('
|
|
'"account_id" INTEGER NOT NULL PRIMARY KEY, '
|
|
'"username" TEXT NOT NULL, '
|
|
'FOREIGN KEY ("account_id") REFERENCES "account" ("id"))'), [])
|
|
|
|
self.assertSQL(Admin._schema._create_table(), (
|
|
'CREATE TABLE IF NOT EXISTS "admin" ('
|
|
'"account_id" INTEGER NOT NULL PRIMARY KEY, '
|
|
'"role" TEXT NOT NULL, '
|
|
'FOREIGN KEY ("account_id") REFERENCES "account" ("id"))'), [])
|
|
|
|
def test_backref_inheritance(self):
|
|
class Category(TestModel): pass
|
|
def backref(fk_field):
|
|
return '%ss' % fk_field.model._meta.name
|
|
class BasePost(TestModel):
|
|
category = ForeignKeyField(Category, backref=backref)
|
|
class Note(BasePost): pass
|
|
class Photo(BasePost): pass
|
|
|
|
self.assertEqual(Category._meta.backrefs, {
|
|
BasePost.category: BasePost,
|
|
Note.category: Note,
|
|
Photo.category: Photo})
|
|
self.assertEqual(BasePost.category.backref, 'baseposts')
|
|
self.assertEqual(Note.category.backref, 'notes')
|
|
self.assertEqual(Photo.category.backref, 'photos')
|
|
self.assertTrue(Category.baseposts.rel_model is BasePost)
|
|
self.assertTrue(Category.baseposts.model is Category)
|
|
self.assertTrue(Category.notes.rel_model is Note)
|
|
self.assertTrue(Category.notes.model is Category)
|
|
self.assertTrue(Category.photos.rel_model is Photo)
|
|
self.assertTrue(Category.photos.model is Category)
|
|
|
|
class BaseItem(TestModel):
|
|
category = ForeignKeyField(Category, backref='items')
|
|
class ItemA(BaseItem): pass
|
|
class ItemB(BaseItem): pass
|
|
|
|
self.assertEqual(BaseItem.category.backref, 'items')
|
|
self.assertEqual(ItemA.category.backref, 'itema_set')
|
|
self.assertEqual(ItemB.category.backref, 'itemb_set')
|
|
self.assertTrue(Category.items.rel_model is BaseItem)
|
|
self.assertTrue(Category.itema_set.rel_model is ItemA)
|
|
self.assertTrue(Category.itema_set.model is Category)
|
|
self.assertTrue(Category.itemb_set.rel_model is ItemB)
|
|
self.assertTrue(Category.itemb_set.model is Category)
|
|
|
|
@skip_if(IS_SQLITE, 'sqlite is not supported')
|
|
@skip_if(IS_MYSQL, 'mysql is not raising this error(?)')
|
|
@skip_if(IS_CRDB, 'crdb is not raising the error in this test(?)')
|
|
def test_deferred_fk_creation(self):
|
|
class B(TestModel):
|
|
a = DeferredForeignKey('A', null=True)
|
|
b = TextField()
|
|
class A(TestModel):
|
|
a = TextField()
|
|
|
|
db.create_tables([A, B])
|
|
|
|
try:
|
|
# Test that we can create B with null "a_id" column:
|
|
a = A.create(a='a')
|
|
b = B.create(b='b')
|
|
|
|
# Test that we can create B that has no corresponding A:
|
|
fake_a = A(id=31337)
|
|
b2 = B.create(a=fake_a, b='b2')
|
|
b2_db = B.get(B.a == fake_a)
|
|
self.assertEqual(b2_db.b, 'b2')
|
|
|
|
# Ensure error occurs trying to create_foreign_key.
|
|
with db.atomic():
|
|
self.assertRaises(
|
|
IntegrityError,
|
|
B._schema.create_foreign_key,
|
|
B.a)
|
|
|
|
b2_db.delete_instance()
|
|
|
|
# We can now create the foreign key.
|
|
B._schema.create_foreign_key(B.a)
|
|
|
|
# The foreign-key is enforced:
|
|
with db.atomic():
|
|
self.assertRaises(IntegrityError, B.create, a=fake_a, b='b3')
|
|
finally:
|
|
db.drop_tables([A, B])
|
|
|
|
|
|
class TestMetaTableName(BaseTestCase):
|
|
def test_table_name_behavior(self):
|
|
def make_model(model_name, table=None):
|
|
class Meta:
|
|
legacy_table_names = False
|
|
table_name = table
|
|
return type(model_name, (Model,), {'Meta': Meta})
|
|
def assertTableName(expected, model_name, table_name=None):
|
|
model_class = make_model(model_name, table_name)
|
|
self.assertEqual(model_class._meta.table_name, expected)
|
|
|
|
assertTableName('users', 'User', 'users')
|
|
assertTableName('tweet', 'Tweet')
|
|
assertTableName('user_profile', 'UserProfile')
|
|
assertTableName('activity_log_status', 'ActivityLogStatus')
|
|
|
|
assertTableName('camel_case', 'CamelCase')
|
|
assertTableName('camel_camel_case', 'CamelCamelCase')
|
|
assertTableName('camel2_camel2_case', 'Camel2Camel2Case')
|
|
assertTableName('http_request', 'HTTPRequest')
|
|
assertTableName('api_response', 'APIResponse')
|
|
assertTableName('api_response', 'API_Response')
|
|
assertTableName('web_http_request', 'WebHTTPRequest')
|
|
assertTableName('get_http_response_code', 'getHTTPResponseCode')
|
|
assertTableName('foo_bar', 'foo_Bar')
|
|
assertTableName('foo_bar', 'Foo__Bar')
|
|
|
|
|
|
class TestMetaInheritance(BaseTestCase):
|
|
def test_table_name(self):
|
|
class Foo(Model):
|
|
class Meta:
|
|
def table_function(klass):
|
|
return 'xxx_%s' % klass.__name__.lower()
|
|
|
|
class Bar(Foo): pass
|
|
class Baze(Foo):
|
|
class Meta:
|
|
table_name = 'yyy_baze'
|
|
class Biz(Baze): pass
|
|
class Nug(Foo):
|
|
class Meta:
|
|
def table_function(klass):
|
|
return 'zzz_%s' % klass.__name__.lower()
|
|
|
|
self.assertEqual(Foo._meta.table_name, 'xxx_foo')
|
|
self.assertEqual(Bar._meta.table_name, 'xxx_bar')
|
|
self.assertEqual(Baze._meta.table_name, 'yyy_baze')
|
|
self.assertEqual(Biz._meta.table_name, 'xxx_biz')
|
|
self.assertEqual(Nug._meta.table_name, 'zzz_nug')
|
|
|
|
def test_composite_key_inheritance(self):
|
|
class Foo(Model):
|
|
key = TextField()
|
|
value = TextField()
|
|
|
|
class Meta:
|
|
primary_key = CompositeKey('key', 'value')
|
|
|
|
class Bar(Foo): pass
|
|
class Baze(Foo):
|
|
value = IntegerField()
|
|
|
|
foo = Foo(key='k1', value='v1')
|
|
self.assertEqual(foo.__composite_key__, ('k1', 'v1'))
|
|
|
|
bar = Bar(key='k2', value='v2')
|
|
self.assertEqual(bar.__composite_key__, ('k2', 'v2'))
|
|
|
|
baze = Baze(key='k3', value=3)
|
|
self.assertEqual(baze.__composite_key__, ('k3', 3))
|
|
|
|
def test_no_primary_key_inheritable(self):
|
|
class Foo(Model):
|
|
data = TextField()
|
|
|
|
class Meta:
|
|
primary_key = False
|
|
|
|
class Bar(Foo): pass
|
|
class Baze(Foo):
|
|
pk = AutoField()
|
|
class Zai(Foo):
|
|
zee = TextField(primary_key=True)
|
|
|
|
self.assertFalse(Foo._meta.primary_key)
|
|
self.assertEqual(Foo._meta.sorted_field_names, ['data'])
|
|
self.assertFalse(Bar._meta.primary_key)
|
|
self.assertEqual(Bar._meta.sorted_field_names, ['data'])
|
|
|
|
self.assertTrue(Baze._meta.primary_key is Baze.pk)
|
|
self.assertEqual(Baze._meta.sorted_field_names, ['pk', 'data'])
|
|
|
|
self.assertTrue(Zai._meta.primary_key is Zai.zee)
|
|
self.assertEqual(Zai._meta.sorted_field_names, ['zee', 'data'])
|
|
|
|
def test_inheritance(self):
|
|
db = SqliteDatabase(':memory:')
|
|
|
|
class Base(Model):
|
|
class Meta:
|
|
constraints = ['c1', 'c2']
|
|
database = db
|
|
indexes = (
|
|
(('username',), True),
|
|
)
|
|
only_save_dirty = True
|
|
options = {'key': 'value'}
|
|
schema = 'magic'
|
|
|
|
class Child(Base): pass
|
|
class GrandChild(Child): pass
|
|
|
|
for ModelClass in (Child, GrandChild):
|
|
self.assertEqual(ModelClass._meta.constraints, ['c1', 'c2'])
|
|
self.assertTrue(ModelClass._meta.database is db)
|
|
self.assertEqual(ModelClass._meta.indexes, [(('username',), True)])
|
|
self.assertEqual(ModelClass._meta.options, {'key': 'value'})
|
|
self.assertTrue(ModelClass._meta.only_save_dirty)
|
|
self.assertEqual(ModelClass._meta.schema, 'magic')
|
|
|
|
class Overrides(Base):
|
|
class Meta:
|
|
constraints = None
|
|
indexes = None
|
|
only_save_dirty = False
|
|
options = {'foo': 'bar'}
|
|
schema = None
|
|
|
|
self.assertTrue(Overrides._meta.constraints is None)
|
|
self.assertEqual(Overrides._meta.indexes, [])
|
|
self.assertFalse(Overrides._meta.only_save_dirty)
|
|
self.assertEqual(Overrides._meta.options, {'foo': 'bar'})
|
|
self.assertTrue(Overrides._meta.schema is None)
|
|
|
|
def test_temporary_inheritance(self):
|
|
class T0(TestModel): pass
|
|
class T1(TestModel):
|
|
class Meta:
|
|
temporary = True
|
|
|
|
class T2(T1): pass
|
|
class T3(T1):
|
|
class Meta:
|
|
temporary = False
|
|
|
|
self.assertFalse(T0._meta.temporary)
|
|
self.assertTrue(T1._meta.temporary)
|
|
self.assertTrue(T2._meta.temporary)
|
|
self.assertFalse(T3._meta.temporary)
|
|
|
|
|
|
class TestModelMetadataMisc(BaseTestCase):
|
|
database = get_in_memory_db()
|
|
|
|
def test_subclass_aware_metadata(self):
|
|
class SchemaPropagateMetadata(SubclassAwareMetadata):
|
|
@property
|
|
def schema(self):
|
|
return self._schema
|
|
@schema.setter
|
|
def schema(self, value):
|
|
# self.models is a singleton, essentially, shared among all
|
|
# classes that use this metadata implementation.
|
|
for model in self.models:
|
|
model._meta._schema = value
|
|
|
|
class Base(Model):
|
|
class Meta:
|
|
database = self.database
|
|
model_metadata_class = SchemaPropagateMetadata
|
|
|
|
class User(Base):
|
|
username = TextField()
|
|
class Tweet(Base):
|
|
user = ForeignKeyField(User, backref='tweets')
|
|
content = TextField()
|
|
|
|
self.assertTrue(User._meta.schema is None)
|
|
self.assertTrue(Tweet._meta.schema is None)
|
|
|
|
Base._meta.schema = 'temp'
|
|
self.assertEqual(User._meta.schema, 'temp')
|
|
self.assertEqual(Tweet._meta.schema, 'temp')
|
|
|
|
User._meta.schema = None
|
|
for model in (Base, User, Tweet):
|
|
self.assertTrue(model._meta.schema is None)
|
|
|
|
|
|
class TestModelSetDatabase(BaseTestCase):
|
|
def test_set_database(self):
|
|
class Register(Model):
|
|
value = IntegerField()
|
|
|
|
db_a = get_in_memory_db()
|
|
db_b = get_in_memory_db()
|
|
Register._meta.set_database(db_a)
|
|
Register.create_table()
|
|
Register._meta.set_database(db_b)
|
|
self.assertFalse(Register.table_exists())
|
|
self.assertEqual(db_a.get_tables(), ['register'])
|
|
self.assertEqual(db_b.get_tables(), [])
|
|
db_a.close()
|
|
db_b.close()
|
|
|
|
|
|
class TestForeignKeyFieldDescriptors(BaseTestCase):
|
|
def test_foreign_key_field_descriptors(self):
|
|
class User(Model): pass
|
|
class T0(Model):
|
|
user = ForeignKeyField(User)
|
|
class T1(Model):
|
|
user = ForeignKeyField(User, column_name='uid')
|
|
class T2(Model):
|
|
user = ForeignKeyField(User, object_id_name='uid')
|
|
class T3(Model):
|
|
user = ForeignKeyField(User, column_name='x', object_id_name='uid')
|
|
class T4(Model):
|
|
foo = ForeignKeyField(User, column_name='user')
|
|
class T5(Model):
|
|
foo = ForeignKeyField(User, object_id_name='uid')
|
|
|
|
self.assertEqual(T0.user.object_id_name, 'user_id')
|
|
self.assertEqual(T1.user.object_id_name, 'uid')
|
|
self.assertEqual(T2.user.object_id_name, 'uid')
|
|
self.assertEqual(T3.user.object_id_name, 'uid')
|
|
self.assertEqual(T4.foo.object_id_name, 'user')
|
|
self.assertEqual(T5.foo.object_id_name, 'uid')
|
|
|
|
user = User(id=1337)
|
|
self.assertEqual(T0(user=user).user_id, 1337)
|
|
self.assertEqual(T1(user=user).uid, 1337)
|
|
self.assertEqual(T2(user=user).uid, 1337)
|
|
self.assertEqual(T3(user=user).uid, 1337)
|
|
self.assertEqual(T4(foo=user).user, 1337)
|
|
self.assertEqual(T5(foo=user).uid, 1337)
|
|
|
|
def conflicts_with_field():
|
|
class TE(Model):
|
|
user = ForeignKeyField(User, object_id_name='user')
|
|
|
|
self.assertRaises(ValueError, conflicts_with_field)
|
|
|
|
def test_column_name(self):
|
|
class User(Model): pass
|
|
class T1(Model):
|
|
user = ForeignKeyField(User, column_name='user')
|
|
|
|
self.assertEqual(T1.user.column_name, 'user')
|
|
self.assertEqual(T1.user.object_id_name, 'user_id')
|
|
|
|
class NoPK(TestModel):
|
|
data = IntegerField()
|
|
class Meta:
|
|
primary_key = False
|
|
|
|
|
|
class TestModelFieldReprs(BaseTestCase):
|
|
def test_model_reprs(self):
|
|
class User(Model):
|
|
username = TextField(primary_key=True)
|
|
class Tweet(Model):
|
|
user = ForeignKeyField(User, backref='tweets')
|
|
content = TextField()
|
|
timestamp = TimestampField()
|
|
class EAV(Model):
|
|
entity = TextField()
|
|
attribute = TextField()
|
|
value = TextField()
|
|
class Meta:
|
|
primary_key = CompositeKey('entity', 'attribute')
|
|
class NoPK(Model):
|
|
key = TextField()
|
|
class Meta:
|
|
primary_key = False
|
|
|
|
self.assertEqual(repr(User), '<Model: User>')
|
|
self.assertEqual(repr(Tweet), '<Model: Tweet>')
|
|
self.assertEqual(repr(EAV), '<Model: EAV>')
|
|
self.assertEqual(repr(NoPK), '<Model: NoPK>')
|
|
|
|
self.assertEqual(repr(User()), '<User: None>')
|
|
self.assertEqual(repr(Tweet()), '<Tweet: None>')
|
|
self.assertEqual(repr(EAV()), '<EAV: (None, None)>')
|
|
self.assertEqual(repr(NoPK()), '<NoPK: n/a>')
|
|
|
|
self.assertEqual(repr(User(username='huey')), '<User: huey>')
|
|
self.assertEqual(repr(Tweet(id=1337)), '<Tweet: 1337>')
|
|
self.assertEqual(repr(EAV(entity='e', attribute='a')),
|
|
"<EAV: ('e', 'a')>")
|
|
self.assertEqual(repr(NoPK(key='k')), '<NoPK: n/a>')
|
|
|
|
self.assertEqual(repr(User.username), '<TextField: User.username>')
|
|
self.assertEqual(repr(Tweet.user), '<ForeignKeyField: Tweet.user>')
|
|
self.assertEqual(repr(EAV.entity), '<TextField: EAV.entity>')
|
|
|
|
self.assertEqual(repr(TextField()), '<TextField: (unbound)>')
|
|
|
|
def test_model_str_method(self):
|
|
class User(Model):
|
|
username = TextField(primary_key=True)
|
|
|
|
def __str__(self):
|
|
return self.username.title()
|
|
|
|
u = User(username='charlie')
|
|
self.assertEqual(repr(u), '<User: Charlie>')
|
|
|
|
|
|
class ColAlias(TestModel):
|
|
name = TextField(column_name='pname')
|
|
|
|
|
|
class CARef(TestModel):
|
|
colalias = ForeignKeyField(ColAlias, backref='carefs', column_name='ca',
|
|
object_id_name='colalias_id')
|
|
|
|
|
|
class TestQueryAliasToColumnName(ModelTestCase):
|
|
requires = [ColAlias, CARef]
|
|
|
|
def setUp(self):
|
|
super(TestQueryAliasToColumnName, self).setUp()
|
|
with self.database.atomic():
|
|
for name in ('huey', 'mickey'):
|
|
col_alias = ColAlias.create(name=name)
|
|
CARef.create(colalias=col_alias)
|
|
|
|
def test_alias_to_column_name(self):
|
|
# The issue here occurs when we take a field whose name differs from
|
|
# it's underlying column name, then alias that field to it's column
|
|
# name. In this case, peewee was *not* respecting the alias and using
|
|
# the field name instead.
|
|
query = (ColAlias
|
|
.select(ColAlias.name.alias('pname'))
|
|
.order_by(ColAlias.name))
|
|
self.assertEqual([c.pname for c in query], ['huey', 'mickey'])
|
|
|
|
# Ensure that when using dicts the logic is preserved.
|
|
query = query.dicts()
|
|
self.assertEqual([r['pname'] for r in query], ['huey', 'mickey'])
|
|
|
|
def test_alias_overlap_with_join(self):
|
|
query = (CARef
|
|
.select(CARef, ColAlias.name.alias('pname'))
|
|
.join(ColAlias)
|
|
.order_by(ColAlias.name))
|
|
with self.assertQueryCount(1):
|
|
self.assertEqual([r.colalias.pname for r in query],
|
|
['huey', 'mickey'])
|
|
|
|
# Note: we cannot alias the join to "ca", as this is the object-id
|
|
# descriptor name.
|
|
query = (CARef
|
|
.select(CARef, ColAlias.name.alias('pname'))
|
|
.join(ColAlias,
|
|
on=(CARef.colalias == ColAlias.id).alias('ca'))
|
|
.order_by(ColAlias.name))
|
|
with self.assertQueryCount(1):
|
|
self.assertEqual([r.ca.pname for r in query], ['huey', 'mickey'])
|
|
|
|
def test_cannot_alias_join_to_object_id_name(self):
|
|
query = CARef.select(CARef, ColAlias.name.alias('pname'))
|
|
expr = (CARef.colalias == ColAlias.id).alias('colalias_id')
|
|
self.assertRaises(ValueError, query.join, ColAlias, on=expr)
|
|
|
|
|
|
class TestOverrideModelRepr(BaseTestCase):
|
|
def test_custom_reprs(self):
|
|
# In 3.5.0, Peewee included a new implementation and semantics for
|
|
# customizing model reprs. This introduced a regression where model
|
|
# classes that defined a __repr__() method had this override ignored
|
|
# silently. This test ensures that it is possible to completely
|
|
# override the model repr.
|
|
class Foo(Model):
|
|
def __repr__(self):
|
|
return 'FOO: %s' % self.id
|
|
|
|
f = Foo(id=1337)
|
|
self.assertEqual(repr(f), 'FOO: 1337')
|
|
|
|
|
|
class Product(TestModel):
|
|
name = TextField()
|
|
price = IntegerField()
|
|
flags = IntegerField(constraints=[SQL('DEFAULT 99')])
|
|
status = CharField(constraints=[Check("status IN ('a', 'b', 'c')")])
|
|
|
|
class Meta:
|
|
constraints = [Check('price > 0')]
|
|
|
|
|
|
class TestModelConstraints(ModelTestCase):
|
|
requires = [Product]
|
|
|
|
def test_model_constraints(self):
|
|
p = Product.create(name='p1', price=1, status='a')
|
|
self.assertTrue(p.flags is None)
|
|
|
|
# Price was saved successfully, flags got server-side default value.
|
|
p_db = Product.get(Product.id == p.id)
|
|
self.assertEqual(p_db.price, 1)
|
|
self.assertEqual(p_db.flags, 99)
|
|
self.assertEqual(p_db.status, 'a')
|
|
|
|
# Cannot update price with invalid value, must be > 0.
|
|
with self.database.atomic():
|
|
p.price = -1
|
|
self.assertRaises(DatabaseError, p.save)
|
|
|
|
# Nor can we create a new product with an invalid price.
|
|
with self.database.atomic():
|
|
self.assertRaises(DatabaseError, Product.create, name='p2',
|
|
price=0, status='a')
|
|
|
|
# Cannot set status to a value other than 1, 2 or 3.
|
|
with self.database.atomic():
|
|
p.price = 1
|
|
p.status = 'd'
|
|
self.assertRaises(DatabaseError, p.save)
|
|
|
|
# Cannot create a new product with invalid status.
|
|
with self.database.atomic():
|
|
self.assertRaises(DatabaseError, Product.create, name='p3',
|
|
price=1, status='x')
|
|
|
|
|
|
class SequenceModel(TestModel):
|
|
seq_id = IntegerField(sequence='seq_id_sequence')
|
|
key = TextField()
|
|
|
|
|
|
@requires_pglike
|
|
class TestSequence(ModelTestCase):
|
|
requires = [SequenceModel]
|
|
|
|
def test_create_table(self):
|
|
query = SequenceModel._schema._create_table()
|
|
self.assertSQL(query, (
|
|
'CREATE TABLE IF NOT EXISTS "sequence_model" ('
|
|
'"id" SERIAL NOT NULL PRIMARY KEY, '
|
|
'"seq_id" INTEGER NOT NULL DEFAULT NEXTVAL(\'seq_id_sequence\'), '
|
|
'"key" TEXT NOT NULL)'), [])
|
|
|
|
def test_sequence(self):
|
|
for key in ('k1', 'k2', 'k3'):
|
|
SequenceModel.create(key=key)
|
|
|
|
s1, s2, s3 = SequenceModel.select().order_by(SequenceModel.key)
|
|
|
|
self.assertEqual(s1.seq_id, 1)
|
|
self.assertEqual(s2.seq_id, 2)
|
|
self.assertEqual(s3.seq_id, 3)
|
|
|
|
# ===========================================================================
|
|
# Database integration
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
class TestBindTo(ModelTestCase):
|
|
requires = [User, Tweet]
|
|
|
|
def test_bind_to(self):
|
|
for i in (1, 2, 3):
|
|
user = User.create(username='u%s' % i)
|
|
Tweet.create(user=user, content='t%s' % i)
|
|
|
|
# Alias to a particular field-name.
|
|
name = Case(User.username, [
|
|
('u1', 'user 1'),
|
|
('u2', 'user 2')], 'someone else')
|
|
q = (Tweet
|
|
.select(Tweet.content, name.alias('username').bind_to(User))
|
|
.join(User)
|
|
.order_by(Tweet.content))
|
|
with self.assertQueryCount(1):
|
|
self.assertEqual([(t.content, t.user.username) for t in q], [
|
|
('t1', 'user 1'),
|
|
('t2', 'user 2'),
|
|
('t3', 'someone else')])
|
|
|
|
# Use a different alias.
|
|
q = (Tweet
|
|
.select(Tweet.content, name.alias('display').bind_to(User))
|
|
.join(User)
|
|
.order_by(Tweet.content))
|
|
with self.assertQueryCount(1):
|
|
self.assertEqual([(t.content, t.user.display) for t in q], [
|
|
('t1', 'user 1'),
|
|
('t2', 'user 2'),
|
|
('t3', 'someone else')])
|
|
|
|
# Ensure works with model and field aliases.
|
|
TA, UA = Tweet.alias(), User.alias()
|
|
name = Case(UA.username, [
|
|
('u1', 'user 1'),
|
|
('u2', 'user 2')], 'someone else')
|
|
q = (TA
|
|
.select(TA.content, name.alias('display').bind_to(UA))
|
|
.join(UA, on=(UA.id == TA.user))
|
|
.order_by(TA.content))
|
|
with self.assertQueryCount(1):
|
|
self.assertEqual([(t.content, t.user.display) for t in q], [
|
|
('t1', 'user 1'),
|
|
('t2', 'user 2'),
|
|
('t3', 'someone else')])
|
|
|
|
|
|
class TestGetWithSecondDatabase(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [User]
|
|
|
|
def test_get_with_second_database(self):
|
|
User.create(username='huey')
|
|
query = User.select().where(User.username == 'huey')
|
|
self.assertEqual(query.get().username, 'huey')
|
|
|
|
alt_db = get_in_memory_db()
|
|
with User.bind_ctx(alt_db):
|
|
User.create_table()
|
|
|
|
self.assertRaises(User.DoesNotExist, query.get, alt_db)
|
|
with User.bind_ctx(alt_db):
|
|
User.create(username='zaizee')
|
|
|
|
query = User.select().where(User.username == 'zaizee')
|
|
self.assertRaises(User.DoesNotExist, query.get)
|
|
self.assertEqual(query.get(alt_db).username, 'zaizee')
|
|
|
|
|
|
class TestMixModelsTables(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [User]
|
|
|
|
def test_mix_models_tables(self):
|
|
Tbl = User._meta.table
|
|
self.assertEqual(Tbl.insert({Tbl.username: 'huey'}).execute(), 1)
|
|
|
|
huey = Tbl.select(User.username).get()
|
|
self.assertEqual(huey, {'username': 'huey'})
|
|
|
|
huey = User.select(Tbl.username).get()
|
|
self.assertEqual(huey.username, 'huey')
|
|
|
|
Tbl.update(username='huey-x').where(Tbl.username == 'huey').execute()
|
|
self.assertEqual(User.select().get().username, 'huey-x')
|
|
|
|
Tbl.delete().where(User.username == 'huey-x').execute()
|
|
self.assertEqual(Tbl.select().count(), 0)
|
|
|
|
|
|
class TestDatabaseExecuteQuery(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [User]
|
|
|
|
def test_execute_query(self):
|
|
for username in ('huey', 'zaizee'):
|
|
User.create(username=username)
|
|
|
|
query = User.select().order_by(User.username.desc())
|
|
cursor = self.database.execute(query)
|
|
self.assertEqual([row[1] for row in cursor], ['zaizee', 'huey'])
|
|
|
|
class ConflictDetectedException(Exception): pass
|
|
|
|
class BaseVersionedModel(TestModel):
|
|
version = IntegerField(default=1, index=True)
|
|
|
|
def save_optimistic(self):
|
|
if not self.id:
|
|
# This is a new record, so the default logic is to perform an
|
|
# INSERT. Ideally your model would also have a unique
|
|
# constraint that made it impossible for two INSERTs to happen
|
|
# at the same time.
|
|
return self.save()
|
|
|
|
# Update any data that has changed and bump the version counter.
|
|
field_data = dict(self.__data__)
|
|
current_version = field_data.pop('version', 1)
|
|
self._populate_unsaved_relations(field_data)
|
|
field_data = self._prune_fields(field_data, self.dirty_fields)
|
|
if not field_data:
|
|
raise ValueError('No changes have been made.')
|
|
|
|
ModelClass = type(self)
|
|
field_data['version'] = ModelClass.version + 1 # Atomic increment.
|
|
|
|
query = ModelClass.update(**field_data).where(
|
|
(ModelClass.version == current_version) &
|
|
(ModelClass.id == self.id))
|
|
if query.execute() == 0:
|
|
# No rows were updated, indicating another process has saved
|
|
# a new version. How you handle this situation is up to you,
|
|
# but for simplicity I'm just raising an exception.
|
|
raise ConflictDetectedException()
|
|
else:
|
|
# Increment local version to match what is now in the db.
|
|
self.version += 1
|
|
return True
|
|
|
|
class VUser(BaseVersionedModel):
|
|
username = TextField()
|
|
|
|
class VTweet(BaseVersionedModel):
|
|
user = ForeignKeyField(VUser, null=True)
|
|
content = TextField()
|
|
|
|
|
|
class TestOptimisticLockingDemo(ModelTestCase):
|
|
requires = [VUser, VTweet]
|
|
|
|
def test_optimistic_locking(self):
|
|
vu = VUser(username='u1')
|
|
vu.save_optimistic()
|
|
vt = VTweet(user=vu, content='t1')
|
|
vt.save_optimistic()
|
|
|
|
# Update the "vt" row in the db, which bumps the version counter.
|
|
vt2 = VTweet.get(VTweet.id == vt.id)
|
|
vt2.content = 't1-x'
|
|
vt2.save_optimistic()
|
|
|
|
# Since no data was modified, this returns a ValueError.
|
|
self.assertRaises(ValueError, vt.save_optimistic)
|
|
|
|
# If we do make an update and attempt to save, a conflict is detected.
|
|
vt.content = 't1-y'
|
|
self.assertRaises(ConflictDetectedException, vt.save_optimistic)
|
|
self.assertEqual(vt.version, 1)
|
|
|
|
vt_db = VTweet.get(VTweet.id == vt.id)
|
|
self.assertEqual(vt_db.content, 't1-x')
|
|
self.assertEqual(vt_db.version, 2)
|
|
self.assertEqual(vt_db.user.username, 'u1')
|
|
|
|
def test_optimistic_locking_populate_fks(self):
|
|
vt = VTweet(content='t1')
|
|
vt.save_optimistic()
|
|
|
|
vu = VUser(username='u1')
|
|
vt.user = vu
|
|
|
|
vu.save_optimistic()
|
|
vt.save_optimistic()
|
|
vt_db = VTweet.get(VTweet.content == 't1')
|
|
self.assertEqual(vt_db.version, 2)
|
|
self.assertEqual(vt_db.user.username, 'u1')
|
|
|
|
# ===========================================================================
|
|
# Regressions and bug-fix tests
|
|
# ===========================================================================
|
|
|
|
|
|
|
|
class DiA(TestModel):
|
|
a = TextField(unique=True)
|
|
class DiB(TestModel):
|
|
a = ForeignKeyField(DiA)
|
|
b = TextField()
|
|
class DiC(TestModel):
|
|
b = ForeignKeyField(DiB)
|
|
c = TextField()
|
|
class DiD(TestModel):
|
|
c = ForeignKeyField(DiC)
|
|
d = TextField()
|
|
class DiBA(TestModel):
|
|
a = ForeignKeyField(DiA, to_field=DiA.a)
|
|
b = TextField()
|
|
|
|
|
|
class TestDeleteInstanceRegression(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [DiA, DiB, DiC, DiD, DiBA]
|
|
|
|
def test_delete_instance_regression(self):
|
|
with self.database.atomic():
|
|
a1, a2, a3 = [DiA.create(a=a) for a in ('a1', 'a2', 'a3')]
|
|
for a in (a1, a2, a3):
|
|
for j in (1, 2):
|
|
b = DiB.create(a=a, b='%s-b%s' % (a.a, j))
|
|
c = DiC.create(b=b, c='%s-c' % (b.b))
|
|
d = DiD.create(c=c, d='%s-d' % (c.c))
|
|
|
|
DiBA.create(a=a, b='%s-b%s' % (a.a, j))
|
|
|
|
# (a1 (b1 (c (d))), (b2 (c (d)))), (a2 ...), (a3 ...)
|
|
with self.assertQueryCount(5):
|
|
a2.delete_instance(recursive=True)
|
|
|
|
self.assertHistory(5, [
|
|
('DELETE FROM "di_d" WHERE ("di_d"."c_id" IN ('
|
|
'SELECT "t1"."id" FROM "di_c" AS "t1" WHERE ("t1"."b_id" IN ('
|
|
'SELECT "t2"."id" FROM "di_b" AS "t2" WHERE ("t2"."a_id" = ?)'
|
|
'))))', [2]),
|
|
('DELETE FROM "di_c" WHERE ("di_c"."b_id" IN ('
|
|
'SELECT "t1"."id" FROM "di_b" AS "t1" WHERE ("t1"."a_id" = ?)'
|
|
'))', [2]),
|
|
('DELETE FROM "di_ba" WHERE ("di_ba"."a_id" = ?)', ['a2']),
|
|
('DELETE FROM "di_b" WHERE ("di_b"."a_id" = ?)', [2]),
|
|
('DELETE FROM "di_a" WHERE ("di_a"."id" = ?)', [2])
|
|
])
|
|
|
|
# a1 & a3 exist, plus their relations.
|
|
self.assertTrue(DiA.select().count(), 2)
|
|
for rel in (DiB, DiBA, DiC, DiD):
|
|
self.assertTrue(rel.select().count(), 4) # 2x2
|
|
|
|
with self.assertQueryCount(5):
|
|
a1.delete_instance(recursive=True)
|
|
|
|
# Only the objects related to a3 exist still.
|
|
self.assertTrue(DiA.select().count(), 1)
|
|
self.assertEqual(DiA.get(DiA.a == 'a3').id, a3.id)
|
|
self.assertEqual([d.d for d in DiD.select().order_by(DiD.d)],
|
|
['a3-b1-c-d', 'a3-b2-c-d'])
|
|
self.assertEqual([c.c for c in DiC.select().order_by(DiC.c)],
|
|
['a3-b1-c', 'a3-b2-c'])
|
|
self.assertEqual([b.b for b in DiB.select().order_by(DiB.b)],
|
|
['a3-b1', 'a3-b2'])
|
|
self.assertEqual([ba.b for ba in DiBA.select().order_by(DiBA.b)],
|
|
['a3-b1', 'a3-b2'])
|
|
|
|
class User2(TestModel):
|
|
username = TextField()
|
|
|
|
class Category2(TestModel):
|
|
name = TextField()
|
|
parent = ForeignKeyField('self', backref='children', null=True)
|
|
user = ForeignKeyField(User2)
|
|
|
|
class TestCountUnionRegression(ModelTestCase):
|
|
@requires_mysql
|
|
@requires_models(User)
|
|
def test_count_union(self):
|
|
with self.database.atomic():
|
|
for i in range(5):
|
|
User.create(username='user-%d' % i)
|
|
|
|
lhs = User.select()
|
|
rhs = User.select()
|
|
query = (lhs | rhs)
|
|
self.assertSQL(query, (
|
|
'SELECT "t1"."id", "t1"."username" FROM "users" AS "t1" '
|
|
'UNION '
|
|
'SELECT "t2"."id", "t2"."username" FROM "users" AS "t2"'), [])
|
|
|
|
self.assertEqual(query.count(), 5)
|
|
|
|
query = query.limit(3)
|
|
self.assertSQL(query, (
|
|
'SELECT "t1"."id", "t1"."username" FROM "users" AS "t1" '
|
|
'UNION '
|
|
'SELECT "t2"."id", "t2"."username" FROM "users" AS "t2" '
|
|
'LIMIT ?'), [3])
|
|
self.assertEqual(query.count(), 3)
|
|
|
|
|
|
class TestGithub1354(ModelTestCase):
|
|
@requires_models(Category2, User2)
|
|
def test_get_or_create_self_referential_fk2(self):
|
|
huey = User2.create(username='huey')
|
|
parent = Category2.create(name='parent', user=huey)
|
|
child, created = Category2.get_or_create(parent=parent, name='child',
|
|
user=huey)
|
|
child_db = Category2.get(Category2.parent == parent)
|
|
self.assertEqual(child_db.user.username, 'huey')
|
|
self.assertEqual(child_db.parent.name, 'parent')
|
|
self.assertEqual(child_db.name, 'child')
|
|
|
|
class TestInsertFromSQL(ModelTestCase):
|
|
def setUp(self):
|
|
super(TestInsertFromSQL, self).setUp()
|
|
|
|
self.database.execute_sql('create table if not exists user_src '
|
|
'(name TEXT);')
|
|
tbl = Table('user_src').bind(self.database)
|
|
tbl.insert(name='foo').execute()
|
|
|
|
def tearDown(self):
|
|
super(TestInsertFromSQL, self).tearDown()
|
|
self.database.execute_sql('drop table if exists user_src')
|
|
|
|
@requires_models(User)
|
|
def test_insert_from_sql(self):
|
|
query_src = SQL('SELECT name FROM user_src')
|
|
User.insert_from(query=query_src, fields=[User.username]).execute()
|
|
self.assertEqual([u.username for u in User.select()], ['foo'])
|
|
|
|
@requires_postgresql
|
|
class TestReturningIntegrationRegressions(ModelTestCase):
|
|
requires = [User, Tweet]
|
|
|
|
def test_returning_integration_subqueries(self):
|
|
_create_users_tweets(self.database)
|
|
|
|
# We can use a correlated subquery in the RETURNING clause.
|
|
subq = (Tweet
|
|
.select(fn.COUNT(Tweet.id).alias('ct'))
|
|
.where(Tweet.user == User.id))
|
|
query = (User
|
|
.update(username=(User.username + '-x'))
|
|
.returning(subq.alias('ct'), User.username))
|
|
result = query.execute()
|
|
self.assertEqual(sorted([(r.ct, r.username) for r in result]), [
|
|
(0, 'zaizee-x'), (2, 'mickey-x'), (3, 'huey-x')])
|
|
|
|
# We can use a correlated subquery via UPDATE...FROM, and reference the
|
|
# FROM table in both the update and the RETURNING clause.
|
|
subq = (User
|
|
.select(User.id, fn.COUNT(Tweet.id).alias('ct'))
|
|
.join(Tweet, JOIN.LEFT_OUTER)
|
|
.group_by(User.id))
|
|
query = (User
|
|
.update(username=User.username + subq.c.ct)
|
|
.from_(subq)
|
|
.where(User.id == subq.c.id)
|
|
.returning(subq.c.ct, User.username))
|
|
result = query.execute()
|
|
self.assertEqual(sorted([(r.ct, r.username) for r in result]), [
|
|
(0, 'zaizee-x0'), (2, 'mickey-x2'), (3, 'huey-x3')])
|
|
|
|
def test_returning_integration(self):
|
|
query = (User
|
|
.insert_many([('huey',), ('mickey',), ('zaizee',)],
|
|
fields=[User.username])
|
|
.returning(User.id, User.username)
|
|
.objects())
|
|
result = query.execute()
|
|
self.assertEqual([(r.id, r.username) for r in result], [
|
|
(1, 'huey'), (2, 'mickey'), (3, 'zaizee')])
|
|
|
|
query = (User
|
|
.delete()
|
|
.where(~User.username.startswith('h'))
|
|
.returning(User.id, User.username)
|
|
.objects())
|
|
result = query.execute()
|
|
self.assertEqual(sorted([(r.id, r.username) for r in result]), [
|
|
(2, 'mickey'), (3, 'zaizee')])
|
|
|
|
class TestUpdateIntegrationRegressions(ModelTestCase):
|
|
requires = [User, Tweet, Sample]
|
|
|
|
def setUp(self):
|
|
super(TestUpdateIntegrationRegressions, self).setUp()
|
|
_create_users_tweets(self.database)
|
|
for i in range(4):
|
|
Sample.create(counter=i, value=i)
|
|
|
|
@skip_if(IS_MYSQL)
|
|
def test_update_examples(self):
|
|
# Do a simple update.
|
|
res = (User
|
|
.update(username=(User.username + '-cat'))
|
|
.where(User.username != 'mickey')
|
|
.execute())
|
|
|
|
users = User.select().order_by(User.username)
|
|
self.assertEqual([u.username for u in users.clone()],
|
|
['huey-cat', 'mickey', 'zaizee-cat'])
|
|
|
|
# Do an update using a subquery..
|
|
subq = User.select(User.username).where(User.username == 'mickey')
|
|
res = (User
|
|
.update(username=(User.username + '-dog'))
|
|
.where(User.username.in_(subq))
|
|
.execute())
|
|
self.assertEqual([u.username for u in users.clone()],
|
|
['huey-cat', 'mickey-dog', 'zaizee-cat'])
|
|
|
|
# Subquery referring to a different table.
|
|
subq = User.select().where(User.username == 'mickey-dog')
|
|
res = (Tweet
|
|
.update(content=(Tweet.content + '-x'))
|
|
.where(Tweet.user.in_(subq))
|
|
.execute())
|
|
|
|
self.assertEqual(
|
|
[t.content for t in Tweet.select().order_by(Tweet.id)],
|
|
['meow', 'hiss', 'purr', 'woof-x', 'bark-x'])
|
|
|
|
# Subquery on the right-hand of the assignment.
|
|
subq = (Tweet
|
|
.select(fn.COUNT(Tweet.id).cast('text'))
|
|
.where(Tweet.user == User.id))
|
|
res = User.update(username=(User.username + '-' + subq)).execute()
|
|
|
|
self.assertEqual([u.username for u in users.clone()],
|
|
['huey-cat-3', 'mickey-dog-2', 'zaizee-cat-0'])
|
|
|
|
def test_update_examples_2(self):
|
|
SA = Sample.alias()
|
|
subq = (SA
|
|
.select(SA.value)
|
|
.where(SA.value.in_([1.0, 3.0])))
|
|
res = (Sample
|
|
.update(counter=(Sample.counter + Sample.value.cast('int')))
|
|
.where(Sample.value.in_(subq))
|
|
.execute())
|
|
|
|
query = (Sample
|
|
.select(Sample.counter, Sample.value)
|
|
.order_by(Sample.id)
|
|
.tuples())
|
|
self.assertEqual(list(query.clone()), [(0, 0.), (2, 1.), (2, 2.),
|
|
(6, 3.)])
|
|
|
|
subq = (SA
|
|
.select(SA.counter - SA.value.cast('int'))
|
|
.where(SA.value == Sample.value))
|
|
res = (Sample
|
|
.update(counter=subq)
|
|
.where(Sample.value.in_([1., 3.]))
|
|
.execute())
|
|
self.assertEqual(list(query.clone()), [(0, 0.), (1, 1.), (2, 2.),
|
|
(3, 3.)])
|
|
|
|
|
|
class MGProject(TestModel):
|
|
name = TextField()
|
|
|
|
class MGTask(TestModel):
|
|
name = TextField()
|
|
mgproject = ForeignKeyField(MGProject, backref='tasks')
|
|
alt = ForeignKeyField(MGProject, backref='alt_tasks')
|
|
|
|
|
|
class TestModelGraphMultiFK(ModelTestCase):
|
|
requires = [MGProject, MGTask]
|
|
|
|
def test_model_graph_multi_fk(self):
|
|
pa, pb, pc = [MGProject.create(name=name) for name in 'abc']
|
|
t1 = MGTask.create(name='t1', mgproject=pa, alt=pc)
|
|
t2 = MGTask.create(name='t2', mgproject=pb, alt=pb)
|
|
|
|
P1 = MGProject.alias('p1')
|
|
P2 = MGProject.alias('p2')
|
|
LO = JOIN.LEFT_OUTER
|
|
|
|
# Query using join expression.
|
|
q1 = (MGTask
|
|
.select(MGTask, P1, P2)
|
|
.join_from(MGTask, P1, LO, on=(MGTask.mgproject == P1.id))
|
|
.join_from(MGTask, P2, LO, on=(MGTask.alt == P2.id))
|
|
.order_by(MGTask.name))
|
|
|
|
# Query specifying target field.
|
|
q2 = (MGTask
|
|
.select(MGTask, P1, P2)
|
|
.join_from(MGTask, P1, LO, on=MGTask.mgproject)
|
|
.join_from(MGTask, P2, LO, on=MGTask.alt)
|
|
.order_by(MGTask.name))
|
|
|
|
# Query specifying with missing target field.
|
|
q3 = (MGTask
|
|
.select(MGTask, P1, P2)
|
|
.join_from(MGTask, P1, LO) # Implicitly selects mgproject.
|
|
.join_from(MGTask, P2, LO, on=MGTask.alt)
|
|
.order_by(MGTask.name))
|
|
|
|
for query in (q1, q2, q3):
|
|
with self.assertQueryCount(1):
|
|
t1, t2 = list(query)
|
|
self.assertEqual(t1.mgproject.name, 'a')
|
|
self.assertEqual(t1.alt.name, 'c')
|
|
self.assertEqual(t2.mgproject.name, 'b')
|
|
self.assertEqual(t2.alt.name, 'b')
|
|
|
|
|
|
class RS(TestModel):
|
|
name = TextField()
|
|
|
|
class RD(TestModel):
|
|
key = TextField()
|
|
value = IntegerField()
|
|
rs = ForeignKeyField(RS, backref='rds')
|
|
|
|
class RKV(TestModel):
|
|
key = CharField(max_length=10)
|
|
value = IntegerField()
|
|
extra = IntegerField()
|
|
class Meta:
|
|
primary_key = CompositeKey('key', 'value')
|
|
|
|
|
|
class TestRegressionCountDistinct(ModelTestCase):
|
|
@requires_models(RS, RD)
|
|
def test_regression_count_distinct(self):
|
|
rs = RS.create(name='rs')
|
|
|
|
nums = [0, 1, 2, 3, 2, 1, 0]
|
|
RD.insert_many([('k%s' % i, i, rs) for i in nums]).execute()
|
|
|
|
query = RD.select(RD.key).distinct()
|
|
self.assertEqual(query.count(), 4)
|
|
|
|
# Try re-selecting using the id/key, which are all distinct.
|
|
query = query.select(RD.id, RD.key)
|
|
self.assertEqual(query.count(), 7)
|
|
|
|
# Re-select the key/value, of which there are 4 distinct.
|
|
query = query.select(RD.key, RD.value)
|
|
self.assertEqual(query.count(), 4)
|
|
|
|
query = rs.rds.select(RD.key).distinct()
|
|
self.assertEqual(query.count(), 4)
|
|
|
|
query = rs.rds.select(RD.key, RD.value).distinct()
|
|
self.assertEqual(query.count(), 4) # Was returning 7!
|
|
|
|
@requires_models(RKV)
|
|
def test_regression_count_distinct_cpk(self):
|
|
RKV.insert_many([('k%s' % i, i, i) for i in range(5)]).execute()
|
|
self.assertEqual(RKV.select().distinct().count(), 5)
|
|
|
|
class TestReselectModelRegression(ModelTestCase):
|
|
requires = [User]
|
|
|
|
def test_reselect_model_regression(self):
|
|
u1, u2, u3 = [User.create(username='u%s' % i) for i in '123']
|
|
|
|
query = User.select(User.username).order_by(User.username.desc())
|
|
self.assertEqual(list(query.tuples()), [('u3',), ('u2',), ('u1',)])
|
|
|
|
query = query.select(User)
|
|
self.assertEqual(list(query.tuples()), [
|
|
(u3.id, 'u3',),
|
|
(u2.id, 'u2',),
|
|
(u1.id, 'u1',)])
|
|
|
|
class RU(TestModel):
|
|
username = TextField()
|
|
|
|
|
|
class Recipe(TestModel):
|
|
name = TextField()
|
|
created_by = ForeignKeyField(RU, backref='recipes')
|
|
changed_by = ForeignKeyField(RU, backref='recipes_modified')
|
|
|
|
class TestJoinCorrelatedSubquery(ModelTestCase):
|
|
requires = [User, Tweet]
|
|
|
|
def test_join_correlated_subquery(self):
|
|
for i in range(3):
|
|
user = User.create(username='u%s' % i)
|
|
for j in range(i + 1):
|
|
Tweet.create(user=user, content='u%s-%s' % (i, j))
|
|
|
|
UA = User.alias()
|
|
subq = (UA
|
|
.select(UA.username)
|
|
.where(UA.username.in_(('u0', 'u2'))))
|
|
|
|
query = (Tweet
|
|
.select(Tweet, User)
|
|
.join(User, on=(
|
|
(Tweet.user == User.id) &
|
|
(User.username.in_(subq))))
|
|
.order_by(Tweet.id))
|
|
|
|
with self.assertQueryCount(1):
|
|
data = [(t.content, t.user.username) for t in query]
|
|
self.assertEqual(data, [
|
|
('u0-0', 'u0'),
|
|
('u2-0', 'u2'),
|
|
('u2-1', 'u2'),
|
|
('u2-2', 'u2')])
|
|
|
|
|
|
class TestMultiFKJoinRegression(ModelTestCase):
|
|
requires = [RU, Recipe]
|
|
|
|
def test_multi_fk_join_regression(self):
|
|
u1, u2 = [RU.create(username=u) for u in ('u1', 'u2')]
|
|
for (n, a, m) in (('r11', u1, u1), ('r12', u1, u2), ('r21', u2, u1)):
|
|
Recipe.create(name=n, created_by=a, changed_by=m)
|
|
|
|
Change = RU.alias()
|
|
query = (Recipe
|
|
.select(Recipe, RU, Change)
|
|
.join(RU, on=(RU.id == Recipe.created_by).alias('a'))
|
|
.switch(Recipe)
|
|
.join(Change, on=(Change.id == Recipe.changed_by).alias('b'))
|
|
.order_by(Recipe.name))
|
|
with self.assertQueryCount(1):
|
|
data = [(r.name, r.a.username, r.b.username) for r in query]
|
|
self.assertEqual(data, [
|
|
('r11', 'u1', 'u1'),
|
|
('r12', 'u1', 'u2'),
|
|
('r21', 'u2', 'u1')])
|
|
|
|
class TestCompoundExistsRegression(ModelTestCase):
|
|
requires = [User]
|
|
|
|
def test_compound_regressions_1961(self):
|
|
UA = User.alias()
|
|
cq = (User.select(User.id) | UA.select(UA.id))
|
|
# Calling .exists() fails with AttributeError, no attribute "columns".
|
|
self.assertFalse(cq.exists())
|
|
self.assertEqual(cq.count(), 0)
|
|
|
|
User.create(username='u1')
|
|
self.assertTrue(cq.exists())
|
|
self.assertEqual(cq.count(), 1)
|
|
|
|
class TestLikeColumnValue(ModelTestCase):
|
|
requires = [User, Tweet]
|
|
|
|
def test_like_column_value(self):
|
|
# e.g., find all tweets that contain the users own username.
|
|
u1, u2, u3 = [User.create(username='u%s' % i) for i in (1, 2, 3)]
|
|
data = (
|
|
(u1, ('nada', 'i am u1', 'u1 is my name')),
|
|
(u2, ('nothing', 'he is u1')),
|
|
(u3, ('she is u2', 'hey u3 is me', 'xx')))
|
|
for user, tweets in data:
|
|
Tweet.insert_many([(user, tweet) for tweet in tweets],
|
|
fields=[Tweet.user, Tweet.content]).execute()
|
|
|
|
expressions = (
|
|
(Tweet.content ** ('%' + User.username + '%')),
|
|
Tweet.content.contains(User.username))
|
|
|
|
for expr in expressions:
|
|
query = (Tweet
|
|
.select(Tweet, User)
|
|
.join(User)
|
|
.where(expr)
|
|
.order_by(Tweet.id))
|
|
|
|
self.assertEqual([(t.user.username, t.content) for t in query], [
|
|
('u1', 'i am u1'),
|
|
('u1', 'u1 is my name'),
|
|
('u3', 'hey u3 is me')])
|
|
|
|
class TestUnionParenthesesRegression(ModelTestCase):
|
|
requires = [User]
|
|
|
|
def test_union_parentheses_regression(self):
|
|
ua, ub, uc = [User.create(username=u) for u in 'abc']
|
|
lhs = User.select(User.id).where(User.username == 'a')
|
|
rhs = User.select(User.id).where(User.username == 'c')
|
|
union = lhs.union_all(rhs)
|
|
self.assertEqual(sorted([u.id for u in union]), [ua.id, uc.id])
|
|
|
|
query = User.select().where(User.id.in_(union)).order_by(User.id)
|
|
self.assertEqual([u.username for u in query], ['a', 'c'])
|
|
|
|
class Site(TestModel):
|
|
url = TextField()
|
|
|
|
class Page(TestModel):
|
|
site = ForeignKeyField(Site, backref='pages')
|
|
title = TextField()
|
|
|
|
class PageItem(TestModel):
|
|
page = ForeignKeyField(Page, backref='items')
|
|
content = TextField()
|
|
|
|
|
|
class TestNoPKHashRegression(ModelTestCase):
|
|
requires = [NoPK]
|
|
|
|
def test_no_pk_hash_regression(self):
|
|
npk = NoPK.create(data=1)
|
|
npk_db = NoPK.get(NoPK.data == 1)
|
|
# When a model does not define a primary key, we cannot test equality.
|
|
self.assertTrue(npk != npk_db)
|
|
|
|
# Their hash is the same, though they are not equal.
|
|
self.assertEqual(hash(npk), hash(npk_db))
|
|
|
|
|
|
class TestModelFilterJoinOrdering(ModelTestCase):
|
|
requires = [Site, Page, PageItem]
|
|
|
|
def setUp(self):
|
|
super(TestModelFilterJoinOrdering, self).setUp()
|
|
with self.database.atomic():
|
|
s1, s2 = [Site.create(url=s) for s in ('s1', 's2')]
|
|
p11, p12, p21 = [Page.create(site=s, title=t) for s, t in
|
|
((s1, 'p1-1'), (s1, 'p1-2'), (s2, 'p2-1'))]
|
|
items = (
|
|
(p11, 's1p1i1'),
|
|
(p11, 's1p1i2'),
|
|
(p11, 's1p1i3'),
|
|
(p12, 's1p2i1'),
|
|
(p21, 's2p1i1'))
|
|
PageItem.insert_many(items).execute()
|
|
|
|
def test_model_filter_join_ordering(self):
|
|
q = PageItem.filter(page__site__url='s1').order_by(PageItem.content)
|
|
self.assertSQL(q, (
|
|
'SELECT "t1"."id", "t1"."page_id", "t1"."content" '
|
|
'FROM "page_item" AS "t1" '
|
|
'INNER JOIN "page" AS "t2" ON ("t1"."page_id" = "t2"."id") '
|
|
'INNER JOIN "site" AS "t3" ON ("t2"."site_id" = "t3"."id") '
|
|
'WHERE ("t3"."url" = ?) ORDER BY "t1"."content"'), ['s1'])
|
|
|
|
def assertQ(q):
|
|
with self.assertQueryCount(1):
|
|
self.assertEqual([pi.content for pi in q],
|
|
['s1p1i1', 's1p1i2', 's1p1i3', 's1p2i1'])
|
|
|
|
assertQ(q)
|
|
|
|
sid = Site.get(Site.url == 's1').id
|
|
q = (PageItem
|
|
.filter(page__site__url='s1', page__site__id=sid)
|
|
.order_by(PageItem.content))
|
|
assertQ(q)
|
|
|
|
q = (PageItem
|
|
.filter(page__site__id=sid)
|
|
.filter(page__site__url='s1')
|
|
.order_by(PageItem.content))
|
|
assertQ(q)
|
|
|
|
q = (PageItem
|
|
.filter(page__site__id=sid)
|
|
.filter(DQ(page__title='p1-1') | DQ(page__title='p1-2'))
|
|
.filter(page__site__url='s1')
|
|
.order_by(PageItem.content))
|
|
assertQ(q)
|
|
|
|
class TestCountSubqueryEquals(ModelTestCase):
|
|
requires = [User, Tweet]
|
|
|
|
def test_count_subquery_equals(self):
|
|
a, b, c = [User.create(username=u) for u in 'abc']
|
|
Tweet.insert_many([(a, 'a1'), (b, 'b1')]).execute()
|
|
|
|
subq = (Tweet
|
|
.select(fn.COUNT(Tweet.id))
|
|
.where(Tweet.user == User.id))
|
|
query = User.select().where(subq == 0)
|
|
self.assertEqual([u.username for u in query], ['c'])
|
|
|
|
class TestChainWhere(ModelTestCase):
|
|
requires = [User]
|
|
|
|
def test_chain_where(self):
|
|
for username in 'abcd':
|
|
User.create(username=username)
|
|
|
|
q = (User.select()
|
|
.where(User.username != 'a')
|
|
.where(User.username != 'd')
|
|
.order_by(User.username))
|
|
self.assertEqual([u.username for u in q], ['b', 'c'])
|
|
|
|
q = (User.select()
|
|
.where(User.username != 'a')
|
|
.where(User.username != 'd')
|
|
.where(User.username == 'b'))
|
|
self.assertEqual([u.username for u in q], ['b'])
|
|
|
|
class TestSaveClearingPK(ModelTestCase):
|
|
requires = [User, Tweet]
|
|
|
|
def test_save_clear_pk(self):
|
|
u = User.create(username='u1')
|
|
t1 = Tweet.create(content='t1', user=u)
|
|
orig_id, t1.id = t1.id, None
|
|
t1.content = 't2'
|
|
t1.save()
|
|
self.assertTrue(t1.id is not None)
|
|
self.assertTrue(t1.id != orig_id)
|
|
tweets = [t.content for t in u.tweets.order_by(Tweet.id)]
|
|
self.assertEqual(tweets, ['t1', 't2'])
|
|
|
|
class TestWeirdAliases(ModelTestCase):
|
|
requires = [User]
|
|
|
|
@skip_if(IS_MYSQL) # mysql can't do anything normally.
|
|
def test_weird_aliases(self):
|
|
User.create(username='huey')
|
|
def assertAlias(s, expected):
|
|
query = User.select(s).dicts()
|
|
row = query[0]
|
|
self.assertEqual(list(row)[0], expected)
|
|
|
|
# When we explicitly provide an alias, use that.
|
|
assertAlias(User.username.alias('"username"'), '"username"')
|
|
assertAlias(User.username.alias('(username)'), '(username)')
|
|
assertAlias(User.username.alias('user(name)'), 'user(name)')
|
|
assertAlias(User.username.alias('(username"'), '(username"')
|
|
assertAlias(User.username.alias('"username)'), '"username)')
|
|
assertAlias(fn.LOWER(User.username).alias('user (name)'), 'user (name)')
|
|
|
|
# Here peewee cannot tell that an alias was given, so it will attempt
|
|
# to clean-up the column name returned by the cursor description.
|
|
assertAlias(SQL('"t1"."username" AS "user name"'), 'user name')
|
|
assertAlias(SQL('"t1"."username" AS "user (name)"'), 'user (name')
|
|
assertAlias(SQL('"t1"."username" AS "(username)"'), 'username')
|
|
assertAlias(SQL('"t1"."username" AS "x.y.(username)"'), 'username')
|
|
if IS_SQLITE:
|
|
assertAlias(SQL('LOWER("t1"."username")'), 'username')
|
|
|
|
class CQA(TestModel):
|
|
a = TextField()
|
|
b = TextField()
|
|
|
|
|
|
class TestSelectFromUnion(ModelTestCase):
|
|
requires = [CQA]
|
|
|
|
def test_select_from_union(self):
|
|
CQA.insert_many([('a%d' % i, 'b%d' % i) for i in range(10)]).execute()
|
|
|
|
q1 = CQA.select(CQA.a).order_by(CQA.id).limit(3)
|
|
q2 = CQA.select(CQA.b).order_by(CQA.id).limit(3)
|
|
|
|
wq1 = q1.select_from(SQL('*'))
|
|
wq2 = q2.select_from(SQL('*'))
|
|
union = wq1 | wq2
|
|
data = [val for val, in union.tuples()]
|
|
self.assertEqual(sorted(data), ['a0', 'a1', 'a2', 'b0', 'b1', 'b2'])
|
|
|
|
class DF(TestModel):
|
|
name = TextField()
|
|
value = IntegerField()
|
|
class DFC(TestModel):
|
|
df = ForeignKeyField(DF)
|
|
name = TextField()
|
|
value = IntegerField()
|
|
class DFGC(TestModel):
|
|
dfc = ForeignKeyField(DFC)
|
|
name = TextField()
|
|
value = IntegerField()
|
|
|
|
|
|
class TestDjangoFilterRegression(ModelTestCase):
|
|
requires = [DF, DFC, DFGC]
|
|
|
|
def test_django_filter_regression(self):
|
|
a, b, c = [DF.create(name=n, value=i) for i, n in enumerate('abc')]
|
|
ca1 = DFC.create(df=a, name='a1', value=11)
|
|
ca2 = DFC.create(df=a, name='a2', value=12)
|
|
cb1 = DFC.create(df=b, name='b1', value=21)
|
|
|
|
gca1_1 = DFGC.create(dfc=ca1, name='a1-1', value=101)
|
|
gca1_2 = DFGC.create(dfc=ca1, name='a1-2', value=101)
|
|
gca2_1 = DFGC.create(dfc=ca2, name='a2-1', value=111)
|
|
|
|
def assertNames(q, expected):
|
|
self.assertEqual(sorted([n.name for n in q]), expected)
|
|
|
|
assertNames(DF.filter(name='a'), ['a'])
|
|
assertNames(DF.filter(name='a', id=a.id), ['a'])
|
|
assertNames(DF.filter(name__in=['a', 'c']), ['a', 'c'])
|
|
assertNames(DF.filter(name__in=['a', 'c'], id=a.id), ['a'])
|
|
assertNames(DF.filter(dfc_set__name='a1'), ['a'])
|
|
assertNames(DF.filter(dfc_set__name__in=['a1', 'b1']), ['a', 'b'])
|
|
assertNames(DF.filter(DQ(dfc_set__name='a1') | DQ(dfc_set__name='b1')),
|
|
['a', 'b'])
|
|
assertNames(DF.filter(dfc_set__dfgc_set__name='a1-1'), ['a'])
|
|
assertNames(DF.filter(
|
|
DQ(dfc_set__dfgc_set__name='a1-1') |
|
|
DQ(dfc_set__dfgc_set__name__in=['x', 'y'])), ['a'])
|
|
|
|
assertNames(DFC.filter(df__name='a'), ['a1', 'a2'])
|
|
assertNames(DFC.filter(df__name='a', value=11), ['a1'])
|
|
assertNames(DFC.filter(DQ(df__name='a') | DQ(df__name='b')),
|
|
['a1', 'a2', 'b1'])
|
|
assertNames(DFC.filter(
|
|
DQ(df__name='a') | DQ(dfgc_set__name='a1-1')).distinct(),
|
|
['a1', 'a2'])
|
|
|
|
assertNames(DFGC.filter(dfc__df__name='a'), ['a1-1', 'a1-2', 'a2-1'])
|
|
assertNames(DFGC.filter(dfc__df__name='a', dfc__name='a2'), ['a2-1'])
|
|
assertNames(DFGC.filter(
|
|
DQ(dfc__df__value__lte=0) |
|
|
DQ(dfc__df__name='a', dfc__name='a1') |
|
|
DQ(dfc__name='a2')), ['a1-1', 'a1-2', 'a2-1'])
|
|
|
|
assertNames(
|
|
(DFGC.filter(DQ(dfc__df__value__lte=10) | DQ(dfc__value__lte=101))
|
|
.filter(DQ(name__ilike='a1%') | DQ(dfc__value=101))),
|
|
['a1-1', 'a1-2'])
|
|
|
|
assertNames(DFGC.filter(dfc__df=a), ['a1-1', 'a1-2', 'a2-1'])
|
|
assertNames(DFGC.filter(dfc__df=a.id), ['a1-1', 'a1-2', 'a2-1'])
|
|
|
|
q = DFC.select().join(DF)
|
|
assertNames(q.filter(df=a), ['a1', 'a2'])
|
|
assertNames(q.filter(df__name='a'), ['a1', 'a2'])
|
|
|
|
DFA = DF.alias()
|
|
DFCA = DFC.alias()
|
|
DFGCA = DFGC.alias()
|
|
q = DFCA.select().join(DFA)
|
|
assertNames(q.filter(df=a), ['a1', 'a2'])
|
|
assertNames(q.filter(df__name='a'), ['a1', 'a2'])
|
|
|
|
q = DFGC.select().join(DFC).join(DF)
|
|
assertNames(q.filter(dfc__df=a), ['a1-1', 'a1-2', 'a2-1'])
|
|
|
|
q = DFGCA.select().join(DFCA).join(DFA)
|
|
assertNames(q.filter(dfc__df=a), ['a1-1', 'a1-2', 'a2-1'])
|
|
|
|
q = DF.select().join(DFC).join(DFGC)
|
|
assertNames(q.filter(dfc_set__dfgc_set__name='a1-1'), ['a'])
|
|
|
|
class I(TestModel):
|
|
name = TextField()
|
|
class S(TestModel):
|
|
i = ForeignKeyField(I)
|
|
class P(TestModel):
|
|
i = ForeignKeyField(I)
|
|
class PS(TestModel):
|
|
p = ForeignKeyField(P)
|
|
s = ForeignKeyField(S)
|
|
class PP(TestModel):
|
|
ps = ForeignKeyField(PS)
|
|
class O(TestModel):
|
|
ps = ForeignKeyField(PS)
|
|
s = ForeignKeyField(S)
|
|
class OX(TestModel):
|
|
o = ForeignKeyField(O, null=True)
|
|
|
|
class Character(TestModel):
|
|
name = TextField()
|
|
class Shape(TestModel):
|
|
character = ForeignKeyField(Character, null=True)
|
|
class ShapeDetail(TestModel):
|
|
shape = ForeignKeyField(Shape)
|
|
|
|
|
|
class TestSumCaseSubquery(ModelTestCase):
|
|
requires = [Sample]
|
|
|
|
def test_sum_case_subquery(self):
|
|
Sample.insert_many([(i, i) for i in range(5)]).execute()
|
|
|
|
subq = Sample.select().where(Sample.counter.in_([1, 3, 5]))
|
|
case = Case(None, [(Sample.id.in_(subq), Sample.value)], 0)
|
|
q = Sample.select(fn.SUM(case))
|
|
self.assertEqual(q.scalar(), 4.0)
|
|
|
|
class TestDeleteInstanceDFS(ModelTestCase):
|
|
@requires_models(Character, Shape, ShapeDetail)
|
|
def test_delete_instance_dfs_nullable(self):
|
|
c1, c2 = [Character.create(name=name) for name in ('c1', 'c2')]
|
|
for c in (c1, c2):
|
|
s = Shape.create(character=c)
|
|
ShapeDetail.create(shape=s)
|
|
|
|
# Update nullables.
|
|
with self.assertQueryCount(2):
|
|
c1.delete_instance(True)
|
|
|
|
self.assertHistory(2, [
|
|
('UPDATE "shape" SET "character_id" = ? WHERE '
|
|
'("shape"."character_id" = ?)', [None, c1.id]),
|
|
('DELETE FROM "character" WHERE ("character"."id" = ?)', [c1.id])])
|
|
|
|
self.assertEqual(Shape.select().count(), 2)
|
|
|
|
# Delete nullables as well.
|
|
with self.assertQueryCount(3):
|
|
c2.delete_instance(True, True)
|
|
|
|
self.assertHistory(3, [
|
|
('DELETE FROM "shape_detail" WHERE '
|
|
'("shape_detail"."shape_id" IN '
|
|
'(SELECT "t1"."id" FROM "shape" AS "t1" WHERE '
|
|
'("t1"."character_id" = ?)))', [c2.id]),
|
|
('DELETE FROM "shape" WHERE ("shape"."character_id" = ?)', [c2.id]),
|
|
('DELETE FROM "character" WHERE ("character"."id" = ?)', [c2.id])])
|
|
|
|
self.assertEqual(Shape.select().count(), 1)
|
|
|
|
@requires_models(I, S, P, PS, PP, O, OX)
|
|
def test_delete_instance_dfs(self):
|
|
i1, i2 = [I.create(name=n) for n in ('i1', 'i2')]
|
|
for i in (i1, i2):
|
|
s = S.create(i=i)
|
|
p = P.create(i=i)
|
|
ps = PS.create(p=p, s=s)
|
|
pp = PP.create(ps=ps)
|
|
o = O.create(ps=ps, s=s)
|
|
ox = OX.create(o=o)
|
|
|
|
with self.assertQueryCount(9):
|
|
i1.delete_instance(recursive=True)
|
|
|
|
self.assertHistory(9, [
|
|
('DELETE FROM "pp" WHERE ('
|
|
'"pp"."ps_id" IN (SELECT "t1"."id" FROM "ps" AS "t1" WHERE ('
|
|
'"t1"."p_id" IN (SELECT "t2"."id" FROM "p" AS "t2" WHERE ('
|
|
'"t2"."i_id" = ?)))))', [i1.id]),
|
|
('UPDATE "ox" SET "o_id" = ? WHERE ('
|
|
'"ox"."o_id" IN (SELECT "t1"."id" FROM "o" AS "t1" WHERE ('
|
|
'"t1"."ps_id" IN (SELECT "t2"."id" FROM "ps" AS "t2" WHERE ('
|
|
'"t2"."p_id" IN (SELECT "t3"."id" FROM "p" AS "t3" WHERE ('
|
|
'"t3"."i_id" = ?)))))))', [None, i1.id]),
|
|
('DELETE FROM "o" WHERE ('
|
|
'"o"."ps_id" IN (SELECT "t1"."id" FROM "ps" AS "t1" WHERE ('
|
|
'"t1"."p_id" IN (SELECT "t2"."id" FROM "p" AS "t2" WHERE ('
|
|
'"t2"."i_id" = ?)))))', [i1.id]),
|
|
('DELETE FROM "o" WHERE ('
|
|
'"o"."s_id" IN (SELECT "t1"."id" FROM "s" AS "t1" WHERE ('
|
|
'"t1"."i_id" = ?)))', [i1.id]),
|
|
('DELETE FROM "ps" WHERE ('
|
|
'"ps"."p_id" IN (SELECT "t1"."id" FROM "p" AS "t1" WHERE ('
|
|
'"t1"."i_id" = ?)))', [i1.id]),
|
|
('DELETE FROM "ps" WHERE ('
|
|
'"ps"."s_id" IN (SELECT "t1"."id" FROM "s" AS "t1" WHERE ('
|
|
'"t1"."i_id" = ?)))', [i1.id]),
|
|
('DELETE FROM "s" WHERE ("s"."i_id" = ?)', [i1.id]),
|
|
('DELETE FROM "p" WHERE ("p"."i_id" = ?)', [i1.id]),
|
|
('DELETE FROM "i" WHERE ("i"."id" = ?)', [i1.id]),
|
|
])
|
|
|
|
models = [I, S, P, PS, PP, O, OX]
|
|
counts = {OX: 2}
|
|
for m in models:
|
|
self.assertEqual(m.select().count(), counts.get(m, 1))
|
|
|
|
class TestQueryCountList(ModelTestCase):
|
|
requires = [User]
|
|
|
|
def test_iteration_single_query(self):
|
|
with self.assertQueryCount(1):
|
|
sq = User.select()
|
|
for i in range(3):
|
|
self.assertEqual(list(sq), [])
|
|
self.assertFalse(bool(sq))
|
|
with self.assertQueryCount(1):
|
|
sq = User.select().tuples()
|
|
for i in range(3):
|
|
self.assertEqual(list(sq), [])
|
|
self.assertFalse(bool(sq))
|
|
with self.assertQueryCount(1):
|
|
self.assertEqual(User.select().count(), 0)
|
|
|
|
|
|
class TestModelSelectFromSubquery(ModelTestCase):
|
|
requires = [User]
|
|
|
|
def test_model_select_from_subquery(self):
|
|
for i in range(5):
|
|
User.create(username='u%s' % i)
|
|
|
|
UA = User.alias()
|
|
subquery = (UA.select()
|
|
.where(UA.username.in_(('u0', 'u2', 'u4'))))
|
|
|
|
cte = (ValuesList([('u0',), ('u4',)], columns=['username'])
|
|
.cte('user_cte', columns=['username']))
|
|
|
|
query = (User
|
|
.select(subquery.c.id, subquery.c.username)
|
|
.from_(subquery)
|
|
.join(cte, on=(subquery.c.username == cte.c.username))
|
|
.with_cte(cte)
|
|
.order_by(subquery.c.username.desc()))
|
|
self.assertEqual([u.username for u in query], ['u4', 'u0'])
|
|
self.assertTrue(isinstance(query[0], User))
|
|
|