mirror of
https://github.com/coleifer/peewee.git
synced 2026-05-06 07:56:41 -04:00
1aeef460a4
C implementations of ranking functions are moved to _sqlite_udf. Capabilites exclusive to the CSqliteExtDatabase implementation are being migrated over to cysqlite_ext.CySqliteDatabase, which supports them natively without relying on hacks.
2311 lines
85 KiB
Python
2311 lines
85 KiB
Python
from decimal import Decimal as D
|
|
import datetime
|
|
import os
|
|
import sys
|
|
|
|
from peewee import *
|
|
from peewee import sqlite3
|
|
from playhouse.sqlite_ext import *
|
|
|
|
from .base import BaseTestCase
|
|
from .base import IS_SQLITE_37
|
|
from .base import IS_SQLITE_9
|
|
from .base import ModelTestCase
|
|
from .base import TestModel
|
|
from .base import db_loader
|
|
from .base import get_in_memory_db
|
|
from .base import requires_models
|
|
from .base import skip_if
|
|
from .base import skip_unless
|
|
from .base_models import Person
|
|
from .base_models import Tweet
|
|
from .base_models import User
|
|
from .sqlite_helpers import compile_option
|
|
from .sqlite_helpers import json_installed
|
|
from .sqlite_helpers import json_patch_installed
|
|
from .sqlite_helpers import json_text_installed
|
|
from .sqlite_helpers import jsonb_installed
|
|
|
|
|
|
database = SqliteExtDatabase(':memory:', timeout=100)
|
|
|
|
|
|
CLOSURE_EXTENSION = os.environ.get('PEEWEE_CLOSURE_EXTENSION')
|
|
if not CLOSURE_EXTENSION and os.path.exists('closure.so'):
|
|
CLOSURE_EXTENSION = './closure.so'
|
|
|
|
LSM_EXTENSION = os.environ.get('LSM_EXTENSION')
|
|
if not LSM_EXTENSION and os.path.exists('lsm.so'):
|
|
LSM_EXTENSION = './lsm.so'
|
|
|
|
try:
|
|
from playhouse._sqlite_udf import peewee_rank
|
|
CYTHON_EXTENSION = True
|
|
except ImportError:
|
|
CYTHON_EXTENSION = False
|
|
|
|
|
|
class WeightedAverage(object):
|
|
def __init__(self):
|
|
self.total = 0.
|
|
self.count = 0.
|
|
|
|
def step(self, value, weight=None):
|
|
weight = weight or 1.
|
|
self.total += weight
|
|
self.count += (weight * value)
|
|
|
|
def finalize(self):
|
|
if self.total != 0.:
|
|
return self.count / self.total
|
|
return 0.
|
|
|
|
def _cmp(l, r):
|
|
if l < r:
|
|
return -1
|
|
return 1 if r < l else 0
|
|
|
|
def collate_reverse(s1, s2):
|
|
return -_cmp(s1, s2)
|
|
|
|
@database.collation()
|
|
def collate_case_insensitive(s1, s2):
|
|
return _cmp(s1.lower(), s2.lower())
|
|
|
|
def title_case(s): return s.title()
|
|
|
|
@database.func()
|
|
def rstrip(s, n):
|
|
return s.rstrip(n)
|
|
|
|
database.register_aggregate(WeightedAverage, 'weighted_avg', 1)
|
|
database.register_aggregate(WeightedAverage, 'weighted_avg2', 2)
|
|
database.register_collation(collate_reverse)
|
|
database.register_function(title_case)
|
|
|
|
|
|
class Post(TestModel):
|
|
message = TextField()
|
|
|
|
|
|
class ContentPost(FTSModel, Post):
|
|
class Meta:
|
|
options = {
|
|
'content': Post,
|
|
'tokenize': 'porter'}
|
|
|
|
|
|
class ContentPostMessage(FTSModel, TestModel):
|
|
message = TextField()
|
|
class Meta:
|
|
options = {'tokenize': 'porter', 'content': Post.message}
|
|
|
|
|
|
class Document(FTSModel, TestModel):
|
|
message = TextField()
|
|
class Meta:
|
|
options = {'tokenize': 'porter'}
|
|
|
|
|
|
class MultiColumn(FTSModel, TestModel):
|
|
c1 = SearchField()
|
|
c2 = SearchField()
|
|
c3 = SearchField()
|
|
c4 = IntegerField()
|
|
class Meta:
|
|
options = {'tokenize': 'porter'}
|
|
|
|
|
|
class RowIDModel(TestModel):
|
|
rowid = RowIDField()
|
|
data = IntegerField()
|
|
|
|
|
|
class KeyData(TestModel):
|
|
key = TextField()
|
|
data = JSONField()
|
|
|
|
class JBData(TestModel):
|
|
key = TextField()
|
|
data = JSONBField()
|
|
|
|
|
|
class Values(TestModel):
|
|
klass = IntegerField()
|
|
value = FloatField()
|
|
weight = FloatField()
|
|
|
|
|
|
class FTS5Test(FTS5Model):
|
|
title = SearchField()
|
|
data = SearchField()
|
|
misc = SearchField(unindexed=True)
|
|
|
|
class Meta:
|
|
legacy_table_names = False
|
|
|
|
|
|
class FTS5Document(FTS5Model):
|
|
message = SearchField()
|
|
class Meta:
|
|
options = {'tokenize': 'porter'}
|
|
|
|
|
|
class DT(TestModel):
|
|
key = TextField(primary_key=True)
|
|
d = DateTimeField()
|
|
iso = ISODateTimeField()
|
|
|
|
|
|
@skip_unless(json_installed(), 'requires sqlite json1')
|
|
class TestJSONField(ModelTestCase):
|
|
database = database
|
|
requires = [KeyData]
|
|
|
|
def test_schema(self):
|
|
self.assertSQL(KeyData._schema._create_table(), (
|
|
'CREATE TABLE IF NOT EXISTS "key_data" ('
|
|
'"id" INTEGER NOT NULL PRIMARY KEY, '
|
|
'"key" TEXT NOT NULL, '
|
|
'"data" JSON NOT NULL)'), [])
|
|
|
|
def test_create_read_update(self):
|
|
test_values = (
|
|
'simple string',
|
|
'',
|
|
1337,
|
|
0.0,
|
|
True,
|
|
False,
|
|
['foo', 'bar', ['baz', 'nug']],
|
|
{'k1': 'v1', 'k2': {'x1': 'y1', 'x2': 'y2'}},
|
|
{'a': 1, 'b': 0.0, 'c': True, 'd': False, 'e': None, 'f': [0, 1],
|
|
'g': {'h': 'ijkl'}},
|
|
)
|
|
|
|
# Create a row using the given test value. Verify we can read the value
|
|
# back from the database, and also that we can query for the row using
|
|
# the value in the WHERE clause.
|
|
for i, value in enumerate(test_values):
|
|
# We can create and re-read values.
|
|
KeyData.create(key='k%s' % i, data=value)
|
|
kd_db = KeyData.get(KeyData.key == 'k%s' % i)
|
|
self.assertEqual(kd_db.data, value)
|
|
|
|
# We can read the data back using the value in the WHERE clause.
|
|
kd_db = KeyData.get(KeyData.data == value)
|
|
self.assertEqual(kd_db.key, 'k%s' % i)
|
|
|
|
# Verify we can use values in UPDATE query.
|
|
kd = KeyData.create(key='kx', data='')
|
|
for value in test_values:
|
|
nrows = (KeyData
|
|
.update(data=value)
|
|
.where(KeyData.key == 'kx')
|
|
.execute())
|
|
self.assertEqual(nrows, 1)
|
|
kd_db = KeyData.get(KeyData.key == 'kx')
|
|
self.assertEqual(kd_db.data, value)
|
|
|
|
def test_json_unicode(self):
|
|
with self.database.atomic():
|
|
KeyData.delete().execute()
|
|
|
|
# Two Chinese characters.
|
|
unicode_str = b'\xe4\xb8\xad\xe6\x96\x87'.decode('utf8')
|
|
data = {'foo': unicode_str}
|
|
kd = KeyData.create(key='k1', data=data)
|
|
|
|
kd_db = KeyData.get(KeyData.key == 'k1')
|
|
self.assertEqual(kd_db.data, {'foo': unicode_str})
|
|
|
|
def test_json_to_json(self):
|
|
kd1 = KeyData.create(key='k1', data={'k1': 'v1', 'k2': 'v2'})
|
|
subq = (KeyData
|
|
.select(KeyData.data)
|
|
.where(KeyData.key == 'k1'))
|
|
|
|
# Assign value using a subquery.
|
|
KeyData.create(key='k2', data=subq)
|
|
kd2_db = KeyData.get(KeyData.key == 'k2')
|
|
self.assertEqual(kd2_db.data, {'k1': 'v1', 'k2': 'v2'})
|
|
|
|
def test_json_bulk_update_top_level_list(self):
|
|
kd1 = KeyData.create(key='k1', data=['a', 'b', 'c'])
|
|
kd2 = KeyData.create(key='k2', data=['d', 'e', 'f'])
|
|
|
|
kd1.data = ['g', 'h', 'i']
|
|
kd2.data = ['j', 'k', 'l']
|
|
KeyData.bulk_update([kd1, kd2], fields=[KeyData.data])
|
|
kd1_db = KeyData.get(KeyData.key == 'k1')
|
|
kd2_db = KeyData.get(KeyData.key == 'k2')
|
|
self.assertEqual(kd1_db.data, ['g', 'h', 'i'])
|
|
self.assertEqual(kd2_db.data, ['j', 'k', 'l'])
|
|
|
|
def test_json_bulk_update_top_level_dict(self):
|
|
kd1 = KeyData.create(key='k1', data={'x': 'y1'})
|
|
kd2 = KeyData.create(key='k2', data={'x': 'y2'})
|
|
|
|
kd1.data = {'x': 'z1'}
|
|
kd2.data = {'X': 'Z2'}
|
|
KeyData.bulk_update([kd1, kd2], fields=[KeyData.data])
|
|
kd1_db = KeyData.get(KeyData.key == 'k1')
|
|
kd2_db = KeyData.get(KeyData.key == 'k2')
|
|
self.assertEqual(kd1_db.data, {'x': 'z1'})
|
|
self.assertEqual(kd2_db.data, {'X': 'Z2'})
|
|
|
|
def test_json_multi_ops(self):
|
|
data = (
|
|
('k1', [0, 1]),
|
|
('k2', [1, 2]),
|
|
('k3', {'x3': 'y3'}),
|
|
('k4', {'x4': 'y4'}))
|
|
res = KeyData.insert_many(data).execute()
|
|
if database.returning_clause:
|
|
self.assertEqual([r for r, in res], [1, 2, 3, 4])
|
|
else:
|
|
self.assertEqual(res, 4)
|
|
|
|
vals = [[1, 2], [2, 3], {'x3': 'y3'}, {'x5': 'y5'}]
|
|
pw_vals = [Value(v, unpack=False) for v in vals]
|
|
|
|
query = KeyData.select().where(KeyData.data.in_(pw_vals))
|
|
self.assertSQL(query, (
|
|
'SELECT "t1"."id", "t1"."key", "t1"."data" '
|
|
'FROM "key_data" AS "t1" '
|
|
'WHERE ("t1"."data" IN (json(?), json(?), json(?), json(?)))'),
|
|
['[1, 2]', '[2, 3]', '{"x3": "y3"}', '{"x5": "y5"}'])
|
|
|
|
self.assertEqual(query.count(), 2)
|
|
self.assertEqual(sorted([k.key for k in query]), ['k2', 'k3'])
|
|
|
|
query = KeyData.select().where(KeyData.data == [1, 2])
|
|
self.assertEqual(query.count(), 1)
|
|
self.assertEqual(query.get().key, 'k2')
|
|
|
|
query = KeyData.select().where(KeyData.data == {'x3': 'y3'})
|
|
self.assertEqual(query.count(), 1)
|
|
self.assertEqual(query.get().key, 'k3')
|
|
|
|
|
|
@skip_unless(json_installed(), 'requires sqlite json1')
|
|
class TestJSONFieldFunctions(ModelTestCase):
|
|
database = database
|
|
requires = [KeyData]
|
|
test_data = [
|
|
('a', {'k1': 'v1', 'x1': {'y1': 'z1'}}),
|
|
('b', {'k2': 'v2', 'x2': {'y2': 'z2'}}),
|
|
('c', {'k1': 'v1', 'k2': 'v2'}),
|
|
('d', {'x1': {'y1': 'z1', 'y2': 'z2'}}),
|
|
('e', {'l1': [0, 1, 2], 'l2': [1, [3, 3], 7]}),
|
|
]
|
|
M = KeyData
|
|
|
|
def setUp(self):
|
|
super(TestJSONFieldFunctions, self).setUp()
|
|
KeyData = self.M
|
|
with self.database.atomic():
|
|
for key, data in self.test_data:
|
|
KeyData.create(key=key, data=data)
|
|
|
|
self.Q = KeyData.select().order_by(KeyData.key)
|
|
|
|
def assertRows(self, where, expected):
|
|
self.assertEqual([kd.key for kd in self.Q.where(where)], expected)
|
|
|
|
def assertData(self, key, expected):
|
|
KeyData = self.M
|
|
self.assertEqual(KeyData.get(KeyData.key == key).data, expected)
|
|
|
|
def test_json_group_functions(self):
|
|
KeyData = self.M
|
|
with self.database.atomic():
|
|
KeyData.delete().execute()
|
|
for i in range(10):
|
|
# e.g., {v: 0, v0: {items: []}}, {v: 2, v2: {items: [0, 1]}}
|
|
KeyData.create(key='k%s' % i, data={'v': i, 'v%s' % i: {
|
|
'items': list(range(i))}})
|
|
|
|
jga_key = fn.json_group_array(KeyData.key)
|
|
query = (KeyData
|
|
.select(jga_key)
|
|
.where(KeyData.data['v'] < 4)
|
|
.order_by(KeyData.key))
|
|
self.assertEqual(json.loads(query.scalar()), ['k0', 'k1', 'k2', 'k3'])
|
|
|
|
# Can specify json.loads as the converter for the function.
|
|
query = (KeyData
|
|
.select(jga_key.python_value(json.loads))
|
|
.where(KeyData.data['v'] > 6)
|
|
.order_by(KeyData.key))
|
|
self.assertEqual(query.scalar(), ['k7', 'k8', 'k9'])
|
|
|
|
# Aggregating a list of ints?
|
|
jga_id = fn.json_group_array(KeyData.id)
|
|
query = (KeyData
|
|
.select(jga_id)
|
|
.where(KeyData.data['v'] < 4)
|
|
.order_by(KeyData.id))
|
|
self.assertEqual(json.loads(query.scalar()), [1, 2, 3, 4])
|
|
|
|
query = (KeyData
|
|
.select(jga_id.python_value(json.loads))
|
|
.where(KeyData.data['v'] > 6)
|
|
.order_by(KeyData.id))
|
|
self.assertEqual(query.scalar(), [8, 9, 10])
|
|
|
|
# Using json_group_object.
|
|
jgo_key = fn.json_group_object(KeyData.key, KeyData.data['v'])
|
|
res = (KeyData
|
|
.select(jgo_key)
|
|
.where(KeyData.data['v'] < 4)
|
|
.scalar())
|
|
self.assertEqual(json.loads(res), {'k0': 0, 'k1': 1, 'k2': 2, 'k3': 3})
|
|
|
|
query = (KeyData
|
|
.select(jgo_key.python_value(json.loads))
|
|
.where(KeyData.data['v'] < 4))
|
|
self.assertEqual(query.scalar(), {'k0': 0, 'k1': 1, 'k2': 2, 'k3': 3})
|
|
|
|
def test_extract(self):
|
|
KeyData = self.M
|
|
self.assertRows((KeyData.data['k1'] == 'v1'), ['a', 'c'])
|
|
self.assertRows((KeyData.data['k2'] == 'v2'), ['b', 'c'])
|
|
self.assertRows((KeyData.data['x1']['y1'] == 'z1'), ['a', 'd'])
|
|
self.assertRows((KeyData.data['l1'][1] == 1), ['e'])
|
|
self.assertRows((KeyData.data['l2'][1][1] == 3), ['e'])
|
|
|
|
@skip_unless(json_text_installed())
|
|
def test_extract_text_json(self):
|
|
KeyData = self.M
|
|
D = KeyData.data
|
|
self.assertRows((D.extract('$.k1') == 'v1'), ['a', 'c'])
|
|
self.assertRows((D.extract_text('$.k1') == 'v1'), ['a', 'c'])
|
|
self.assertRows((D.extract_json('$.k1') == '"v1"'), ['a', 'c'])
|
|
self.assertRows((D.extract_text('k2') == 'v2'), ['b', 'c'])
|
|
self.assertRows((D.extract_json('k2') == '"v2"'), ['b', 'c'])
|
|
self.assertRows((D.extract_text('$.x1.y1') == 'z1'), ['a', 'd'])
|
|
self.assertRows((D.extract_json('$.x1.y1') == '"z1"'), ['a', 'd'])
|
|
self.assertRows((D.extract_text('$.l1[1]') == 1), ['e'])
|
|
self.assertRows((D.extract_text('$.l2[1][1]') == 3), ['e'])
|
|
self.assertRows((D.extract_json('x1') == '{"y1":"z1"}'), ['a'])
|
|
|
|
def test_extract_multiple(self):
|
|
KeyData = self.M
|
|
query = KeyData.select(
|
|
KeyData.key,
|
|
KeyData.data.extract('$.k1', '$.k2').alias('keys'))
|
|
self.assertEqual(sorted((k.key, k.keys) for k in query), [
|
|
('a', ['v1', None]),
|
|
('b', [None, 'v2']),
|
|
('c', ['v1', 'v2']),
|
|
('d', [None, None]),
|
|
('e', [None, None])])
|
|
|
|
def test_insert(self):
|
|
KeyData = self.M
|
|
# Existing values are not overwritten.
|
|
query = KeyData.update(data=KeyData.data['k1'].insert('v1-x'))
|
|
self.assertEqual(query.execute(), 5)
|
|
|
|
self.assertData('a', {'k1': 'v1', 'x1': {'y1': 'z1'}})
|
|
self.assertData('b', {'k1': 'v1-x', 'k2': 'v2', 'x2': {'y2': 'z2'}})
|
|
self.assertData('c', {'k1': 'v1', 'k2': 'v2'})
|
|
self.assertData('d', {'k1': 'v1-x', 'x1': {'y1': 'z1', 'y2': 'z2'}})
|
|
self.assertData('e', {'k1': 'v1-x', 'l1': [0, 1, 2],
|
|
'l2': [1, [3, 3], 7]})
|
|
|
|
def test_insert_json(self):
|
|
KeyData = self.M
|
|
set_json = KeyData.data['k1'].insert([0])
|
|
query = KeyData.update(data=set_json)
|
|
self.assertEqual(query.execute(), 5)
|
|
|
|
self.assertData('a', {'k1': 'v1', 'x1': {'y1': 'z1'}})
|
|
self.assertData('b', {'k1': [0], 'k2': 'v2', 'x2': {'y2': 'z2'}})
|
|
self.assertData('c', {'k1': 'v1', 'k2': 'v2'})
|
|
self.assertData('d', {'k1': [0], 'x1': {'y1': 'z1', 'y2': 'z2'}})
|
|
self.assertData('e', {'k1': [0], 'l1': [0, 1, 2],
|
|
'l2': [1, [3, 3], 7]})
|
|
|
|
def test_replace(self):
|
|
KeyData = self.M
|
|
# Only existing values are overwritten.
|
|
query = KeyData.update(data=KeyData.data['k1'].replace('v1-x'))
|
|
self.assertEqual(query.execute(), 5)
|
|
|
|
self.assertData('a', {'k1': 'v1-x', 'x1': {'y1': 'z1'}})
|
|
self.assertData('b', {'k2': 'v2', 'x2': {'y2': 'z2'}})
|
|
self.assertData('c', {'k1': 'v1-x', 'k2': 'v2'})
|
|
self.assertData('d', {'x1': {'y1': 'z1', 'y2': 'z2'}})
|
|
self.assertData('e', {'l1': [0, 1, 2], 'l2': [1, [3, 3], 7]})
|
|
|
|
def test_replace_json(self):
|
|
KeyData = self.M
|
|
set_json = KeyData.data['k1'].replace([0])
|
|
query = KeyData.update(data=set_json)
|
|
self.assertEqual(query.execute(), 5)
|
|
|
|
self.assertData('a', {'k1': [0], 'x1': {'y1': 'z1'}})
|
|
self.assertData('b', {'k2': 'v2', 'x2': {'y2': 'z2'}})
|
|
self.assertData('c', {'k1': [0], 'k2': 'v2'})
|
|
self.assertData('d', {'x1': {'y1': 'z1', 'y2': 'z2'}})
|
|
self.assertData('e', {'l1': [0, 1, 2], 'l2': [1, [3, 3], 7]})
|
|
|
|
def test_set(self):
|
|
KeyData = self.M
|
|
query = (KeyData
|
|
.update({KeyData.data: KeyData.data['k1'].set('v1-x')})
|
|
.where(KeyData.data['k1'] == 'v1'))
|
|
self.assertEqual(query.execute(), 2)
|
|
self.assertRows((KeyData.data['k1'] == 'v1-x'), ['a', 'c'])
|
|
|
|
self.assertData('a', {'k1': 'v1-x', 'x1': {'y1': 'z1'}})
|
|
|
|
def test_set_json(self):
|
|
KeyData = self.M
|
|
set_json = KeyData.data['x1'].set({'y1': 'z1-x', 'y3': 'z3'})
|
|
query = (KeyData
|
|
.update({KeyData.data: set_json})
|
|
.where(KeyData.data['x1']['y1'] == 'z1'))
|
|
self.assertEqual(query.execute(), 2)
|
|
self.assertRows((KeyData.data['x1']['y1'] == 'z1-x'), ['a', 'd'])
|
|
|
|
self.assertData('a', {'k1': 'v1', 'x1': {'y1': 'z1-x', 'y3': 'z3'}})
|
|
self.assertData('d', {'x1': {'y1': 'z1-x', 'y3': 'z3'}})
|
|
|
|
def test_append(self):
|
|
KeyData = self.M
|
|
for value in ('ix', [], ['c1'], ['c1', 'c2'], {}, {'k1': 'v1'},
|
|
{'k1': 'v1', 'k2': 'v2'}, None, 1):
|
|
KeyData.delete().execute()
|
|
KeyData.create(key='a0', data=[])
|
|
KeyData.create(key='a1', data=['i1'])
|
|
KeyData.create(key='a2', data=['i1', 'i2'])
|
|
KeyData.create(key='n0', data={'arr': []})
|
|
KeyData.create(key='n1', data={'arr': ['i1']})
|
|
KeyData.create(key='n2', data={'arr': ['i1', 'i2']})
|
|
|
|
query = (KeyData
|
|
.update(data=KeyData.data.append(value))
|
|
.where(KeyData.key.startswith('a')))
|
|
self.assertEqual(query.execute(), 3)
|
|
|
|
query = (KeyData
|
|
.select(KeyData.key, fn.json(KeyData.data))
|
|
.where(KeyData.key.startswith('a')))
|
|
self.assertEqual(sorted((row.key, row.data) for row in query),
|
|
[('a0', [value]), ('a1', ['i1', value]),
|
|
('a2', ['i1', 'i2', value])])
|
|
|
|
query = (KeyData
|
|
.update(data=KeyData.data['arr'].append(value))
|
|
.where(KeyData.key.startswith('n')))
|
|
self.assertEqual(query.execute(), 3)
|
|
|
|
query = (KeyData
|
|
.select(KeyData.key, fn.json(KeyData.data))
|
|
.where(KeyData.key.startswith('n')))
|
|
self.assertEqual(sorted((row.key, row.data) for row in query),
|
|
[('n0', {'arr': [value]}),
|
|
('n1', {'arr': ['i1', value]}),
|
|
('n2', {'arr': ['i1', 'i2', value]})])
|
|
|
|
@skip_unless(json_patch_installed())
|
|
def test_update(self):
|
|
KeyData = self.M
|
|
merged = KeyData.data.update({'x1': {'y1': 'z1-x', 'y3': 'z3'}})
|
|
query = (KeyData
|
|
.update({KeyData.data: merged})
|
|
.where(KeyData.data['x1']['y1'] == 'z1'))
|
|
self.assertEqual(query.execute(), 2)
|
|
self.assertRows((KeyData.data['x1']['y1'] == 'z1-x'), ['a', 'd'])
|
|
|
|
self.assertData('a', {'k1': 'v1', 'x1': {'y1': 'z1-x', 'y3': 'z3'}})
|
|
self.assertData('d', {'x1': {'y1': 'z1-x', 'y2': 'z2', 'y3': 'z3'}})
|
|
|
|
@skip_unless(json_patch_installed())
|
|
def test_update_with_removal(self):
|
|
KeyData = self.M
|
|
m = KeyData.data.update({'k1': None, 'x1': {'y1': None, 'y3': 'z3'}})
|
|
query = KeyData.update(data=m).where(KeyData.data['x1']['y1'] == 'z1')
|
|
self.assertEqual(query.execute(), 2)
|
|
self.assertRows((KeyData.data['x1']['y3'] == 'z3'), ['a', 'd'])
|
|
|
|
self.assertData('a', {'x1': {'y3': 'z3'}})
|
|
self.assertData('d', {'x1': {'y2': 'z2', 'y3': 'z3'}})
|
|
|
|
@skip_unless(json_patch_installed())
|
|
def test_update_nested(self):
|
|
KeyData = self.M
|
|
merged = KeyData.data['x1'].update({'y1': 'z1-x', 'y3': 'z3'})
|
|
query = (KeyData
|
|
.update(data=merged)
|
|
.where(KeyData.data['x1']['y1'] == 'z1'))
|
|
self.assertEqual(query.execute(), 2)
|
|
self.assertRows((KeyData.data['x1']['y1'] == 'z1-x'), ['a', 'd'])
|
|
|
|
self.assertData('a', {'k1': 'v1', 'x1': {'y1': 'z1-x', 'y3': 'z3'}})
|
|
self.assertData('d', {'x1': {'y1': 'z1-x', 'y2': 'z2', 'y3': 'z3'}})
|
|
|
|
@skip_unless(json_patch_installed())
|
|
def test_updated_nested_with_removal(self):
|
|
KeyData = self.M
|
|
merged = KeyData.data['x1'].update({'o1': 'p1', 'y1': None})
|
|
nrows = (KeyData
|
|
.update(data=merged)
|
|
.where(KeyData.data['x1']['y1'] == 'z1')
|
|
.execute())
|
|
self.assertRows((KeyData.data['x1']['o1'] == 'p1'), ['a', 'd'])
|
|
self.assertData('a', {'k1': 'v1', 'x1': {'o1': 'p1'}})
|
|
self.assertData('d', {'x1': {'o1': 'p1', 'y2': 'z2'}})
|
|
|
|
def test_remove(self):
|
|
KeyData = self.M
|
|
query = (KeyData
|
|
.update(data=KeyData.data['k1'].remove())
|
|
.where(KeyData.data['k1'] == 'v1'))
|
|
self.assertEqual(query.execute(), 2)
|
|
|
|
self.assertData('a', {'x1': {'y1': 'z1'}})
|
|
self.assertData('c', {'k2': 'v2'})
|
|
|
|
nrows = (KeyData
|
|
.update(data=KeyData.data['l2'][1][1].remove())
|
|
.where(KeyData.key == 'e')
|
|
.execute())
|
|
self.assertData('e', {'l1': [0, 1, 2], 'l2': [1, [3], 7]})
|
|
|
|
def test_simple_update(self):
|
|
KeyData = self.M
|
|
nrows = (KeyData
|
|
.update(data={'foo': 'bar'})
|
|
.where(KeyData.key.in_(['a', 'b']))
|
|
.execute())
|
|
self.assertData('a', {'foo': 'bar'})
|
|
self.assertData('b', {'foo': 'bar'})
|
|
|
|
def test_children(self):
|
|
KeyData = self.M
|
|
children = KeyData.data.children().alias('children')
|
|
query = (KeyData
|
|
.select(KeyData.key, children.c.fullkey.alias('fullkey'))
|
|
.from_(KeyData, children)
|
|
.where(~children.c.fullkey.contains('k'))
|
|
.order_by(KeyData.id, SQL('fullkey')))
|
|
accum = [(row.key, row.fullkey) for row in query]
|
|
self.assertEqual(accum, [
|
|
('a', '$.x1'),
|
|
('b', '$.x2'),
|
|
('d', '$.x1'),
|
|
('e', '$.l1'), ('e', '$.l2')])
|
|
|
|
def test_tree(self):
|
|
KeyData = self.M
|
|
tree = KeyData.data.tree().alias('tree')
|
|
query = (KeyData
|
|
.select(tree.c.fullkey.alias('fullkey'))
|
|
.from_(KeyData, tree)
|
|
.where(KeyData.key == 'd')
|
|
.order_by(SQL('1'))
|
|
.tuples())
|
|
self.assertEqual([fullkey for fullkey, in query], [
|
|
'$',
|
|
'$.x1',
|
|
'$.x1.y1',
|
|
'$.x1.y2'])
|
|
|
|
|
|
@skip_unless(jsonb_installed(), 'requires sqlite jsonb support')
|
|
class TestJSONBFieldFunctions(TestJSONFieldFunctions):
|
|
requires = [JBData]
|
|
M = JBData
|
|
|
|
def assertData(self, key, expected):
|
|
q = JBData.select(fn.json(JBData.data)).where(JBData.key == key)
|
|
self.assertEqual(q.get().data, expected)
|
|
|
|
def test_extract_multiple(self):
|
|
# We need to override this, otherwise we end up with jsonb returned.
|
|
expr = fn.json(JBData.data.extract('$.k1', '$.k2'))
|
|
query = JBData.select(
|
|
JBData.key,
|
|
expr.python_value(json.loads).alias('keys'))
|
|
self.assertEqual(sorted((k.key, k.keys) for k in query), [
|
|
('a', ['v1', None]),
|
|
('b', [None, 'v2']),
|
|
('c', ['v1', 'v2']),
|
|
('d', [None, None]),
|
|
('e', [None, None])])
|
|
|
|
|
|
class TestSqliteExtensions(BaseTestCase):
|
|
def test_virtual_model(self):
|
|
class Test(VirtualModel):
|
|
class Meta:
|
|
database = database
|
|
extension_module = 'ext1337'
|
|
legacy_table_names = False
|
|
options = {'huey': 'cat', 'mickey': 'dog'}
|
|
primary_key = False
|
|
|
|
class SubTest(Test): pass
|
|
|
|
self.assertSQL(Test._schema._create_table(), (
|
|
'CREATE VIRTUAL TABLE IF NOT EXISTS "test" '
|
|
'USING ext1337 '
|
|
'(huey=cat, mickey=dog)'), [])
|
|
self.assertSQL(SubTest._schema._create_table(), (
|
|
'CREATE VIRTUAL TABLE IF NOT EXISTS "sub_test" '
|
|
'USING ext1337 '
|
|
'(huey=cat, mickey=dog)'), [])
|
|
self.assertSQL(
|
|
Test._schema._create_table(huey='kitten', zaizee='cat'),
|
|
('CREATE VIRTUAL TABLE IF NOT EXISTS "test" '
|
|
'USING ext1337 (huey=kitten, mickey=dog, zaizee=cat)'), [])
|
|
|
|
def test_autoincrement_field(self):
|
|
class AutoIncrement(TestModel):
|
|
id = AutoIncrementField()
|
|
data = TextField()
|
|
class Meta:
|
|
database = database
|
|
|
|
self.assertSQL(AutoIncrement._schema._create_table(), (
|
|
'CREATE TABLE IF NOT EXISTS "auto_increment" '
|
|
'("id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, '
|
|
'"data" TEXT NOT NULL)'), [])
|
|
|
|
|
|
class BaseFTSTestCase(object):
|
|
messages = (
|
|
('A faith is a necessity to a man. Woe to him who believes in '
|
|
'nothing.'),
|
|
('All who call on God in true faith, earnestly from the heart, will '
|
|
'certainly be heard, and will receive what they have asked and '
|
|
'desired.'),
|
|
('Be faithful in small things because it is in them that your '
|
|
'strength lies.'),
|
|
('Faith consists in believing when it is beyond the power of reason '
|
|
'to believe.'),
|
|
('Faith has to do with things that are not seen and hope with things '
|
|
'that are not at hand.'))
|
|
values = (
|
|
('aaaaa bbbbb ccccc ddddd', 'aaaaa ccccc', 'zzzzz zzzzz', 1),
|
|
('bbbbb ccccc ddddd eeeee', 'bbbbb', 'zzzzz', 2),
|
|
('ccccc ccccc ddddd fffff', 'ccccc', 'yyyyy', 3),
|
|
('ddddd', 'ccccc', 'xxxxx', 4))
|
|
|
|
def assertMessages(self, query, indexes):
|
|
self.assertEqual([obj.message for obj in query],
|
|
[self.messages[idx] for idx in indexes])
|
|
|
|
|
|
class TestFullTextSearch(BaseFTSTestCase, ModelTestCase):
|
|
database = database
|
|
requires = [
|
|
Post,
|
|
ContentPost,
|
|
ContentPostMessage,
|
|
Document,
|
|
MultiColumn]
|
|
|
|
@requires_models(Document)
|
|
def test_fts_insert_or_replace(self):
|
|
# We can use replace to create a new row.
|
|
n = Document.replace(docid=100, message='m100').execute()
|
|
self.assertEqual(n, 100)
|
|
self.assertEqual(Document.select().count(), 1)
|
|
|
|
# We can use replace to update an existing row.
|
|
n = Document.replace(docid=100, message='x100').execute()
|
|
self.assertEqual(n, 100)
|
|
self.assertEqual(Document.select().count(), 1)
|
|
|
|
# Adds a new row.
|
|
n = Document.replace(docid=101, message='x101').execute()
|
|
self.assertEqual(n, 101)
|
|
self.assertEqual(Document.select().count(), 2)
|
|
|
|
query = Document.select().order_by(Document.message)
|
|
self.assertEqual(list(query.tuples()), [(100, 'x100'), (101, 'x101')])
|
|
|
|
@requires_models(Document)
|
|
def test_fts_manual(self):
|
|
messages = [Document.create(message=message)
|
|
for message in self.messages]
|
|
query = (Document
|
|
.select()
|
|
.where(Document.match('believe'))
|
|
.order_by(Document.docid))
|
|
self.assertMessages(query, [0, 3])
|
|
|
|
query = Document.search('believe')
|
|
self.assertMessages(query, [3, 0])
|
|
|
|
# Test peewee's "rank" algorithm, as presented in the SQLite FTS3 docs.
|
|
query = Document.search('things', with_score=True)
|
|
self.assertEqual([(row.message, row.score) for row in query], [
|
|
(self.messages[4], -2. / 3),
|
|
(self.messages[2], -1. / 3)])
|
|
|
|
# Test peewee's bm25 ranking algorithm.
|
|
query = Document.search_bm25('things', with_score=True)
|
|
self.assertEqual([(d.message, round(d.score, 2)) for d in query], [
|
|
(self.messages[4], -0.45),
|
|
(self.messages[2], -0.36)])
|
|
|
|
# Another test of bm25 ranking.
|
|
query = Document.search_bm25('believe', with_score=True)
|
|
self.assertEqual([(d.message, round(d.score, 2)) for d in query], [
|
|
(self.messages[3], -0.49),
|
|
(self.messages[0], -0.35)])
|
|
|
|
query = Document.search_bm25('god faith', with_score=True)
|
|
self.assertEqual([(d.message, round(d.score, 2)) for d in query], [
|
|
(self.messages[1], -0.92)])
|
|
|
|
query = Document.search_bm25('"it is"', with_score=True)
|
|
self.assertEqual([(d.message, round(d.score, 2)) for d in query], [
|
|
(self.messages[2], -0.36),
|
|
(self.messages[3], -0.36)])
|
|
|
|
def test_fts_delete_row(self):
|
|
posts = [Post.create(message=msg) for msg in self.messages]
|
|
ContentPost.rebuild()
|
|
query = (ContentPost
|
|
.select(ContentPost, ContentPost.rank().alias('score'))
|
|
.where(ContentPost.match('believe'))
|
|
.order_by(ContentPost.docid))
|
|
self.assertMessages(query, [0, 3])
|
|
|
|
query = (ContentPost
|
|
.select(ContentPost.docid)
|
|
.order_by(ContentPost.docid))
|
|
for content_post in query:
|
|
self.assertEqual(content_post.delete_instance(), 1)
|
|
|
|
for post in posts:
|
|
self.assertEqual(
|
|
(ContentPost
|
|
.delete()
|
|
.where(ContentPost.message == post.message)
|
|
.execute()), 1)
|
|
|
|
# None of the deletes were processed since the table is managed.
|
|
self.assertEqual(ContentPost.select().count(), 5)
|
|
|
|
documents = [Document.create(message=message) for message in
|
|
self.messages]
|
|
self.assertEqual(Document.select().count(), 5)
|
|
|
|
for document in documents:
|
|
self.assertEqual(
|
|
(Document
|
|
.delete()
|
|
.where(Document.message == document.message)
|
|
.execute()), 1)
|
|
|
|
self.assertEqual(Document.select().count(), 0)
|
|
|
|
def _create_multi_column(self):
|
|
for c1, c2, c3, c4 in self.values:
|
|
MultiColumn.create(c1=c1, c2=c2, c3=c3, c4=c4)
|
|
|
|
@requires_models(MultiColumn)
|
|
def test_fts_multi_column(self):
|
|
def assertResults(term, expected):
|
|
results = [(x.c4, round(x.score, 2))
|
|
for x in MultiColumn.search(term, with_score=True)]
|
|
self.assertEqual(results, expected)
|
|
|
|
self._create_multi_column()
|
|
assertResults('bbbbb', [
|
|
(2, -1.5), # 1/2 + 1/1
|
|
(1, -0.5)]) # 1/2
|
|
|
|
# `ccccc` appears four times in `c1`, three times in `c2`.
|
|
assertResults('ccccc', [
|
|
(3, -.83), # 2/4 + 1/3
|
|
(1, -.58), # 1/4 + 1/3
|
|
(4, -.33), # 1/3
|
|
(2, -.25), # 1/4
|
|
])
|
|
|
|
# `zzzzz` appears three times in c3.
|
|
assertResults('zzzzz', [(1, -.67), (2, -.33)])
|
|
|
|
self.assertEqual(
|
|
[x.score for x in MultiColumn.search('ddddd', with_score=True)],
|
|
[-.25, -.25, -.25, -.25])
|
|
|
|
@requires_models(MultiColumn)
|
|
def test_bm25(self):
|
|
def assertResults(term, expected):
|
|
query = MultiColumn.search_bm25(term, [1.0, 0, 0, 0], True)
|
|
self.assertEqual(
|
|
[(mc.c4, round(mc.score, 2)) for mc in query],
|
|
expected)
|
|
|
|
self._create_multi_column()
|
|
MultiColumn.create(c1='aaaaa fffff', c4=5)
|
|
|
|
assertResults('aaaaa', [(5, -0.39), (1, -0.3)])
|
|
assertResults('fffff', [(5, -0.39), (3, -0.3)])
|
|
assertResults('eeeee', [(2, -0.97)])
|
|
|
|
# No column specified, use the first text field.
|
|
query = MultiColumn.search_bm25('fffff', [1.0, 0, 0, 0], True)
|
|
self.assertEqual([(mc.c4, round(mc.score, 2)) for mc in query], [
|
|
(5, -0.39),
|
|
(3, -0.3)])
|
|
|
|
# Use helpers.
|
|
query = (MultiColumn
|
|
.select(
|
|
MultiColumn.c4,
|
|
MultiColumn.bm25(1.0).alias('score'))
|
|
.where(MultiColumn.match('aaaaa'))
|
|
.order_by(SQL('score')))
|
|
self.assertEqual([(mc.c4, round(mc.score, 2)) for mc in query], [
|
|
(5, -0.39),
|
|
(1, -0.3)])
|
|
|
|
def assertAllColumns(term, expected):
|
|
query = MultiColumn.search_bm25(term, with_score=True)
|
|
self.assertEqual(
|
|
[(mc.c4, round(mc.score, 2)) for mc in query],
|
|
expected)
|
|
|
|
assertAllColumns('aaaaa ddddd', [(1, -1.08)])
|
|
assertAllColumns('zzzzz ddddd', [(1, -0.36), (2, -0.34)])
|
|
assertAllColumns('ccccc bbbbb ddddd', [(2, -1.39), (1, -0.3)])
|
|
|
|
@requires_models(Document)
|
|
def test_bm25_alt_corpus(self):
|
|
for message in self.messages:
|
|
Document.create(message=message)
|
|
|
|
query = Document.search_bm25('things', with_score=True)
|
|
self.assertEqual([(d.message, round(d.score, 2)) for d in query], [
|
|
(self.messages[4], -0.45),
|
|
(self.messages[2], -0.36)])
|
|
|
|
query = Document.search_bm25('believe', with_score=True)
|
|
self.assertEqual([(d.message, round(d.score, 2)) for d in query], [
|
|
(self.messages[3], -0.49),
|
|
(self.messages[0], -0.35)])
|
|
|
|
# Indeterminate order since all are 0.0. All phrases contain the word
|
|
# faith, so there is no meaningful score.
|
|
query = Document.search_bm25('faith', with_score=True)
|
|
self.assertEqual([round(d.score, 2) for d in query], [-0.] * 5)
|
|
|
|
def _test_fts_auto(self, ModelClass):
|
|
posts = []
|
|
for message in self.messages:
|
|
posts.append(Post.create(message=message))
|
|
|
|
# Nothing matches, index is not built.
|
|
pq = ModelClass.select().where(ModelClass.match('faith'))
|
|
self.assertEqual(list(pq), [])
|
|
|
|
ModelClass.rebuild()
|
|
ModelClass.optimize()
|
|
|
|
# it will stem faithful -> faith b/c we use the porter tokenizer
|
|
pq = (ModelClass
|
|
.select()
|
|
.where(ModelClass.match('faith'))
|
|
.order_by(ModelClass.docid))
|
|
self.assertMessages(pq, range(len(self.messages)))
|
|
|
|
pq = (ModelClass
|
|
.select()
|
|
.where(ModelClass.match('believe'))
|
|
.order_by(ModelClass.docid))
|
|
self.assertMessages(pq, [0, 3])
|
|
|
|
pq = (ModelClass
|
|
.select()
|
|
.where(ModelClass.match('thin*'))
|
|
.order_by(ModelClass.docid))
|
|
self.assertMessages(pq, [2, 4])
|
|
|
|
pq = (ModelClass
|
|
.select()
|
|
.where(ModelClass.match('"it is"'))
|
|
.order_by(ModelClass.docid))
|
|
self.assertMessages(pq, [2, 3])
|
|
|
|
pq = ModelClass.search('things', with_score=True)
|
|
self.assertEqual([(x.message, x.score) for x in pq], [
|
|
(self.messages[4], -2.0 / 3),
|
|
(self.messages[2], -1.0 / 3),
|
|
])
|
|
|
|
pq = (ModelClass
|
|
.select(ModelClass.rank())
|
|
.where(ModelClass.match('faithful'))
|
|
.tuples())
|
|
self.assertEqual([x[0] for x in pq], [-.2] * 5)
|
|
|
|
pq = (ModelClass
|
|
.search('faithful', with_score=True)
|
|
.dicts())
|
|
self.assertEqual([x['score'] for x in pq], [-.2] * 5)
|
|
|
|
def test_fts_auto_model(self):
|
|
self._test_fts_auto(ContentPost)
|
|
|
|
def test_fts_auto_field(self):
|
|
self._test_fts_auto(ContentPostMessage)
|
|
|
|
def test_weighting(self):
|
|
self._create_multi_column()
|
|
def assertResults(method, term, weights, expected):
|
|
results = [
|
|
(x.c4, round(x.score, 2))
|
|
for x in method(term, weights=weights, with_score=True)]
|
|
self.assertEqual(results, expected)
|
|
|
|
assertResults(MultiColumn.search, 'bbbbb', None, [
|
|
(2, -1.5), # 1/2 + 1/1
|
|
(1, -0.5), # 1/2
|
|
])
|
|
assertResults(MultiColumn.search, 'bbbbb', [1., 5., 0.], [
|
|
(2, -5.5), # 1/2 + (5 * 1/1)
|
|
(1, -0.5), # 1/2 + (5 * 0)
|
|
])
|
|
assertResults(MultiColumn.search, 'bbbbb', [1., .5, 0.], [
|
|
(2, -1.), # 1/2 + (.5 * 1/1)
|
|
(1, -0.5), # 1/2 + (.5 * 0)
|
|
])
|
|
assertResults(MultiColumn.search, 'bbbbb', [1., -1., 0.], [
|
|
(1, -0.5), # 1/2 + (-1 * 0)
|
|
(2, 0.5), # 1/2 + (-1 * 1/1)
|
|
])
|
|
|
|
# BM25
|
|
assertResults(MultiColumn.search_bm25, 'bbbbb', None, [
|
|
(2, -0.85),
|
|
(1, -0.)])
|
|
assertResults(MultiColumn.search_bm25, 'bbbbb', [1., 5., 0.], [
|
|
(2, -4.24),
|
|
(1, -0.)])
|
|
assertResults(MultiColumn.search_bm25, 'bbbbb', [1., .5, 0.], [
|
|
(2, -0.42),
|
|
(1, -0.)])
|
|
assertResults(MultiColumn.search_bm25, 'bbbbb', [1., -1., 0.], [
|
|
(1, -0.),
|
|
(2, 0.85)])
|
|
|
|
def test_fts_match_single_column(self):
|
|
data = (
|
|
('m1c1 aaaa', 'm1c2 bbbb', 'm1c3 cccc'),
|
|
('m2c1 dddd', 'm2c2 eeee', 'm2c3 ffff'),
|
|
('m3c1 cccc', 'm3c2 bbbb', 'm3c3 aaaa'),
|
|
)
|
|
for c1, c2, c3 in data:
|
|
MultiColumn.create(c1=c1, c2=c2, c3=c3, c4=0)
|
|
|
|
def assertSearch(field, value, expected):
|
|
query = (MultiColumn
|
|
.select()
|
|
.where(field.match(value))
|
|
.order_by(MultiColumn.c1))
|
|
self.assertEqual([mc.c1[:2] for mc in query], expected)
|
|
|
|
assertSearch(MultiColumn.c1, 'aaaa', ['m1'])
|
|
assertSearch(MultiColumn.c1, 'bbbb', [])
|
|
assertSearch(MultiColumn.c1, 'cccc', ['m3'])
|
|
assertSearch(MultiColumn.c2, 'bbbb', ['m1', 'm3'])
|
|
assertSearch(MultiColumn.c2, 'eeee', ['m2'])
|
|
assertSearch(MultiColumn.c3, 'cccc', ['m1'])
|
|
assertSearch(MultiColumn.c3, 'aaaa', ['m3'])
|
|
|
|
def test_fts_score_single_column(self):
|
|
data = (
|
|
('m1c1 aaaa', 'm1c2 bbbb', 'm1c3 cccc'),
|
|
('m2c1 dddd', 'm2c2 eeee', 'm2c3 ffff'),
|
|
('m3c1 cccc', 'm3c2 bbbb aaaa', 'm3c3 aaaa aaaa'),
|
|
)
|
|
for c1, c2, c3 in data:
|
|
MultiColumn.create(c1=c1, c2=c2, c3=c3, c4=0)
|
|
|
|
def assertQueryScore(field, search_term, expected, *weights):
|
|
rank = MultiColumn.bm25(*weights)
|
|
query = (MultiColumn
|
|
.select(MultiColumn, rank.alias('score'))
|
|
.where(field.match(search_term))
|
|
.order_by(rank))
|
|
results = [(r.c1[:2], round(r.score, 2)) for r in query]
|
|
self.assertEqual(results, expected)
|
|
|
|
assertQueryScore(MultiColumn.c1, 'aaaa', [('m1', -0.51)])
|
|
assertQueryScore(MultiColumn.c1, 'dddd', [('m2', -0.51)])
|
|
assertQueryScore(MultiColumn.c2, 'bbbb', [('m1', -0.), ('m3', -0.)])
|
|
assertQueryScore(MultiColumn.c2, 'eeee', [('m2', -0.51)])
|
|
assertQueryScore(MultiColumn.c3, 'aaaa', [('m3', -0.62)])
|
|
|
|
assertQueryScore(MultiColumn.c1, 'aaaa', [('m1', -1.02)], 2., 0., 0.)
|
|
assertQueryScore(MultiColumn.c2, 'bbbb', [('m1', -0.), ('m3', -0.)],
|
|
0., 1.0, 0.)
|
|
assertQueryScore(MultiColumn.c2, 'eeee', [('m2', -1.02)], 0., 2., 0.)
|
|
assertQueryScore(MultiColumn.c3, 'aaaa', [('m3', -0.31)], 0., 1., 0.5)
|
|
|
|
@skip_unless(compile_option('enable_fts4'))
|
|
@requires_models(MultiColumn)
|
|
def test_match_column_queries(self):
|
|
data = (
|
|
('alpha one', 'apple aspires to ace artsy beta launch'),
|
|
('beta two', 'beta boasts better broadcast over apple'),
|
|
('gamma three', 'gold gray green gamma ray delta data'),
|
|
('delta four', 'delta data indicates downturn for apple beta'),
|
|
)
|
|
MC = MultiColumn
|
|
|
|
for i, (title, message) in enumerate(data):
|
|
MC.create(c1=title, c2=message, c3='', c4=i)
|
|
|
|
def assertQ(expr, idxscore):
|
|
q = (MC
|
|
.select(MC, MC.bm25().alias('score'))
|
|
.where(expr)
|
|
.order_by(SQL('score'), MC.c4))
|
|
self.assertEqual([(r.c4, round(r.score, 2)) for r in q], idxscore)
|
|
|
|
# Single whitespace does not affect the mapping of col->term. We can
|
|
# also store the column value in quotes if single-quotes are used.
|
|
assertQ(MC.match('beta'), [(1, -0.85), (0, -0.), (3, -0.)])
|
|
assertQ(MC.match('c1:beta'), [(1, -0.85)])
|
|
assertQ(MC.match('c1: beta'), [(1, -0.85)])
|
|
assertQ(MC.match('c1: ^bet*'), [(1, -0.85)])
|
|
assertQ(MC.match('c1: \'beta\''), [(1, -0.85)])
|
|
assertQ(MC.match('"beta"'), [(1, -0.85), (0, -0.), (3, -0.)])
|
|
|
|
# Alternatively, just specify the column explicitly.
|
|
assertQ(MC.c1.match('beta'), [(1, -0.85)])
|
|
assertQ(MC.c1.match(' beta '), [(1, -0.85)])
|
|
assertQ(MC.c1.match('"beta"'), [(1, -0.85)])
|
|
assertQ(MC.c1.match('"^bet*"'), [(1, -0.85)])
|
|
|
|
# apple beta delta gamma
|
|
# 0 | alpha | X X
|
|
# 1 | beta | X X
|
|
# 2 | gamma | X X
|
|
# 3 | delta | X X X
|
|
#
|
|
assertQ(MC.match('delta NOT gamma'), [(3, -0.85)])
|
|
assertQ(MC.match('delta NOT c2:gamma'), [(3, -0.85)])
|
|
assertQ(MC.match('"delta"'), [(3, -0.85), (2, -0.)])
|
|
assertQ(MC.match('c1:delta OR c2:delta'), [(3, -0.85), (2, -0.)])
|
|
assertQ(MC.match('"^delta"'), [(3, -1.69)])
|
|
|
|
assertQ(MC.match('(delta AND c2:apple) OR c1:alpha'),
|
|
[(3, -0.85), (0, -0.85)])
|
|
assertQ(MC.match('(c2:delta AND c2:apple) OR c1:alpha'),
|
|
[(0, -0.85), (3, -0.)])
|
|
assertQ(MC.match('c2:delta c2:apple OR c1:alpha'),
|
|
[(0, -0.85), (3, -0.)])
|
|
assertQ(MC.match('(c2:delta AND c2:apple) OR beta'),
|
|
[(1, -0.85), (3, -0.), (0, -0.)])
|
|
assertQ(MC.match('c2:delta AND (c2:apple OR c1:alpha)'),
|
|
[(3, -0.)])
|
|
|
|
# c2 apple (0,1,3) OR (...irrelevant...).
|
|
assertQ(MC.match('c2:apple OR c1:alpha NOT delta'),
|
|
[(0, -0.85), (1, -0.), (3, -0.)])
|
|
assertQ(MC.match('c2:apple OR (c1:alpha NOT c2:delta)'),
|
|
[(0, -0.85), (1, -0.), (3, -0.)])
|
|
# c2 apple OR c1 alpha (0, 1, 3) AND NOT delta (2, 3) -> (0, 1).
|
|
assertQ(MC.match('(c2:apple OR c1:alpha) NOT delta'),
|
|
[(0, -0.85), (1, -0.)])
|
|
|
|
|
|
@skip_unless(CYTHON_EXTENSION, 'requires _sqlite_udf c extension')
|
|
class TestFullTextSearchCython(TestFullTextSearch):
|
|
def test_bm25f(self):
|
|
def assertResults(term, expected):
|
|
query = MultiColumn.search_bm25f(term, [1.0, 0, 0, 0], True)
|
|
self.assertEqual(
|
|
[(mc.c4, round(mc.score, 2)) for mc in query],
|
|
expected)
|
|
|
|
self._create_multi_column()
|
|
MultiColumn.create(c1='aaaaa fffff', c4=5)
|
|
|
|
assertResults('aaaaa', [(5, -0.76), (1, -0.62)])
|
|
assertResults('fffff', [(5, -0.76), (3, -0.65)])
|
|
assertResults('eeeee', [(2, -2.13)])
|
|
|
|
# No column specified, use the first text field.
|
|
query = MultiColumn.search_bm25f('aaaaa OR fffff', [1., 3., 0, 0], 1)
|
|
self.assertEqual([(mc.c4, round(mc.score, 2)) for mc in query], [
|
|
(1, -14.18),
|
|
(5, -12.01),
|
|
(3, -11.48)])
|
|
|
|
def test_lucene(self):
|
|
for message in self.messages:
|
|
Document.create(message=message)
|
|
|
|
def assertResults(term, expected, sort_cleaned=False):
|
|
query = Document.search_lucene(term, with_score=True)
|
|
cleaned = [
|
|
(round(doc.score, 3), ' '.join(doc.message.split()[:2]))
|
|
for doc in query]
|
|
if sort_cleaned:
|
|
cleaned = sorted(cleaned)
|
|
self.assertEqual(cleaned, expected)
|
|
|
|
assertResults('things', [
|
|
(-0.166, 'Faith has'),
|
|
(-0.137, 'Be faithful')])
|
|
|
|
assertResults('faith', [
|
|
(0.036, 'All who'),
|
|
(0.042, 'Faith has'),
|
|
(0.047, 'A faith'),
|
|
(0.049, 'Be faithful'),
|
|
(0.049, 'Faith consists')], sort_cleaned=True)
|
|
|
|
|
|
@skip_unless(FTS5Model.fts5_installed(), 'requires fts5')
|
|
class TestFTS5(BaseFTSTestCase, ModelTestCase):
|
|
database = database
|
|
requires = [FTS5Test]
|
|
test_corpus = (
|
|
('foo aa bb', 'aa bb cc ' * 10, 1),
|
|
('bar bb cc', 'bb cc dd ' * 9, 2),
|
|
('baze cc dd', 'cc dd ee ' * 8, 3),
|
|
('nug aa dd', 'bb cc ' * 7, 4))
|
|
|
|
def setUp(self):
|
|
super(TestFTS5, self).setUp()
|
|
for title, data, misc in self.test_corpus:
|
|
FTS5Test.create(title=title, data=data, misc=misc)
|
|
|
|
def test_create_table(self):
|
|
query = FTS5Test._schema._create_table()
|
|
self.assertSQL(query, (
|
|
'CREATE VIRTUAL TABLE IF NOT EXISTS "fts5_test" USING fts5 '
|
|
'("title", "data", "misc" UNINDEXED)'), [])
|
|
|
|
def test_custom_fts5_command(self):
|
|
merge_sql = FTS5Test._fts_cmd_sql('merge', rank=4)
|
|
self.assertSQL(merge_sql, (
|
|
'INSERT INTO "fts5_test" ("fts5_test", "rank") VALUES (?, ?)'),
|
|
['merge', 4])
|
|
FTS5Test.merge(4) # Runs without error.
|
|
|
|
FTS5Test.insert_many([{'title': 'k%08d' % i, 'data': 'v%08d' % i}
|
|
for i in range(100)]).execute()
|
|
|
|
FTS5Test.integrity_check(rank=0)
|
|
FTS5Test.optimize()
|
|
|
|
def test_create_table_options(self):
|
|
class Test1(FTS5Model):
|
|
f1 = SearchField()
|
|
f2 = SearchField(unindexed=True)
|
|
f3 = SearchField()
|
|
|
|
class Meta:
|
|
database = self.database
|
|
options = {
|
|
'prefix': (2, 3),
|
|
'tokenize': 'porter unicode61',
|
|
'content': Post,
|
|
'content_rowid': Post.id}
|
|
|
|
query = Test1._schema._create_table()
|
|
self.assertSQL(query, (
|
|
'CREATE VIRTUAL TABLE IF NOT EXISTS "test1" USING fts5 ('
|
|
'"f1", "f2" UNINDEXED, "f3", '
|
|
'content="post", content_rowid="id", '
|
|
'prefix=\'2,3\', tokenize="porter unicode61")'), [])
|
|
|
|
def assertResults(self, query, expected, scores=False, alias='score'):
|
|
if scores:
|
|
results = [(obj.title, round(getattr(obj, alias), 7))
|
|
for obj in query]
|
|
else:
|
|
results = [obj.title for obj in query]
|
|
self.assertEqual(results, expected)
|
|
|
|
def test_search(self):
|
|
query = FTS5Test.search('bb')
|
|
self.assertSQL(query, (
|
|
'SELECT "t1"."rowid", "t1"."title", "t1"."data", "t1"."misc" '
|
|
'FROM "fts5_test" AS "t1" '
|
|
'WHERE ("fts5_test" MATCH ?) ORDER BY rank'), ['bb'])
|
|
self.assertResults(query, ['nug aa dd', 'foo aa bb', 'bar bb cc'])
|
|
|
|
self.assertResults(FTS5Test.search('baze OR dd'),
|
|
['baze cc dd', 'bar bb cc', 'nug aa dd'])
|
|
|
|
@requires_models(FTS5Document)
|
|
def test_fts_manual(self):
|
|
messages = [FTS5Document.create(message=message)
|
|
for message in self.messages]
|
|
query = (FTS5Document
|
|
.select()
|
|
.where(FTS5Document.match('believe'))
|
|
.order_by(FTS5Document.rowid))
|
|
self.assertMessages(query, [0, 3])
|
|
|
|
query = FTS5Document.search('believe')
|
|
self.assertMessages(query, [3, 0])
|
|
|
|
# Test SQLite's built-in ranking algorithm (bm25). The results should
|
|
# be comparable to our user-defined implementation.
|
|
query = FTS5Document.search('things', with_score=True)
|
|
self.assertEqual([(d.message, round(d.score, 2)) for d in query], [
|
|
(self.messages[4], -0.45),
|
|
(self.messages[2], -0.37)])
|
|
|
|
# Another test of bm25 ranking.
|
|
query = FTS5Document.search_bm25('believe', with_score=True)
|
|
self.assertEqual([(d.message, round(d.score, 2)) for d in query], [
|
|
(self.messages[3], -0.49),
|
|
(self.messages[0], -0.36)])
|
|
|
|
query = FTS5Document.search_bm25('god faith', with_score=True)
|
|
self.assertEqual([(d.message, round(d.score, 2)) for d in query], [
|
|
(self.messages[1], -0.93)])
|
|
|
|
query = FTS5Document.search_bm25('"it is"', with_score=True)
|
|
self.assertEqual([(d.message, round(d.score, 2)) for d in query], [
|
|
(self.messages[2], -0.37),
|
|
(self.messages[3], -0.37)])
|
|
|
|
def test_match_column_queries(self):
|
|
data = (
|
|
('alpha one', 'apple aspires to ace artsy beta launch'),
|
|
('beta two', 'beta boasts better broadcast over apple'),
|
|
('gamma three', 'gold gray green gamma ray delta data'),
|
|
('delta four', 'delta data indicates downturn for apple beta'),
|
|
)
|
|
FT = FTS5Test
|
|
|
|
for i, (title, message) in enumerate(data):
|
|
FT.create(title=title, data=message, misc=str(i))
|
|
|
|
def assertQ(expr, idxscore):
|
|
q = (FT
|
|
.select(FT, FT.bm25().alias('score'))
|
|
.where(expr)
|
|
.order_by(SQL('score'), FT.misc.cast('int')))
|
|
self.assertEqual([(int(r.misc), round(r.score, 2)) for r in q],
|
|
idxscore)
|
|
|
|
# Single whitespace does not affect the mapping of col->term. We can
|
|
# also store the column value in quotes if single-quotes are used.
|
|
assertQ(FT.match('beta'), [(1, -0.74), (0, -0.57), (3, -0.57)])
|
|
assertQ(FT.match('title: beta'), [(1, -2.08)])
|
|
assertQ(FT.match('title: ^bet*'), [(1, -2.08)])
|
|
assertQ(FT.match('title: "beta"'), [(1, -2.08)])
|
|
assertQ(FT.match('"beta"'), [(1, -0.74), (0, -0.57), (3, -0.57)])
|
|
|
|
# Alternatively, just specify the column explicitly.
|
|
assertQ(FT.title.match('beta'), [(1, -2.08)])
|
|
assertQ(FT.title.match(' beta '), [(1, -2.08)])
|
|
assertQ(FT.title.match('"beta"'), [(1, -2.08)])
|
|
assertQ(FT.title.match('^bet*'), [(1, -2.08)])
|
|
assertQ(FT.title.match('"^bet*"'), []) # No wildcards in quotes!
|
|
|
|
# apple beta delta gamma
|
|
# 0 | alpha | X X
|
|
# 1 | beta | X X
|
|
# 2 | gamma | X X
|
|
# 3 | delta | X X X
|
|
#
|
|
assertQ(FT.match('delta NOT gamma'), [(3, -1.53)])
|
|
assertQ(FT.match('delta NOT data:gamma'), [(3, -1.53)])
|
|
assertQ(FT.match('"delta"'), [(3, -1.53), (2, -1.2)])
|
|
assertQ(FT.match('title:delta OR data:delta'), [(3, -3.21), (2, -1.2)])
|
|
assertQ(FT.match('"^delta"'), [(3, -1.53), (2, -1.2)]) # Different.
|
|
assertQ(FT.match('^delta'), [(3, -2.57)]) # Different from FTS4.
|
|
|
|
assertQ(FT.match('(delta AND data:apple) OR title:alpha'),
|
|
[(3, -2.09), (0, -2.02)])
|
|
assertQ(FT.match('(data:delta AND data:apple) OR title:alpha'),
|
|
[(0, -2.02), (3, -1.76)])
|
|
assertQ(FT.match('data:delta data:apple OR title:alpha'),
|
|
[(0, -2.02), (3, -1.76)])
|
|
assertQ(FT.match('(data:delta AND data:apple) OR beta'),
|
|
[(3, -2.33), (1, -0.74), (0, -0.57)])
|
|
assertQ(FT.match('data:delta AND (data:apple OR title:alpha)'),
|
|
[(3, -1.76)])
|
|
|
|
# data apple (0,1,3) OR (...irrelevant...).
|
|
assertQ(FT.match('data:apple OR title:alpha NOT delta'),
|
|
[(0, -2.58), (1, -0.58), (3, -0.57)])
|
|
assertQ(FT.match('data:apple OR (title:alpha NOT data:delta)'),
|
|
[(0, -2.58), (1, -0.58), (3, -0.57)])
|
|
# data apple OR title alpha (0, 1, 3) AND NOT delta (2, 3) -> (0, 1).
|
|
assertQ(FT.match('(data:apple OR title:alpha) NOT delta'),
|
|
[(0, -2.58), (1, -0.58)])
|
|
|
|
def test_highlight_function(self):
|
|
query = (FTS5Test
|
|
.search('dd')
|
|
.select(FTS5Test.title.highlight('[', ']').alias('hi')))
|
|
accum = [row.hi for row in query]
|
|
self.assertEqual(accum, ['baze cc [dd]', 'bar bb cc', 'nug aa [dd]'])
|
|
|
|
query = (FTS5Test
|
|
.search('bb')
|
|
.select(FTS5Test.data.highlight('[', ']').alias('hi')))
|
|
accum = [row.hi[:7] for row in query]
|
|
self.assertEqual(accum, ['[bb] cc', 'aa [bb]', '[bb] cc'])
|
|
|
|
def test_snippet_function(self):
|
|
snip = FTS5Test.data.snippet('[', ']', max_tokens=5).alias('snip')
|
|
query = FTS5Test.search('dd').select(snip)
|
|
accum = [row.snip for row in query]
|
|
self.assertEqual(accum, [
|
|
'cc [dd] ee cc [dd]...',
|
|
'bb cc [dd] bb cc...',
|
|
'bb cc bb cc bb...'])
|
|
|
|
def test_clean_query(self):
|
|
cases = (
|
|
('test', 'test'),
|
|
('"test"', '"test"'),
|
|
('"test\u2022"', '"test\u2022"'),
|
|
('test\u2022', 'test\u2022'),
|
|
('test-', 'test\x1a'),
|
|
('"test-"', '"test-"'),
|
|
('\\"test-', '\x1a test\x1a'),
|
|
('--test--', '\x1a\x1atest\x1a\x1a'),
|
|
('-test- "-test-"', '\x1atest\x1a "-test-"'),
|
|
)
|
|
for a, b in cases:
|
|
self.assertEqual(FTS5Test.clean_query(a), b)
|
|
|
|
|
|
class TestUserDefinedCallbacks(ModelTestCase):
|
|
database = database
|
|
requires = [Post, Values]
|
|
|
|
def test_custom_agg(self):
|
|
data = (
|
|
(1, 3.4, 1.0),
|
|
(1, 6.4, 2.3),
|
|
(1, 4.3, 0.9),
|
|
(2, 3.4, 1.4),
|
|
(3, 2.7, 1.1),
|
|
(3, 2.5, 1.1),
|
|
)
|
|
for klass, value, wt in data:
|
|
Values.create(klass=klass, value=value, weight=wt)
|
|
|
|
vq = (Values
|
|
.select(
|
|
Values.klass,
|
|
fn.weighted_avg(Values.value).alias('wtavg'),
|
|
fn.avg(Values.value).alias('avg'))
|
|
.group_by(Values.klass))
|
|
q_data = [(v.klass, v.wtavg, v.avg) for v in vq]
|
|
self.assertEqual(q_data, [
|
|
(1, 4.7, 4.7),
|
|
(2, 3.4, 3.4),
|
|
(3, 2.6, 2.6)])
|
|
|
|
vq = (Values
|
|
.select(
|
|
Values.klass,
|
|
fn.weighted_avg2(Values.value, Values.weight).alias('wtavg'),
|
|
fn.avg(Values.value).alias('avg'))
|
|
.group_by(Values.klass))
|
|
q_data = [(v.klass, str(v.wtavg)[:4], v.avg) for v in vq]
|
|
self.assertEqual(q_data, [
|
|
(1, '5.23', 4.7),
|
|
(2, '3.4', 3.4),
|
|
(3, '2.6', 2.6)])
|
|
|
|
def test_custom_collation(self):
|
|
for i in [1, 4, 3, 5, 2]:
|
|
Post.create(message='p%d' % i)
|
|
|
|
pq = Post.select().order_by(NodeList((Post.message, SQL('collate collate_reverse'))))
|
|
self.assertEqual([p.message for p in pq], ['p5', 'p4', 'p3', 'p2', 'p1'])
|
|
|
|
def test_collation_decorator(self):
|
|
posts = [Post.create(message=m) for m in ['aaa', 'Aab', 'ccc', 'Bba', 'BbB']]
|
|
pq = Post.select().order_by(collate_case_insensitive.collation(Post.message))
|
|
self.assertEqual([p.message for p in pq], [
|
|
'aaa',
|
|
'Aab',
|
|
'Bba',
|
|
'BbB',
|
|
'ccc'])
|
|
|
|
def test_custom_function(self):
|
|
p1 = Post.create(message='this is a test')
|
|
p2 = Post.create(message='another TEST')
|
|
|
|
sq = Post.select().where(fn.title_case(Post.message) == 'This Is A Test')
|
|
self.assertEqual(list(sq), [p1])
|
|
|
|
sq = Post.select(fn.title_case(Post.message)).tuples()
|
|
self.assertEqual([x[0] for x in sq], [
|
|
'This Is A Test',
|
|
'Another Test',
|
|
])
|
|
|
|
def test_function_decorator(self):
|
|
[Post.create(message=m) for m in ['testing', 'chatting ', ' foo']]
|
|
pq = Post.select(fn.rstrip(Post.message, 'ing')).order_by(Post.id)
|
|
self.assertEqual([x[0] for x in pq.tuples()], [
|
|
'test', 'chatting ', ' foo'])
|
|
|
|
pq = Post.select(fn.rstrip(Post.message, ' ')).order_by(Post.id)
|
|
self.assertEqual([x[0] for x in pq.tuples()], [
|
|
'testing', 'chatting', ' foo'])
|
|
|
|
def test_use_across_connections(self):
|
|
db = get_in_memory_db()
|
|
@db.func()
|
|
def rev(s):
|
|
return s[::-1]
|
|
|
|
db.connect(); db.close(); db.connect()
|
|
curs = db.execute_sql('select rev(?)', ('hello',))
|
|
self.assertEqual(curs.fetchone(), ('olleh',))
|
|
|
|
|
|
class TestRowIDField(ModelTestCase):
|
|
database = database
|
|
requires = [RowIDModel]
|
|
|
|
def test_model_meta(self):
|
|
self.assertEqual(RowIDModel._meta.sorted_field_names, ['rowid', 'data'])
|
|
self.assertEqual(RowIDModel._meta.primary_key.name, 'rowid')
|
|
self.assertTrue(RowIDModel._meta.auto_increment)
|
|
|
|
def test_rowid_field(self):
|
|
r1 = RowIDModel.create(data=10)
|
|
self.assertEqual(r1.rowid, 1)
|
|
self.assertEqual(r1.data, 10)
|
|
|
|
r2 = RowIDModel.create(data=20)
|
|
self.assertEqual(r2.rowid, 2)
|
|
self.assertEqual(r2.data, 20)
|
|
|
|
query = RowIDModel.select().where(RowIDModel.rowid == 2)
|
|
self.assertSQL(query, (
|
|
'SELECT "t1"."rowid", "t1"."data" '
|
|
'FROM "row_id_model" AS "t1" '
|
|
'WHERE ("t1"."rowid" = ?)'), [2])
|
|
r_db = query.get()
|
|
self.assertEqual(r_db.rowid, 2)
|
|
self.assertEqual(r_db.data, 20)
|
|
|
|
r_db2 = query.columns(RowIDModel.rowid, RowIDModel.data).get()
|
|
self.assertEqual(r_db2.rowid, 2)
|
|
self.assertEqual(r_db2.data, 20)
|
|
|
|
def test_insert_with_rowid(self):
|
|
RowIDModel.insert({RowIDModel.rowid: 5, RowIDModel.data: 1}).execute()
|
|
self.assertEqual(5, RowIDModel.select(RowIDModel.rowid).first().rowid)
|
|
|
|
def test_insert_many_with_rowid_without_field_validation(self):
|
|
RowIDModel.insert_many([{RowIDModel.rowid: 5, RowIDModel.data: 1}]).execute()
|
|
self.assertEqual(5, RowIDModel.select(RowIDModel.rowid).first().rowid)
|
|
|
|
def test_insert_many_with_rowid_with_field_validation(self):
|
|
RowIDModel.insert_many([{RowIDModel.rowid: 5, RowIDModel.data: 1}]).execute()
|
|
self.assertEqual(5, RowIDModel.select(RowIDModel.rowid).first().rowid)
|
|
|
|
|
|
class TestTransitiveClosure(BaseTestCase):
|
|
def test_model_factory(self):
|
|
class Category(TestModel):
|
|
name = CharField()
|
|
parent = ForeignKeyField('self', null=True)
|
|
|
|
Closure = ClosureTable(Category)
|
|
self.assertEqual(Closure._meta.extension_module, 'transitive_closure')
|
|
self.assertEqual(Closure._meta.columns, {})
|
|
self.assertEqual(Closure._meta.fields, {})
|
|
self.assertFalse(Closure._meta.primary_key)
|
|
self.assertEqual(Closure._meta.options, {
|
|
'idcolumn': 'id',
|
|
'parentcolumn': 'parent_id',
|
|
'tablename': 'category',
|
|
})
|
|
|
|
class Alt(TestModel):
|
|
pk = AutoField()
|
|
ref = ForeignKeyField('self', null=True)
|
|
|
|
Closure = ClosureTable(Alt)
|
|
self.assertEqual(Closure._meta.columns, {})
|
|
self.assertEqual(Closure._meta.fields, {})
|
|
self.assertFalse(Closure._meta.primary_key)
|
|
self.assertEqual(Closure._meta.options, {
|
|
'idcolumn': 'pk',
|
|
'parentcolumn': 'ref_id',
|
|
'tablename': 'alt',
|
|
})
|
|
|
|
class NoForeignKey(TestModel):
|
|
pass
|
|
self.assertRaises(ValueError, ClosureTable, NoForeignKey)
|
|
|
|
|
|
class BaseExtModel(TestModel):
|
|
class Meta:
|
|
database = database
|
|
|
|
|
|
@skip_unless(CLOSURE_EXTENSION, 'requires closure table extension')
|
|
class TestTransitiveClosureManyToMany(BaseTestCase):
|
|
def setUp(self):
|
|
super(TestTransitiveClosureManyToMany, self).setUp()
|
|
database.load_extension(CLOSURE_EXTENSION.rstrip('.so'))
|
|
database.close()
|
|
|
|
def tearDown(self):
|
|
super(TestTransitiveClosureManyToMany, self).tearDown()
|
|
database.unload_extension(CLOSURE_EXTENSION.rstrip('.so'))
|
|
database.close()
|
|
|
|
def test_manytomany(self):
|
|
class Person(BaseExtModel):
|
|
name = CharField()
|
|
|
|
class Relationship(BaseExtModel):
|
|
person = ForeignKeyField(Person)
|
|
relation = ForeignKeyField(Person, backref='related_to')
|
|
|
|
PersonClosure = ClosureTable(
|
|
Person,
|
|
referencing_class=Relationship,
|
|
foreign_key=Relationship.relation,
|
|
referencing_key=Relationship.person)
|
|
|
|
database.drop_tables([Person, Relationship, PersonClosure], safe=True)
|
|
database.create_tables([Person, Relationship, PersonClosure])
|
|
|
|
c = Person.create(name='charlie')
|
|
m = Person.create(name='mickey')
|
|
h = Person.create(name='huey')
|
|
z = Person.create(name='zaizee')
|
|
Relationship.create(person=c, relation=h)
|
|
Relationship.create(person=c, relation=m)
|
|
Relationship.create(person=h, relation=z)
|
|
Relationship.create(person=h, relation=m)
|
|
|
|
def assertPeople(query, expected):
|
|
self.assertEqual(sorted([p.name for p in query]), expected)
|
|
|
|
PC = PersonClosure
|
|
assertPeople(PC.descendants(c), [])
|
|
assertPeople(PC.ancestors(c), ['huey', 'mickey', 'zaizee'])
|
|
assertPeople(PC.siblings(c), ['huey'])
|
|
|
|
assertPeople(PC.descendants(h), ['charlie'])
|
|
assertPeople(PC.ancestors(h), ['mickey', 'zaizee'])
|
|
assertPeople(PC.siblings(h), ['charlie'])
|
|
|
|
assertPeople(PC.descendants(z), ['charlie', 'huey'])
|
|
assertPeople(PC.ancestors(z), [])
|
|
assertPeople(PC.siblings(z), [])
|
|
|
|
|
|
@skip_unless(CLOSURE_EXTENSION and os.path.exists(CLOSURE_EXTENSION),
|
|
'requires closure extension')
|
|
class TestTransitiveClosureIntegration(BaseTestCase):
|
|
tree = {
|
|
'books': [
|
|
{'fiction': [
|
|
{'scifi': [
|
|
{'hard scifi': []},
|
|
{'dystopian': []}]},
|
|
{'westerns': []},
|
|
{'classics': []},
|
|
]},
|
|
{'non-fiction': [
|
|
{'biographies': []},
|
|
{'essays': []},
|
|
]},
|
|
]
|
|
}
|
|
|
|
def setUp(self):
|
|
super(TestTransitiveClosureIntegration, self).setUp()
|
|
database.load_extension(CLOSURE_EXTENSION.rstrip('.so'))
|
|
database.close()
|
|
|
|
def tearDown(self):
|
|
super(TestTransitiveClosureIntegration, self).tearDown()
|
|
database.unload_extension(CLOSURE_EXTENSION.rstrip('.so'))
|
|
database.close()
|
|
|
|
def initialize_models(self):
|
|
class Category(BaseExtModel):
|
|
name = CharField()
|
|
parent = ForeignKeyField('self', null=True)
|
|
@classmethod
|
|
def g(cls, name):
|
|
return cls.get(cls.name == name)
|
|
|
|
Closure = ClosureTable(Category)
|
|
database.drop_tables([Category, Closure], safe=True)
|
|
database.create_tables([Category, Closure])
|
|
|
|
def build_tree(nodes, parent=None):
|
|
for name, subnodes in nodes.items():
|
|
category = Category.create(name=name, parent=parent)
|
|
if subnodes:
|
|
for subnode in subnodes:
|
|
build_tree(subnode, category)
|
|
|
|
build_tree(self.tree)
|
|
return Category, Closure
|
|
|
|
def assertNodes(self, query, *expected):
|
|
self.assertEqual(
|
|
set([category.name for category in query]),
|
|
set(expected))
|
|
|
|
def test_build_tree(self):
|
|
Category, Closure = self.initialize_models()
|
|
self.assertEqual(Category.select().count(), 10)
|
|
|
|
def test_descendants(self):
|
|
Category, Closure = self.initialize_models()
|
|
books = Category.g('books')
|
|
self.assertNodes(
|
|
Closure.descendants(books),
|
|
'fiction', 'scifi', 'hard scifi', 'dystopian',
|
|
'westerns', 'classics', 'non-fiction', 'biographies', 'essays')
|
|
|
|
self.assertNodes(Closure.descendants(books, 0), 'books')
|
|
self.assertNodes(
|
|
Closure.descendants(books, 1), 'fiction', 'non-fiction')
|
|
self.assertNodes(
|
|
Closure.descendants(books, 2),
|
|
'scifi', 'westerns', 'classics', 'biographies', 'essays')
|
|
self.assertNodes(
|
|
Closure.descendants(books, 3), 'hard scifi', 'dystopian')
|
|
|
|
fiction = Category.g('fiction')
|
|
self.assertNodes(
|
|
Closure.descendants(fiction),
|
|
'scifi', 'hard scifi', 'dystopian', 'westerns', 'classics')
|
|
self.assertNodes(
|
|
Closure.descendants(fiction, 1),
|
|
'scifi', 'westerns', 'classics')
|
|
self.assertNodes(
|
|
Closure.descendants(fiction, 2), 'hard scifi', 'dystopian')
|
|
|
|
self.assertNodes(
|
|
Closure.descendants(Category.g('scifi')),
|
|
'hard scifi', 'dystopian')
|
|
self.assertNodes(
|
|
Closure.descendants(Category.g('scifi'), include_node=True),
|
|
'scifi', 'hard scifi', 'dystopian')
|
|
self.assertNodes(Closure.descendants(Category.g('hard scifi'), 1))
|
|
|
|
def test_ancestors(self):
|
|
Category, Closure = self.initialize_models()
|
|
|
|
hard_scifi = Category.g('hard scifi')
|
|
self.assertNodes(
|
|
Closure.ancestors(hard_scifi),
|
|
'scifi', 'fiction', 'books')
|
|
self.assertNodes(
|
|
Closure.ancestors(hard_scifi, include_node=True),
|
|
'hard scifi', 'scifi', 'fiction', 'books')
|
|
self.assertNodes(Closure.ancestors(hard_scifi, 2), 'fiction')
|
|
self.assertNodes(Closure.ancestors(hard_scifi, 3), 'books')
|
|
|
|
non_fiction = Category.g('non-fiction')
|
|
self.assertNodes(Closure.ancestors(non_fiction), 'books')
|
|
self.assertNodes(Closure.ancestors(non_fiction, include_node=True),
|
|
'non-fiction', 'books')
|
|
self.assertNodes(Closure.ancestors(non_fiction, 1), 'books')
|
|
|
|
books = Category.g('books')
|
|
self.assertNodes(Closure.ancestors(books, include_node=True),
|
|
'books')
|
|
self.assertNodes(Closure.ancestors(books))
|
|
self.assertNodes(Closure.ancestors(books, 1))
|
|
|
|
def test_siblings(self):
|
|
Category, Closure = self.initialize_models()
|
|
|
|
self.assertNodes(
|
|
Closure.siblings(Category.g('hard scifi')), 'dystopian')
|
|
self.assertNodes(
|
|
Closure.siblings(Category.g('hard scifi'), include_node=True),
|
|
'hard scifi', 'dystopian')
|
|
self.assertNodes(
|
|
Closure.siblings(Category.g('classics')), 'scifi', 'westerns')
|
|
self.assertNodes(
|
|
Closure.siblings(Category.g('classics'), include_node=True),
|
|
'scifi', 'westerns', 'classics')
|
|
self.assertNodes(
|
|
Closure.siblings(Category.g('fiction')), 'non-fiction')
|
|
|
|
def test_tree_changes(self):
|
|
Category, Closure = self.initialize_models()
|
|
books = Category.g('books')
|
|
fiction = Category.g('fiction')
|
|
dystopian = Category.g('dystopian')
|
|
essays = Category.g('essays')
|
|
new_root = Category.create(name='products')
|
|
Category.create(name='magazines', parent=new_root)
|
|
books.parent = new_root
|
|
books.save()
|
|
dystopian.delete_instance()
|
|
essays.parent = books
|
|
essays.save()
|
|
Category.create(name='rants', parent=essays)
|
|
Category.create(name='poetry', parent=books)
|
|
|
|
query = (Category
|
|
.select(Category.name, Closure.depth)
|
|
.join(Closure, on=(Category.id == Closure.id))
|
|
.where(Closure.root == new_root)
|
|
.order_by(Closure.depth, Category.name)
|
|
.tuples())
|
|
self.assertEqual(list(query), [
|
|
('products', 0),
|
|
('books', 1),
|
|
('magazines', 1),
|
|
('essays', 2),
|
|
('fiction', 2),
|
|
('non-fiction', 2),
|
|
('poetry', 2),
|
|
('biographies', 3),
|
|
('classics', 3),
|
|
('rants', 3),
|
|
('scifi', 3),
|
|
('westerns', 3),
|
|
('hard scifi', 4),
|
|
])
|
|
|
|
def test_id_not_overwritten(self):
|
|
class Node(BaseExtModel):
|
|
parent = ForeignKeyField('self', null=True)
|
|
name = CharField()
|
|
|
|
NodeClosure = ClosureTable(Node)
|
|
database.create_tables([Node, NodeClosure], safe=True)
|
|
|
|
root = Node.create(name='root')
|
|
c1 = Node.create(name='c1', parent=root)
|
|
c2 = Node.create(name='c2', parent=root)
|
|
|
|
query = NodeClosure.descendants(root)
|
|
self.assertEqual(sorted([(n.id, n.name) for n in query]),
|
|
[(c1.id, 'c1'), (c2.id, 'c2')])
|
|
database.drop_tables([Node, NodeClosure])
|
|
|
|
|
|
@skip_unless(json_installed(), 'requires json1 sqlite extension')
|
|
class TestJsonContains(ModelTestCase):
|
|
database = SqliteExtDatabase(':memory:', json_contains=True)
|
|
requires = [KeyData]
|
|
test_data = (
|
|
('a', {'k1': 'v1', 'k2': 'v2', 'k3': 'v3'}),
|
|
('b', {'k2': 'v2', 'k3': 'v3', 'k4': 'v4'}),
|
|
('c', {'k3': 'v3', 'x1': {'y1': 'z1', 'y2': 'z2'}}),
|
|
('d', {'k4': 'v4', 'x1': {'y2': 'z2', 'y3': [0, 1, 2]}}),
|
|
('e', ['foo', 'bar', [0, 1, 2]]),
|
|
)
|
|
|
|
def setUp(self):
|
|
super(TestJsonContains, self).setUp()
|
|
with self.database.atomic():
|
|
for key, data in self.test_data:
|
|
KeyData.create(key=key, data=data)
|
|
|
|
def assertContains(self, obj, expected):
|
|
contains = fn.json_contains(KeyData.data, json.dumps(obj))
|
|
query = (KeyData
|
|
.select(KeyData.key)
|
|
.where(contains)
|
|
.order_by(KeyData.key)
|
|
.namedtuples())
|
|
self.assertEqual([m.key for m in query], expected)
|
|
|
|
def test_json_contains(self):
|
|
# Simple checks for key.
|
|
self.assertContains('k1', ['a'])
|
|
self.assertContains('k2', ['a', 'b'])
|
|
self.assertContains('k3', ['a', 'b', 'c'])
|
|
self.assertContains('kx', [])
|
|
self.assertContains('y1', [])
|
|
|
|
# Partial dictionary.
|
|
self.assertContains({'k1': 'v1'}, ['a'])
|
|
self.assertContains({'k2': 'v2'}, ['a', 'b'])
|
|
self.assertContains({'k3': 'v3'}, ['a', 'b', 'c'])
|
|
self.assertContains({'k2': 'v2', 'k3': 'v3'}, ['a', 'b'])
|
|
|
|
self.assertContains({'k2': 'vx'}, [])
|
|
self.assertContains({'k2': 'v2', 'k3': 'vx'}, [])
|
|
self.assertContains({'y1': 'z1'}, [])
|
|
|
|
# List, interpreted as list of keys.
|
|
self.assertContains(['k1', 'k2'], ['a'])
|
|
self.assertContains(['k4'], ['b', 'd'])
|
|
self.assertContains(['kx'], [])
|
|
self.assertContains(['y1'], [])
|
|
|
|
# List, interpreted as ordered list of items.
|
|
self.assertContains(['foo'], ['e'])
|
|
self.assertContains(['foo', 'bar'], ['e'])
|
|
self.assertContains(['bar', 'foo'], [])
|
|
|
|
# Nested dictionaries.
|
|
self.assertContains({'x1': 'y1'}, ['c'])
|
|
self.assertContains({'x1': ['y1']}, ['c'])
|
|
self.assertContains({'x1': {'y1': 'z1'}}, ['c'])
|
|
self.assertContains({'x1': {'y2': 'z2'}}, ['c', 'd'])
|
|
self.assertContains({'x1': {'y2': 'z2'}, 'k4': 'v4'}, ['d'])
|
|
|
|
self.assertContains({'x1': {'yx': 'z1'}}, [])
|
|
self.assertContains({'x1': {'y1': 'z1', 'y3': 'z3'}}, [])
|
|
self.assertContains({'x1': {'y2': 'zx'}}, [])
|
|
self.assertContains({'x1': {'k4': 'v4'}}, [])
|
|
|
|
# Mixing dictionaries and lists.
|
|
self.assertContains({'x1': {'y2': 'z2', 'y3': [0]}}, ['d'])
|
|
self.assertContains({'x1': {'y2': 'z2', 'y3': [0, 1, 2]}}, ['d'])
|
|
|
|
self.assertContains({'x1': {'y2': 'z2', 'y3': [0, 1, 2, 4]}}, [])
|
|
self.assertContains({'x1': {'y2': 'z2', 'y3': [0, 2]}}, [])
|
|
|
|
|
|
class CalendarMonth(TestModel):
|
|
name = TextField()
|
|
value = IntegerField()
|
|
|
|
class CalendarDay(TestModel):
|
|
month = ForeignKeyField(CalendarMonth, backref='days')
|
|
value = IntegerField()
|
|
|
|
|
|
class TestIntWhereChain(ModelTestCase):
|
|
database = database
|
|
requires = [CalendarMonth, CalendarDay]
|
|
|
|
def test_int_where_chain(self):
|
|
with self.database.atomic():
|
|
jan = CalendarMonth.create(name='january', value=1)
|
|
feb = CalendarMonth.create(name='february', value=2)
|
|
CalendarDay.insert_many([{'month': jan, 'value': i + 1}
|
|
for i in range(31)]).execute()
|
|
CalendarDay.insert_many([{'month': feb, 'value': i + 1}
|
|
for i in range(28)]).execute()
|
|
|
|
def assertValues(query, expected):
|
|
self.assertEqual(sorted([d.value for d in query]), list(expected))
|
|
|
|
q = CalendarDay.select().join(CalendarMonth)
|
|
jq = q.where(CalendarMonth.name == 'january')
|
|
jq1 = jq.where(CalendarDay.value >= 25)
|
|
assertValues(jq1, range(25, 32))
|
|
|
|
jq2 = jq1.where(CalendarDay.value < 30)
|
|
assertValues(jq2, range(25, 30))
|
|
|
|
fq = q.where(CalendarMonth.name == 'february')
|
|
fq1 = fq.where(CalendarDay.value >= 25)
|
|
assertValues(fq1, range(25, 29))
|
|
|
|
fq2 = fq1.where(CalendarDay.value < 30)
|
|
assertValues(fq2, range(25, 29))
|
|
|
|
|
|
class Datum(TestModel):
|
|
a = BareField()
|
|
b = BareField(collation='BINARY')
|
|
c = BareField(collation='RTRIM')
|
|
d = BareField(collation='NOCASE')
|
|
|
|
|
|
class TestCollatedFieldDefinitions(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [Datum]
|
|
|
|
def test_collated_fields(self):
|
|
rows = (
|
|
(1, 'abc', 'abc', 'abc ', 'abc'),
|
|
(2, 'abc', 'abc', 'abc', 'ABC'),
|
|
(3, 'abc', 'abc', 'abc ', 'Abc'),
|
|
(4, 'abc', 'abc ', 'ABC', 'abc'))
|
|
for pk, a, b, c, d in rows:
|
|
Datum.create(id=pk, a=a, b=b, c=c, d=d)
|
|
|
|
def assertC(query, expected):
|
|
self.assertEqual([r.id for r in query], expected)
|
|
|
|
base = Datum.select().order_by(Datum.id)
|
|
|
|
# Text comparison a=b is performed using binary collating sequence.
|
|
assertC(base.where(Datum.a == Datum.b), [1, 2, 3])
|
|
|
|
# Text comparison a=b is performed using the RTRIM collating sequence.
|
|
assertC(base.where(Datum.a == Datum.b.collate('RTRIM')), [1, 2, 3, 4])
|
|
|
|
# Text comparison d=a is performed using the NOCASE collating sequence.
|
|
assertC(base.where(Datum.d == Datum.a), [1, 2, 3, 4])
|
|
|
|
# Text comparison a=d is performed using the BINARY collating sequence.
|
|
assertC(base.where(Datum.a == Datum.d), [1, 4])
|
|
|
|
# Text comparison 'abc'=c is performed using RTRIM collating sequence.
|
|
assertC(base.where('abc' == Datum.c), [1, 2, 3])
|
|
|
|
# Text comparison c='abc' is performed using RTRIM collating sequence.
|
|
assertC(base.where(Datum.c == 'abc'), [1, 2, 3])
|
|
|
|
# Grouping is performed using the NOCASE collating sequence (Values
|
|
# 'abc', 'ABC', and 'Abc' are placed in the same group).
|
|
query = Datum.select(fn.COUNT(Datum.id)).group_by(Datum.d)
|
|
self.assertEqual(query.scalar(), 4)
|
|
|
|
# Grouping is performed using the BINARY collating sequence. 'abc' and
|
|
# 'ABC' and 'Abc' form different groups.
|
|
query = Datum.select(fn.COUNT(Datum.id)).group_by(Datum.d.concat(''))
|
|
self.assertEqual([r[0] for r in query.tuples()], [1, 1, 2])
|
|
|
|
# Sorting or column c is performed using the RTRIM collating sequence.
|
|
assertC(base.order_by(Datum.c, Datum.id), [4, 1, 2, 3])
|
|
|
|
# Sorting of (c||'') is performed using the BINARY collating sequence.
|
|
assertC(base.order_by(Datum.c.concat(''), Datum.id), [4, 2, 3, 1])
|
|
|
|
# Sorting of column c is performed using the NOCASE collating sequence.
|
|
assertC(base.order_by(Datum.c.collate('NOCASE'), Datum.id),
|
|
[2, 4, 3, 1])
|
|
|
|
|
|
class TestReadOnly(ModelTestCase):
|
|
database = db_loader('sqlite3')
|
|
|
|
@skip_if(sys.version_info < (3, 4, 0), 'requres python >= 3.4.0')
|
|
@requires_models(User)
|
|
def test_read_only(self):
|
|
User.create(username='foo')
|
|
|
|
db_filename = self.database.database
|
|
db = SqliteDatabase('file:%s?mode=ro' % db_filename, uri=True)
|
|
cursor = db.execute_sql('select username from users')
|
|
self.assertEqual(cursor.fetchone(), ('foo',))
|
|
|
|
self.assertRaises(OperationalError, db.execute_sql,
|
|
'insert into users (username) values (?)', ('huey',))
|
|
|
|
# We cannot create a database if in read-only mode.
|
|
db = SqliteDatabase('file:xx_not_exists.db?mode=ro', uri=True)
|
|
self.assertRaises(OperationalError, db.connect)
|
|
|
|
|
|
class TDecModel(TestModel):
|
|
value = TDecimalField(max_digits=24, decimal_places=16, auto_round=True)
|
|
|
|
|
|
class TestTDecimalField(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [TDecModel]
|
|
|
|
def test_tdecimal_field(self):
|
|
value = D('12345678.0123456789012345')
|
|
value_ov = D('12345678.012345678901234567890123456789')
|
|
|
|
td1 = TDecModel.create(value=value)
|
|
td2 = TDecModel.create(value=value_ov)
|
|
|
|
td1_db = TDecModel.get(TDecModel.id == td1.id)
|
|
self.assertEqual(td1_db.value, value)
|
|
|
|
td2_db = TDecModel.get(TDecModel.id == td2.id)
|
|
self.assertEqual(td2_db.value, D('12345678.0123456789012346'))
|
|
|
|
|
|
class KVR(TestModel):
|
|
key = TextField(primary_key=True)
|
|
value = IntegerField()
|
|
|
|
|
|
@skip_unless(database.server_version >= (3, 35, 0), 'sqlite returning clause required')
|
|
class TestSqliteReturning(ModelTestCase):
|
|
database = database
|
|
requires = [Person, User, KVR]
|
|
|
|
def test_sqlite_returning(self):
|
|
iq = (User
|
|
.insert_many([{'username': 'u%s' % i} for i in range(3)])
|
|
.returning(User.id))
|
|
self.assertEqual([r.id for r in iq.execute()], [1, 2, 3])
|
|
|
|
res = (User
|
|
.insert_many([{'username': 'u%s' % i} for i in (4, 5)])
|
|
.returning(User)
|
|
.execute())
|
|
self.assertEqual([(r.id, r.username) for r in res],
|
|
[(4, 'u4'), (5, 'u5')])
|
|
|
|
# Simple insert returns the ID.
|
|
res = User.insert(username='u6').execute()
|
|
self.assertEqual(res, 6)
|
|
|
|
iq = (User
|
|
.insert_many([{'username': 'u%s' % i} for i in (7, 8, 9)])
|
|
.returning(User)
|
|
.namedtuples())
|
|
curs = iq.execute()
|
|
self.assertEqual([u.id for u in curs], [7, 8, 9])
|
|
|
|
def test_sqlite_on_conflict_returning(self):
|
|
p = Person.create(first='f1', last='l1', dob='1990-01-01')
|
|
self.assertEqual(p.id, 1)
|
|
|
|
iq = Person.insert_many([
|
|
{'first': 'f%s' % i, 'last': 'l%s' %i, 'dob': '1990-01-%02d' % i}
|
|
for i in range(1, 3)])
|
|
iq = iq.on_conflict(conflict_target=[Person.first, Person.last],
|
|
update={'dob': '2000-01-01'})
|
|
p1, p2 = iq.returning(Person).execute()
|
|
|
|
self.assertEqual((p1.first, p1.last), ('f1', 'l1'))
|
|
self.assertEqual(p1.dob, datetime.date(2000, 1, 1))
|
|
self.assertEqual((p2.first, p2.last), ('f2', 'l2'))
|
|
self.assertEqual(p2.dob, datetime.date(1990, 1, 2))
|
|
|
|
p3 = Person.insert(first='f3', last='l3', dob='1990-01-03').execute()
|
|
self.assertEqual(p3, 3)
|
|
|
|
def test_text_pk(self):
|
|
res = KVR.create(key='k1', value=1)
|
|
self.assertEqual((res.key, res.value), ('k1', 1))
|
|
|
|
res = KVR.insert(key='k2', value=2).execute()
|
|
self.assertEqual(res, 2)
|
|
#self.assertEqual(res, 'k2')
|
|
|
|
# insert_many() returns the primary-key as usual.
|
|
iq = (KVR
|
|
.insert_many([{'key': 'k%s' % i, 'value': i} for i in (3, 4)])
|
|
.returning(KVR.key))
|
|
self.assertEqual([r.key for r in iq.execute()], ['k3', 'k4'])
|
|
|
|
iq = KVR.insert_many([{'key': 'k%s' % i, 'value': i} for i in (4, 5)])
|
|
iq = iq.on_conflict(conflict_target=[KVR.key],
|
|
update={KVR.value: KVR.value + 10})
|
|
res = iq.returning(KVR).execute()
|
|
self.assertEqual([(r.key, r.value) for r in res],
|
|
[('k4', 14), ('k5', 5)])
|
|
|
|
res = (KVR
|
|
.update(value=KVR.value + 10)
|
|
.where(KVR.key.in_(['k1', 'k3', 'kx']))
|
|
.returning(KVR)
|
|
.execute())
|
|
self.assertEqual([(r.key, r.value) for r in res],
|
|
[('k1', 11), ('k3', 13)])
|
|
|
|
res = (KVR.delete()
|
|
.where(KVR.key.not_in(['k2', 'k3', 'k4']))
|
|
.returning(KVR)
|
|
.execute())
|
|
self.assertEqual([(r.key, r.value) for r in res],
|
|
[('k1', 11), ('k5', 5)])
|
|
|
|
|
|
@skip_unless(database.server_version >= (3, 35, 0), 'sqlite returning clause required')
|
|
class TestSqliteReturningConfig(ModelTestCase):
|
|
database = SqliteExtDatabase(':memory:', returning_clause=True)
|
|
requires = [KVR, User]
|
|
|
|
def test_pk_set_properly(self):
|
|
user = User.create(username='u1')
|
|
self.assertEqual(user.id, 1)
|
|
|
|
kvr = KVR.create(key='k1', value=1)
|
|
self.assertEqual(kvr.key, 'k1')
|
|
|
|
def test_insert_behavior(self):
|
|
iq = User.insert({'username': 'u1'})
|
|
self.assertEqual(iq.execute(), 1)
|
|
|
|
iq = User.insert_many([{'username': 'u2'}, {'username': 'u3'}])
|
|
self.assertEqual(list(iq.execute()), [(2,), (3,)])
|
|
|
|
# NOTE: sqlite3_changes() does not return the inserted rowcount until
|
|
# the statement has been consumed. The fact that it returned 2 is a
|
|
# side-effect of the statement cache and our having consumed the query
|
|
# in the previous test assertion. So this test is invalid.
|
|
#iq = User.insert_many([('u4',), ('u5',)]).as_rowcount()
|
|
#self.assertEqual(iq.execute(), 2)
|
|
|
|
iq = KVR.insert({'key': 'k1', 'value': 1})
|
|
self.assertEqual(iq.execute(), 'k1')
|
|
|
|
iq = KVR.insert_many([('k2', 2), ('k3', 3)])
|
|
self.assertEqual(list(iq.execute()), [('k2',), ('k3',)])
|
|
|
|
# See note above.
|
|
#iq = KVR.insert_many([('k4', 4), ('k5', 5)]).as_rowcount()
|
|
#self.assertEqual(iq.execute(), 2)
|
|
|
|
def test_insert_on_conflict(self):
|
|
KVR.create(key='k1', value=1)
|
|
iq = (KVR.insert({'key': 'k1', 'value': 100})
|
|
.on_conflict(conflict_target=[KVR.key],
|
|
update={KVR.value: KVR.value + 10}))
|
|
self.assertEqual(iq.execute(), 'k1')
|
|
self.assertEqual(KVR.get(KVR.key == 'k1').value, 11)
|
|
|
|
KVR.create(key='k2', value=2)
|
|
iq = (KVR.insert_many([
|
|
{'key': 'k1', 'value': 100},
|
|
{'key': 'k2', 'value': 200},
|
|
{'key': 'k3', 'value': 300}])
|
|
.on_conflict(conflict_target=[KVR.key],
|
|
update={KVR.value: KVR.value + 10}))
|
|
self.assertEqual(list(iq.execute()), [('k1',), ('k2',), ('k3',)])
|
|
self.assertEqual(sorted(KVR.select().tuples()),
|
|
[('k1', 21), ('k2', 12), ('k3', 300)])
|
|
|
|
def test_update_delete_rowcounts(self):
|
|
users = [User.create(username=u) for u in 'abc']
|
|
kvrs = [KVR.create(key='k%s' % i, value=i) for i in (1, 2, 3)]
|
|
|
|
uq = User.update(username='c2').where(User.username == 'c')
|
|
self.assertEqual(uq.execute(), 1)
|
|
uq = User.update(username=User.username.concat('x'))
|
|
self.assertEqual(uq.execute(), 3)
|
|
|
|
dq = User.delete().where(User.username.in_(['bx', 'c2x']))
|
|
self.assertEqual(dq.execute(), 2)
|
|
|
|
uq = KVR.update(value=KVR.value + 10).where(KVR.key == 'k3')
|
|
self.assertEqual(uq.execute(), 1)
|
|
uq = KVR.update(value=KVR.value + 100)
|
|
self.assertEqual(uq.execute(), 3)
|
|
|
|
dq = KVR.delete().where(KVR.value.in_([102, 113]))
|
|
self.assertEqual(dq.execute(), 2)
|
|
|
|
def test_update_delete_explicit_returning(self):
|
|
users = [User.create(username=u) for u in 'abc']
|
|
|
|
uq = (User.update(username='c2')
|
|
.where(User.username == 'c')
|
|
.returning(User.id, User.username))
|
|
for _ in range(2):
|
|
self.assertEqual([u.username for u in uq.execute()], ['c2'])
|
|
self.assertEqual(list(uq.clone().execute()), [])
|
|
|
|
uq = (User.update(username=User.username.concat('x'))
|
|
.where(~User.username.endswith('x')) # For idempotency.
|
|
.returning(User.id, User.username)
|
|
.tuples())
|
|
for _ in range(2):
|
|
self.assertEqual(sorted(uq.execute()),
|
|
[(1, 'ax'), (2, 'bx'), (3, 'c2x')])
|
|
self.assertEqual(list(uq.clone().execute()), [])
|
|
|
|
dq = User.delete().where(User.username == 'c2x').returning(User)
|
|
for _ in range(2):
|
|
# The result is cached to support multiple iterations.
|
|
self.assertEqual([u.username for u in dq.execute()], ['c2x'])
|
|
self.assertEqual(list(dq.clone().execute()), [])
|
|
|
|
dq = User.delete().returning(User).tuples()
|
|
for _ in range(2):
|
|
# The result is cached to support multiple iterations.
|
|
self.assertEqual(sorted(dq.execute()), [(1, 'ax'), (2, 'bx')])
|
|
self.assertEqual(list(dq.clone().execute()), [])
|
|
|
|
def test_bulk_create_update(self):
|
|
users = [User(username='u%s' % i) for i in range(5)]
|
|
with self.assertQueryCount(1):
|
|
User.bulk_create(users)
|
|
|
|
self.assertEqual(User.select().count(), 5)
|
|
self.assertEqual(sorted(User.select().tuples()), [
|
|
(1, 'u0'), (2, 'u1'), (3, 'u2'), (4, 'u3'), (5, 'u4')])
|
|
|
|
users[0].username = 'u0x'
|
|
users[2].username = 'u2x'
|
|
users[4].username = 'u4x'
|
|
with self.assertQueryCount(1):
|
|
n = User.bulk_update(users, ['username'])
|
|
self.assertEqual(n, 5)
|
|
|
|
self.assertEqual(sorted(User.select().tuples()), [
|
|
(1, 'u0x'), (2, 'u1'), (3, 'u2x'), (4, 'u3'), (5, 'u4x')])
|
|
|
|
@requires_models(User, Tweet)
|
|
def test_fk_set_correctly(self):
|
|
# Ensure FK can be set lazily.
|
|
user = User(username='u1')
|
|
tweet = Tweet(user=user, content='t1')
|
|
user.save()
|
|
tweet.save()
|
|
|
|
|
|
@skip_unless(database.server_version >= (3, 20, 0), 'sqlite deterministic requires >= 3.20')
|
|
@skip_unless(sys.version_info >= (3, 8, 0), 'sqlite deterministic requires Python >= 3.8')
|
|
class TestDeterministicFunction(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
|
|
def test_deterministic(self):
|
|
db = self.database
|
|
@db.func(deterministic=True)
|
|
def pylower(s):
|
|
if s is not None:
|
|
return s.lower()
|
|
|
|
class Reg(db.Model):
|
|
key = TextField()
|
|
class Meta:
|
|
indexes = [
|
|
SQL('create unique index "reg_pylower_key" '
|
|
'on "reg" (pylower("key"))')]
|
|
|
|
db.create_tables([Reg])
|
|
Reg.create(key='k1')
|
|
with self.assertRaises(IntegrityError):
|
|
with db.atomic():
|
|
Reg.create(key='K1')
|
|
|
|
@skip_unless(sys.version_info >= (3, 7, 0), 'isoformat (":") works 3.7+')
|
|
class TestISODateTimeField(ModelTestCase):
|
|
database = get_in_memory_db()
|
|
requires = [DT]
|
|
|
|
def test_aware_datetimes(self):
|
|
class _UTC(datetime.tzinfo):
|
|
def utcoffset(self, dt): return datetime.timedelta(0)
|
|
def tzname(self, dt): return "UTC"
|
|
def dst(self, dt): return datetime.timedelta(0)
|
|
|
|
UTC = _UTC()
|
|
|
|
d1 = datetime.datetime(2026, 1, 2, 3, 4, 5)
|
|
d2 = d1.astimezone(UTC)
|
|
|
|
dt = DT.create(key='k1', d=d1, iso=d2)
|
|
self.assertEqual(dt.d, d1)
|
|
self.assertEqual(dt.iso, d2)
|
|
|
|
dt = DT['k1']
|
|
self.assertEqual(dt.d, d1)
|
|
self.assertEqual(dt.iso, d2)
|
|
|
|
raw = self.database.execute_sql('select * from dt').fetchone()
|
|
self.assertEqual(raw, ('k1', str(d1), d2.isoformat()))
|