Files
peewee/tests/sqlite_udf.py
2026-02-23 15:26:14 -06:00

413 lines
14 KiB
Python

import datetime
import json
import random
from peewee import *
from peewee import sqlite3
from playhouse.sqlite_udf import register_all
from .base import IS_SQLITE_9
from .base import ModelTestCase
from .base import TestModel
from .base import get_sqlite_db
from .base import skip_unless
try:
from playhouse import _sqlite_udf as cython_udf
except ImportError:
cython_udf = None
def requires_cython(method):
return skip_unless(cython_udf is not None,
'requires sqlite udf c extension')(method)
database = get_sqlite_db()
register_all(database)
class User(TestModel):
username = TextField()
class APIResponse(TestModel):
url = TextField(default='')
data = TextField(default='')
timestamp = DateTimeField(default=datetime.datetime.now)
class Generic(TestModel):
value = IntegerField(default=0)
x = Field(null=True)
MODELS = [User, APIResponse, Generic]
class FixedOffset(datetime.tzinfo):
def __init__(self, offset, name, dstoffset=42):
if isinstance(offset, int):
offset = datetime.timedelta(minutes=offset)
if isinstance(dstoffset, int):
dstoffset = datetime.timedelta(minutes=dstoffset)
self.__offset = offset
self.__name = name
self.__dstoffset = dstoffset
def utcoffset(self, dt):
return self.__offset
def tzname(self, dt):
return self.__name
def dst(self, dt):
return self.__dstoffset
class BaseTestUDF(ModelTestCase):
database = database
def sql1(self, sql, *params):
cursor = self.database.execute_sql(sql, params)
return cursor.fetchone()[0]
class TestAggregates(BaseTestUDF):
requires = [Generic]
def _store_values(self, *values):
with self.database.atomic():
for value in values:
Generic.create(x=value)
def mts(self, seconds):
return (datetime.datetime(2015, 1, 1) +
datetime.timedelta(seconds=seconds))
def test_min_avg_tdiff(self):
self.assertEqual(self.sql1('select mintdiff(x) from generic;'), None)
self.assertEqual(self.sql1('select avgtdiff(x) from generic;'), None)
self._store_values(self.mts(10))
self.assertEqual(self.sql1('select mintdiff(x) from generic;'), None)
self.assertEqual(self.sql1('select avgtdiff(x) from generic;'), 0)
self._store_values(self.mts(15))
self.assertEqual(self.sql1('select mintdiff(x) from generic;'), 5)
self.assertEqual(self.sql1('select avgtdiff(x) from generic;'), 5)
self._store_values(
self.mts(22),
self.mts(52),
self.mts(18),
self.mts(41),
self.mts(2),
self.mts(33))
self.assertEqual(self.sql1('select mintdiff(x) from generic;'), 3)
self.assertEqual(
round(self.sql1('select avgtdiff(x) from generic;'), 1),
7.1)
self._store_values(self.mts(22))
self.assertEqual(self.sql1('select mintdiff(x) from generic;'), 0)
def test_duration(self):
self.assertEqual(self.sql1('select duration(x) from generic;'), None)
self._store_values(self.mts(10))
self.assertEqual(self.sql1('select duration(x) from generic;'), 0)
self._store_values(self.mts(15))
self.assertEqual(self.sql1('select duration(x) from generic;'), 5)
self._store_values(
self.mts(22),
self.mts(11),
self.mts(52),
self.mts(18),
self.mts(41),
self.mts(2),
self.mts(33))
self.assertEqual(self.sql1('select duration(x) from generic;'), 50)
@requires_cython
def test_median(self):
self.assertEqual(self.sql1('select median(x) from generic;'), None)
self._store_values(1)
self.assertEqual(self.sql1('select median(x) from generic;'), 1)
self._store_values(3, 6, 6, 6, 7, 7, 7, 7, 12, 12, 17)
self.assertEqual(self.sql1('select median(x) from generic;'), 7)
Generic.delete().execute()
self._store_values(9, 2, 2, 3, 3, 1)
self.assertEqual(self.sql1('select median(x) from generic;'), 3)
Generic.delete().execute()
self._store_values(4, 4, 1, 8, 2, 2, 5, 8, 1)
self.assertEqual(self.sql1('select median(x) from generic;'), 4)
def test_mode(self):
self.assertEqual(self.sql1('select mode(x) from generic;'), None)
self._store_values(1)
self.assertEqual(self.sql1('select mode(x) from generic;'), 1)
self._store_values(4, 5, 6, 1, 3, 4, 1, 4, 9, 3, 4)
self.assertEqual(self.sql1('select mode(x) from generic;'), 4)
def test_ranges(self):
self.assertEqual(self.sql1('select minrange(x) from generic'), None)
self.assertEqual(self.sql1('select avgrange(x) from generic'), None)
self.assertEqual(self.sql1('select range(x) from generic'), None)
self._store_values(1)
self.assertEqual(self.sql1('select minrange(x) from generic'), 0)
self.assertEqual(self.sql1('select avgrange(x) from generic'), 0)
self.assertEqual(self.sql1('select range(x) from generic'), 0)
self._store_values(4, 8, 13, 19)
self.assertEqual(self.sql1('select minrange(x) from generic'), 3)
self.assertEqual(self.sql1('select avgrange(x) from generic'), 4.5)
self.assertEqual(self.sql1('select range(x) from generic'), 18)
Generic.delete().execute()
self._store_values(19, 4, 5, 20, 5, 8)
self.assertEqual(self.sql1('select range(x) from generic'), 16)
class TestScalarFunctions(BaseTestUDF):
requires = MODELS
def test_if_then_else(self):
for i in range(4):
User.create(username='u%d' % (i + 1))
with self.assertQueryCount(1):
query = (User
.select(
User.username,
fn.if_then_else(
User.username << ['u1', 'u2'],
'one or two',
'other').alias('name_type'))
.order_by(User.id))
self.assertEqual([row.name_type for row in query], [
'one or two',
'one or two',
'other',
'other'])
def test_strip_tz(self):
dt = datetime.datetime(2015, 1, 1, 12, 0)
# 13 hours, 37 minutes.
dt_tz = dt.replace(tzinfo=FixedOffset(13 * 60 + 37, 'US/LFK'))
api_dt = APIResponse.create(timestamp=dt)
api_dt_tz = APIResponse.create(timestamp=dt_tz)
# Re-fetch from the database.
api_dt_db = APIResponse.get(APIResponse.id == api_dt.id)
api_dt_tz_db = APIResponse.get(APIResponse.id == api_dt_tz.id)
# Assert the timezone is present, first of all, and that they were
# stored in the database.
self.assertEqual(api_dt_db.timestamp, dt)
query = (APIResponse
.select(
APIResponse.id,
fn.strip_tz(APIResponse.timestamp).alias('ts'))
.order_by(APIResponse.id))
ts, ts_tz = query[:]
self.assertEqual(ts.ts, dt)
self.assertEqual(ts_tz.ts, dt)
def test_human_delta(self):
values = [0, 1, 30, 300, 3600, 7530, 300000]
for value in values:
Generic.create(value=value)
delta = fn.human_delta(Generic.value).coerce(False)
query = (Generic
.select(
Generic.value,
delta.alias('delta'))
.order_by(Generic.value))
results = query.tuples()[:]
self.assertEqual(results, [
(0, '0 seconds'),
(1, '1 second'),
(30, '30 seconds'),
(300, '5 minutes'),
(3600, '1 hour'),
(7530, '2 hours, 5 minutes, 30 seconds'),
(300000, '3 days, 11 hours, 20 minutes'),
])
def test_file_ext(self):
data = (
('test.py', '.py'),
('test.x.py', '.py'),
('test', ''),
('test.', '.'),
('/foo.bar/test/nug.py', '.py'),
('/foo.bar/test/nug', ''),
)
for filename, ext in data:
res = self.sql1('SELECT file_ext(?)', filename)
self.assertEqual(res, ext)
def test_gz(self):
random.seed(1)
A = ord('A')
z = ord('z')
with self.database.atomic():
def randstr(l):
return ''.join([
chr(random.randint(A, z))
for _ in range(l)])
data = (
'a',
'a' * 1024,
randstr(1024),
randstr(4096),
randstr(1024 * 64))
for s in data:
compressed = self.sql1('select gzip(?)', s)
decompressed = self.sql1('select gunzip(?)', compressed)
self.assertEqual(decompressed.decode('utf-8'), s)
def test_hostname(self):
r = json.dumps({'success': True})
data = (
('https://charlesleifer.com/api/', r),
('https://a.charlesleifer.com/api/foo', r),
('www.nugget.com', r),
('nugz.com', r),
('http://a.b.c.peewee/foo', r),
('https://charlesleifer.com/xx', r),
('https://charlesleifer.com/xx', r),
)
with self.database.atomic():
for url, response in data:
APIResponse.create(url=url, data=data)
with self.assertQueryCount(1):
query = (APIResponse
.select(
fn.hostname(APIResponse.url).alias('host'),
fn.COUNT(APIResponse.id).alias('count'))
.group_by(fn.hostname(APIResponse.url))
.order_by(
fn.COUNT(APIResponse.id).desc(),
fn.hostname(APIResponse.url)))
results = query.tuples()[:]
self.assertEqual(results, [
('charlesleifer.com', 3),
('', 2),
('a.b.c.peewee', 1),
('a.charlesleifer.com', 1)])
@skip_unless(IS_SQLITE_9, 'requires sqlite >= 3.9')
def test_toggle(self):
self.assertEqual(self.sql1('select toggle(?)', 'foo'), 1)
self.assertEqual(self.sql1('select toggle(?)', 'bar'), 1)
self.assertEqual(self.sql1('select toggle(?)', 'foo'), 0)
self.assertEqual(self.sql1('select toggle(?)', 'foo'), 1)
self.assertEqual(self.sql1('select toggle(?)', 'bar'), 0)
self.assertEqual(self.sql1('select clear_toggles()'), None)
self.assertEqual(self.sql1('select toggle(?)', 'foo'), 1)
def test_setting(self):
self.assertEqual(self.sql1('select setting(?, ?)', 'k1', 'v1'), 'v1')
self.assertEqual(self.sql1('select setting(?, ?)', 'k2', 'v2'), 'v2')
self.assertEqual(self.sql1('select setting(?)', 'k1'), 'v1')
self.assertEqual(self.sql1('select setting(?, ?)', 'k2', 'v2-x'), 'v2-x')
self.assertEqual(self.sql1('select setting(?)', 'k2'), 'v2-x')
self.assertEqual(self.sql1('select setting(?)', 'kx'), None)
self.assertEqual(self.sql1('select clear_settings()'), None)
self.assertEqual(self.sql1('select setting(?)', 'k1'), None)
def test_random_range(self):
vals = ((1, 10), (1, 100), (0, 2), (1, 5, 2))
results = []
for params in vals:
random.seed(1)
results.append(random.randrange(*params))
for params, expected in zip(vals, results):
random.seed(1)
if len(params) == 3:
pstr = '?, ?, ?'
else:
pstr = '?, ?'
self.assertEqual(
self.sql1('select randomrange(%s)' % pstr, *params),
expected)
def test_sqrt(self):
self.assertEqual(self.sql1('select sqrt(?)', 4), 2)
self.assertEqual(round(self.sql1('select sqrt(?)', 2), 2), 1.41)
def test_tonumber(self):
data = (
('123', 123),
('1.23', 1.23),
('1e4', 10000),
('-10', -10),
('x', None),
('13d', None),
)
for inp, outp in data:
self.assertEqual(self.sql1('select tonumber(?)', inp), outp)
@requires_cython
def test_leven(self):
self.assertEqual(
self.sql1('select levenshtein_dist(?, ?)', 'abc', 'ba'),
2)
self.assertEqual(
self.sql1('select levenshtein_dist(?, ?)', 'abcde', 'eba'),
4)
self.assertEqual(
self.sql1('select levenshtein_dist(?, ?)', 'abcde', 'abcde'),
0)
@requires_cython
def test_str_dist(self):
self.assertEqual(
self.sql1('select str_dist(?, ?)', 'abc', 'ba'),
3)
self.assertEqual(
self.sql1('select str_dist(?, ?)', 'abcde', 'eba'),
6)
self.assertEqual(
self.sql1('select str_dist(?, ?)', 'abcde', 'abcde'),
0)
def test_substr_count(self):
self.assertEqual(
self.sql1('select substr_count(?, ?)', 'foo bar baz', 'a'), 2)
self.assertEqual(
self.sql1('select substr_count(?, ?)', 'foo bor baz', 'o'), 3)
self.assertEqual(
self.sql1('select substr_count(?, ?)', 'foodooboope', 'oo'), 3)
self.assertEqual(self.sql1('select substr_count(?, ?)', 'xx', ''), 0)
self.assertEqual(self.sql1('select substr_count(?, ?)', '', ''), 0)
def test_strip_chars(self):
self.assertEqual(
self.sql1('select strip_chars(?, ?)', ' hey foo ', ' '),
'hey foo')