Files
2026-04-23 16:12:19 -05:00

1172 lines
40 KiB
Python

#coding:utf-8
import datetime
import functools
import json
import os
import uuid
from decimal import Decimal as Dc
from types import MethodType
from peewee import *
from playhouse.postgres_ext import *
from playhouse.reflection import Introspector
from .base import BaseTestCase
from .base import DatabaseTestCase
from .base import ModelTestCase
from .base import TestModel
from .base import db_loader
from .base import requires_models
from .base import skip_if
from .base import skip_unless
from .base_models import Register
from .base_models import Tweet
from .base_models import User
from .postgres_helpers import BaseBinaryJsonFieldTestCase
from .postgres_helpers import BaseJsonFieldTestCase
db = db_loader('postgres', db_class=PostgresqlExtDatabase)
class HStoreModel(TestModel):
name = CharField()
data = HStoreField()
D = HStoreModel.data
class ArrayModel(TestModel):
tags = ArrayField(CharField)
ints = ArrayField(IntegerField, dimensions=2)
class UUIDList(TestModel):
key = CharField()
id_list = ArrayField(BinaryUUIDField, convert_values=True, index=False)
id_list_native = ArrayField(UUIDField, index=False)
class ArrayTSModel(TestModel):
key = CharField(max_length=100, primary_key=True)
timestamps = ArrayField(TimestampField, convert_values=True)
class DecimalArray(TestModel):
values = ArrayField(DecimalField, field_kwargs={'decimal_places': 1})
class FTSModel(TestModel):
title = CharField()
data = TextField()
fts_data = TSVectorField()
class JsonModel(TestModel):
data = JSONField()
class JsonModelNull(TestModel):
data = JSONField(null=True)
class BJson(TestModel):
data = BinaryJSONField()
class JData(TestModel):
d1 = BinaryJSONField()
d2 = BinaryJSONField(index=False)
class UUIDEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, uuid.UUID):
return str(obj)
return super(UUIDEncoder, self).default(obj)
def dumps(obj):
return json.dumps(obj, cls=UUIDEncoder)
class CustomJSONDumps(TestModel):
jf = JSONField(dumps=dumps)
jbf = BinaryJSONField(dumps=dumps)
class Normal(TestModel):
data = TextField()
class Event(TestModel):
name = CharField()
duration = IntervalField()
class TZModel(TestModel):
dt = DateTimeTZField()
class TestTZField(ModelTestCase):
database = db
requires = [TZModel]
@skip_if(os.environ.get('CI'), 'running in ci mode, skipping')
def test_tz_field(self):
self.database.set_time_zone('us/eastern')
# Our naive datetime is treated as if it were in US/Eastern.
dt = datetime.datetime(2019, 1, 1, 12)
tz = TZModel.create(dt=dt)
self.assertTrue(tz.dt.tzinfo is None)
# When we retrieve the row, psycopg will attach the appropriate tzinfo
# data. The value is returned as an "aware" datetime in US/Eastern.
tz_db = TZModel[tz.id]
self.assertTrue(tz_db.dt.tzinfo is not None)
self.assertEqual(tz_db.dt.timetuple()[:4], (2019, 1, 1, 12))
self.assertEqual(tz_db.dt.utctimetuple()[:4], (2019, 1, 1, 17))
class _UTC(datetime.tzinfo):
def utcoffset(self, dt): return datetime.timedelta(0)
def tzname(self, dt): return "UTC"
def dst(self, dt): return datetime.timedelta(0)
UTC = _UTC()
# We can explicitly insert a row with a different timezone, however.
# When we read the row back, it is returned in US/Eastern.
dt2 = datetime.datetime(2019, 1, 1, 12, tzinfo=UTC)
tz2 = TZModel.create(dt=dt2)
tz2_db = TZModel[tz2.id]
self.assertEqual(tz2_db.dt.timetuple()[:4], (2019, 1, 1, 7))
self.assertEqual(tz2_db.dt.utctimetuple()[:4], (2019, 1, 1, 12))
# Querying using naive datetime, treated as localtime (US/Eastern).
tzq1 = TZModel.get(TZModel.dt == dt)
self.assertEqual(tzq1.id, tz.id)
# Querying using aware datetime, tzinfo is respected.
tzq2 = TZModel.get(TZModel.dt == dt2)
self.assertEqual(tzq2.id, tz2.id)
# Change the connection timezone?
self.database.set_time_zone('us/central')
tz_db = TZModel[tz.id]
self.assertEqual(tz_db.dt.timetuple()[:4], (2019, 1, 1, 11))
self.assertEqual(tz_db.dt.utctimetuple()[:4], (2019, 1, 1, 17))
tz2_db = TZModel[tz2.id]
self.assertEqual(tz2_db.dt.timetuple()[:4], (2019, 1, 1, 6))
self.assertEqual(tz2_db.dt.utctimetuple()[:4], (2019, 1, 1, 12))
class TestHStoreField(ModelTestCase):
database = db_loader('postgres', db_class=PostgresqlExtDatabase,
register_hstore=True)
requires = [HStoreModel]
def setUp(self):
super(TestHStoreField, self).setUp()
self.t1 = HStoreModel.create(name='t1', data={'k1': 'v1', 'k2': 'v2'})
self.t2 = HStoreModel.create(name='t2', data={'k2': 'v2', 'k3': 'v3'})
def by_name(self, name):
return HStoreModel.get(HStoreModel.name == name).data
def test_hstore_storage(self):
self.assertEqual(self.by_name('t1'), {'k1': 'v1', 'k2': 'v2'})
self.assertEqual(self.by_name('t2'), {'k2': 'v2', 'k3': 'v3'})
self.t1.data = {'k4': 'v4'}
self.t1.save()
self.assertEqual(self.by_name('t1'), {'k4': 'v4'})
HStoreModel.create(name='t3', data={})
self.assertEqual(self.by_name('t3'), {})
def query(self, *cols):
return (HStoreModel
.select(HStoreModel.name, *cols)
.order_by(HStoreModel.id))
def test_hstore_selecting(self):
query = self.query(D.keys().alias('keys'))
self.assertEqual([(x.name, sorted(x.keys)) for x in query], [
('t1', ['k1', 'k2']), ('t2', ['k2', 'k3'])])
query = self.query(D.values().alias('vals'))
self.assertEqual([(x.name, sorted(x.vals)) for x in query], [
('t1', ['v1', 'v2']), ('t2', ['v2', 'v3'])])
query = self.query(D.items().alias('mtx'))
self.assertEqual([(x.name, sorted(x.mtx)) for x in query], [
('t1', [['k1', 'v1'], ['k2', 'v2']]),
('t2', [['k2', 'v2'], ['k3', 'v3']])])
query = self.query(D.slice('k2', 'k3').alias('kz'))
self.assertEqual([(x.name, x.kz) for x in query], [
('t1', {'k2': 'v2'}),
('t2', {'k2': 'v2', 'k3': 'v3'})])
query = self.query(D.slice('k4').alias('kz'))
self.assertEqual([(x.name, x.kz) for x in query], [
('t1', {}), ('t2', {})])
query = self.query(D.exists('k3').alias('ke'))
self.assertEqual([(x.name, x.ke) for x in query], [
('t1', False), ('t2', True)])
query = self.query(D.defined('k3').alias('ke'))
self.assertEqual([(x.name, x.ke) for x in query], [
('t1', False), ('t2', True)])
query = self.query(D['k1'].alias('k1'))
self.assertEqual([(x.name, x.k1) for x in query], [
('t1', 'v1'), ('t2', None)])
query = self.query().where(D['k1'] == 'v1')
self.assertEqual([x.name for x in query], ['t1'])
def assertWhere(self, expr, names):
query = HStoreModel.select().where(expr)
self.assertEqual([x.name for x in query], names)
def test_hstore_filtering(self):
self.assertWhere(D == {'k1': 'v1', 'k2': 'v2'}, ['t1'])
self.assertWhere(D == {'k2': 'v2'}, [])
self.assertWhere(D.contains('k3'), ['t2'])
self.assertWhere(D.contains(['k2', 'k3']), ['t2'])
self.assertWhere(D.contains(['k2']), ['t1', 't2'])
# test dict
self.assertWhere(D.contains({'k2': 'v2', 'k3': 'v3'}), ['t2'])
self.assertWhere(D.contains({'k2': 'v2'}), ['t1', 't2'])
self.assertWhere(D.contains({'k2': 'v3'}), [])
# test contains any.
self.assertWhere(D.contains_any('k3', 'kx'), ['t2'])
self.assertWhere(D.contains_any('k2', 'x', 'k3'), ['t1', 't2'])
self.assertWhere(D.contains_any('x', 'kx', 'y'), [])
def test_hstore_filter_functions(self):
self.assertWhere(D.exists('k2') == True, ['t1', 't2'])
self.assertWhere(D.exists('k3') == True, ['t2'])
self.assertWhere(D.defined('k2') == True, ['t1', 't2'])
self.assertWhere(D.defined('k3') == True, ['t2'])
def test_hstore_update(self):
rc = (HStoreModel
.update(data=D.update(k4='v4'))
.where(HStoreModel.name == 't1')
.execute())
self.assertTrue(rc > 0)
self.assertEqual(self.by_name('t1'),
{'k1': 'v1', 'k2': 'v2', 'k4': 'v4'})
rc = (HStoreModel
.update(data=D.update(k5='v5', k6='v6'))
.where(HStoreModel.name == 't2')
.execute())
self.assertTrue(rc > 0)
self.assertEqual(self.by_name('t2'),
{'k2': 'v2', 'k3': 'v3', 'k5': 'v5', 'k6': 'v6'})
HStoreModel.update(data=D.update(k2='vxxx')).execute()
self.assertEqual([x.data for x in self.query(D)], [
{'k1': 'v1', 'k2': 'vxxx', 'k4': 'v4'},
{'k2': 'vxxx', 'k3': 'v3', 'k5': 'v5', 'k6': 'v6'}])
(HStoreModel
.update(data=D.delete('k4'))
.where(HStoreModel.name == 't1')
.execute())
self.assertEqual(self.by_name('t1'), {'k1': 'v1', 'k2': 'vxxx'})
HStoreModel.update(data=D.delete('k5')).execute()
self.assertEqual([x.data for x in self.query(D)], [
{'k1': 'v1', 'k2': 'vxxx'},
{'k2': 'vxxx', 'k3': 'v3', 'k6': 'v6'}
])
HStoreModel.update(data=D.delete('k1', 'k2')).execute()
self.assertEqual([x.data for x in self.query(D)], [
{},
{'k3': 'v3', 'k6': 'v6'}])
class TestArrayField(ModelTestCase):
database = db
requires = [ArrayModel]
def create_sample(self):
return ArrayModel.create(
tags=['alpha', 'beta', 'gamma', 'delta'],
ints=[[1, 2], [3, 4], [5, 6]])
def test_index_expression(self):
data = (
(['a', 'b', 'c'], []),
(['b', 'c', 'd', 'e'], []))
am_ids = []
for tags, ints in data:
am = ArrayModel.create(tags=tags, ints=ints)
am_ids.append(am.id)
last_tag = fn.array_upper(ArrayModel.tags, 1)
query = ArrayModel.select(ArrayModel.tags[last_tag]).tuples()
self.assertEqual(sorted([t for t, in query]), ['c', 'e'])
q = ArrayModel.select().where(ArrayModel.tags[last_tag] < 'd')
self.assertEqual([a.id for a in q], [am_ids[0]])
q = ArrayModel.select().where(ArrayModel.tags[last_tag] > 'd')
self.assertEqual([a.id for a in q], [am_ids[1]])
def test_hashable_objectslice(self):
ArrayModel.create(tags=[], ints=[[0, 1], [2, 3]])
ArrayModel.create(tags=[], ints=[[4, 5], [6, 7]])
n = (ArrayModel
.update({ArrayModel.ints[0][0]: ArrayModel.ints[0][0] + 1})
.execute())
self.assertEqual(n, 2)
am1, am2 = ArrayModel.select().order_by(ArrayModel.id)
self.assertEqual(am1.ints, [[1, 1], [2, 3]])
self.assertEqual(am2.ints, [[5, 5], [6, 7]])
def test_array_get_set(self):
am = self.create_sample()
am_db = ArrayModel.get(ArrayModel.id == am.id)
self.assertEqual(am_db.tags, ['alpha', 'beta', 'gamma', 'delta'])
self.assertEqual(am_db.ints, [[1, 2], [3, 4], [5, 6]])
def test_array_equality(self):
am1 = ArrayModel.create(tags=['t1'], ints=[[1, 2]])
am2 = ArrayModel.create(tags=['t2'], ints=[[3, 4]])
obj = ArrayModel.get(ArrayModel.tags == ['t1'])
self.assertEqual(obj.id, am1.id)
self.assertEqual(obj.tags, ['t1'])
obj = ArrayModel.get(ArrayModel.ints == [[3, 4]])
self.assertEqual(obj.id, am2.id)
obj = ArrayModel.get(ArrayModel.tags != ['t1'])
self.assertEqual(obj.id, am2.id)
def test_array_db_value(self):
am = ArrayModel.create(tags=('foo', 'bar'), ints=[])
am_db = ArrayModel.get(ArrayModel.id == am.id)
self.assertEqual(am_db.tags, ['foo', 'bar'])
def test_array_search(self):
def assertAM(where, *instances):
query = (ArrayModel
.select()
.where(where)
.order_by(ArrayModel.id))
self.assertEqual([x.id for x in query], [x.id for x in instances])
am = self.create_sample()
am2 = ArrayModel.create(tags=['alpha', 'beta'], ints=[[1, 1]])
am3 = ArrayModel.create(tags=['delta'], ints=[[3, 4]])
am4 = ArrayModel.create(tags=['中文'], ints=[[3, 4]])
am5 = ArrayModel.create(tags=['中文', '汉语'], ints=[[3, 4]])
AM = ArrayModel
T = AM.tags
assertAM((Value('beta') == fn.ANY(T)), am, am2)
assertAM((Value('delta') == fn.Any(T)), am, am3)
assertAM(Value('omega') == fn.Any(T))
# Check the contains operator.
assertAM(SQL("tags::text[] @> ARRAY['beta']"), am, am2)
# Use the nicer API.
assertAM(T.contains('beta'), am, am2)
assertAM(T.contains('omega', 'delta'))
assertAM(T.contains('汉语'), am5)
assertAM(T.contains('alpha', 'delta'), am)
assertAM(T.contained_by('alpha', 'beta', 'delta'), am2, am3)
assertAM(T.contained_by('alpha', 'beta', 'gamma', 'delta'),
am, am2, am3)
# Check for any.
assertAM(T.contains_any('beta'), am, am2)
assertAM(T.contains_any('中文'), am4, am5)
assertAM(T.contains_any('omega', 'delta'), am, am3)
assertAM(T.contains_any('alpha', 'delta'), am, am2, am3)
def test_array_index_slice(self):
self.create_sample()
AM = ArrayModel
I, T = AM.ints, AM.tags
row = AM.select(T[1].alias('arrtags')).dicts().get()
self.assertEqual(row['arrtags'], 'beta')
row = AM.select(T[2:3].alias('foo')).dicts().get()
self.assertEqual(row['foo'], ['gamma'])
row = AM.select(T[:2].alias('foo')).dicts().get()
self.assertEqual(row['foo'], ['alpha', 'beta'])
row = AM.select(T[2:].alias('foo')).dicts().get()
self.assertEqual(row['foo'], ['gamma', 'delta'])
row = AM.select(T[2:4].alias('foo')).dicts().get()
self.assertEqual(row['foo'], ['gamma', 'delta'])
row = AM.select(I[1][1].alias('ints')).dicts().get()
self.assertEqual(row['ints'], 4)
row = AM.select(I[1:3][0].alias('ints')).dicts().get()
self.assertEqual(row['ints'], [[3], [5]])
@requires_models(DecimalArray)
def test_field_kwargs(self):
vl1, vl2 = [Dc('3.1'), Dc('1.3')], [Dc('3.14'), Dc('1')]
da1, da2 = [DecimalArray.create(values=vl) for vl in (vl1, vl2)]
da1_db = DecimalArray.get(DecimalArray.id == da1.id)
da2_db = DecimalArray.get(DecimalArray.id == da2.id)
self.assertEqual(da1_db.values, [Dc('3.1'), Dc('1.3')])
self.assertEqual(da2_db.values, [Dc('3.1'), Dc('1.0')])
class TestArrayFieldConvertValues(ModelTestCase):
database = db
requires = [ArrayTSModel]
def dt(self, day, hour=0, minute=0, second=0):
return datetime.datetime(2018, 1, day, hour, minute, second)
def test_value_conversion(self):
data = {
'k1': [self.dt(1), self.dt(2), self.dt(3)],
'k2': [],
'k3': [self.dt(4, 5, 6, 7), self.dt(10, 11, 12, 13)],
}
for key in sorted(data):
ArrayTSModel.create(key=key, timestamps=data[key])
for key in sorted(data):
am = ArrayTSModel.get(ArrayTSModel.key == key)
self.assertEqual(am.timestamps, data[key])
# Perform lookup using timestamp values.
ts = ArrayTSModel.get(ArrayTSModel.timestamps.contains(self.dt(3)))
self.assertEqual(ts.key, 'k1')
ts = ArrayTSModel.get(
ArrayTSModel.timestamps.contains(self.dt(4, 5, 6, 7)))
self.assertEqual(ts.key, 'k3')
self.assertRaises(ArrayTSModel.DoesNotExist, ArrayTSModel.get,
ArrayTSModel.timestamps.contains(self.dt(4, 5, 6)))
def test_get_with_array_values(self):
a1 = ArrayTSModel.create(key='k1', timestamps=[self.dt(1)])
a2 = ArrayTSModel.create(key='k2', timestamps=[self.dt(2), self.dt(3)])
query = (ArrayTSModel
.select()
.where(ArrayTSModel.timestamps == [self.dt(1)]))
a1_db = query.get()
self.assertEqual(a1_db.id, a1.id)
query = (ArrayTSModel
.select()
.where(ArrayTSModel.timestamps == [self.dt(2), self.dt(3)]))
a2_db = query.get()
self.assertEqual(a2_db.id, a2.id)
a1_db = ArrayTSModel.get(timestamps=[self.dt(1)])
self.assertEqual(a1_db.id, a1.id)
a2_db = ArrayTSModel.get(timestamps=[self.dt(2), self.dt(3)])
self.assertEqual(a2_db.id, a2.id)
class TestArrayUUIDField(ModelTestCase):
database = db
requires = [UUIDList]
def test_array_of_uuids(self):
u1, u2, u3, u4 = [uuid.uuid4() for _ in range(4)]
a = UUIDList.create(key='a', id_list=[u1, u2, u3],
id_list_native=[u1, u2, u3])
b = UUIDList.create(key='b', id_list=[u2, u3, u4],
id_list_native=[u2, u3, u4])
a_db = UUIDList.get(UUIDList.key == 'a')
b_db = UUIDList.get(UUIDList.key == 'b')
self.assertEqual(a.id_list, [u1, u2, u3])
self.assertEqual(b.id_list, [u2, u3, u4])
self.assertEqual(a.id_list_native, [u1, u2, u3])
self.assertEqual(b.id_list_native, [u2, u3, u4])
class TestTSVectorField(ModelTestCase):
database = db
requires = [FTSModel]
messages = [
'A faith is a necessity to a man. Woe to him who believes in nothing.',
'All who call on God in true faith, earnestly from the heart, will '
'certainly be heard, and will receive what they have asked and desired.',
'Be faithful in small things because it is in them that your strength lies.',
'Faith consists in believing when it is beyond the power of reason to believe.',
'Faith has to do with things that are not seen and hope with things that are not at hand.',
]
def setUp(self):
super(TestTSVectorField, self).setUp()
for idx, message in enumerate(self.messages):
FTSModel.create(title=str(idx), data=message,
fts_data=fn.to_tsvector(message))
def assertMessages(self, expr, expected):
query = FTSModel.select().where(expr).order_by(FTSModel.id)
titles = [row.title for row in query]
self.assertEqual(list(map(int, titles)), expected)
def test_sql(self):
query = FTSModel.select().where(Match(FTSModel.data, 'foo bar'))
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."title", "t1"."data", "t1"."fts_data" '
'FROM "fts_model" AS "t1" '
'WHERE (to_tsvector("t1"."data") @@ to_tsquery(?))'), ['foo bar'])
def test_match_function(self):
D = FTSModel.data
self.assertMessages(Match(D, 'heart'), [1])
self.assertMessages(Match(D, 'god'), [1])
self.assertMessages(Match(D, 'faith'), [0, 1, 2, 3, 4])
self.assertMessages(Match(D, 'thing'), [2, 4])
self.assertMessages(Match(D, 'faith & things'), [2, 4])
self.assertMessages(Match(D, 'god | things'), [1, 2, 4])
self.assertMessages(Match(D, 'god & things'), [])
def test_tsvector_field(self):
M = FTSModel.fts_data.match
self.assertMessages(M('heart'), [1])
self.assertMessages(M('god'), [1])
self.assertMessages(M('faith'), [0, 1, 2, 3, 4])
self.assertMessages(M('thing'), [2, 4])
self.assertMessages(M('faith & things'), [2, 4])
self.assertMessages(M('god | things'), [1, 2, 4])
self.assertMessages(M('god & things'), [])
# Using the plain parser we cannot express "OR", but individual term
# match works like we expect and multi-term is AND-ed together.
self.assertMessages(M('god | things', plain=True), [])
self.assertMessages(M('god', plain=True), [1])
self.assertMessages(M('thing', plain=True), [2, 4])
self.assertMessages(M('faith things', plain=True), [2, 4])
def pg12():
with db:
return db.server_version >= 120000
class TestJsonField(BaseJsonFieldTestCase, ModelTestCase):
M = JsonModel
N = Normal
database = db
requires = [JsonModel, Normal, JsonModelNull]
def test_json_null(self):
tjn = JsonModelNull.create(data=None)
tj = JsonModelNull.create(data={'k1': 'v1'})
results = JsonModelNull.select().order_by(JsonModelNull.id)
self.assertEqual(
[tj_db.data for tj_db in results],
[None, {'k1': 'v1'}])
query = JsonModelNull.select().where(
JsonModelNull.data.is_null(True))
self.assertEqual(query.get(), tjn)
class TestBinaryJsonField(BaseBinaryJsonFieldTestCase, ModelTestCase):
M = BJson
N = Normal
database = db
requires = [BJson, Normal]
def test_remove_data(self):
BJson.delete().execute() # Clear out db.
BJson.create(data={
'k1': 'v1',
'k2': 'v2',
'k3': {'x1': 'z1', 'x2': 'z2'},
'k4': [0, 1, 2]})
def assertData(exp_list, expected_data):
query = BJson.select(BJson.data.remove(*exp_list)).tuples()
data = query[:][0][0]
self.assertEqual(data, expected_data)
D = BJson.data
assertData(['k3'], {'k1': 'v1', 'k2': 'v2', 'k4': [0, 1, 2]})
assertData(['k1', 'k3'], {'k2': 'v2', 'k4': [0, 1, 2]})
assertData(['k1', 'kx', 'ky', 'k3'], {'k2': 'v2', 'k4': [0, 1, 2]})
assertData(['k4', 'k3'], {'k1': 'v1', 'k2': 'v2'})
def test_remove_path(self):
BJson.delete().execute() # Clear out db.
data = {'k1': {'x1': {'y1': 'z1', 'y2': 'z2'}, 'x2': ['i1', 'i2']}}
BJson.create(data=data)
def assertData(exp_list, expected_data):
curr = BJson.data
for exp in exp_list:
curr = curr[exp]
query = BJson.select(curr.remove()).tuples()
data = query[:][0][0]
self.assertEqual(data, expected_data)
assertData(['k1'], {})
assertData(['k1', 'x1'], {'k1': {'x2': ['i1', 'i2']}})
assertData(['k1', 'x1', 'y1'],
{'k1': {'x1': {'y2': 'z2'}, 'x2': ['i1', 'i2']}})
assertData(['k1', 'x1', 'y2'],
{'k1': {'x1': {'y1': 'z1'}, 'x2': ['i1', 'i2']}})
assertData(['k1', 'x2', 0],
{'k1': {'x1': {'y1': 'z1', 'y2': 'z2'}, 'x2': ['i2']}})
assertData(['k1', 'x2', -1],
{'k1': {'x1': {'y1': 'z1', 'y2': 'z2'}, 'x2': ['i1']}})
assertData(['kx'], data)
assertData(['k1', 'zz'], data)
def test_json_length(self):
BJson.delete().execute() # Clear out db.
data = {'k1': {'x1': [1, 2, 3], 'x2': [1, 2], 'x3': []}}
BJson.create(data=data)
def assertLength(exp_list, count):
curr = BJson.data
for exp in exp_list:
curr = curr[exp]
query = BJson.select(curr.length()).tuples()
data = query[:][0][0]
self.assertEqual(data, count)
assertLength(('k1', 'x1'), 3)
assertLength(('k1', 'x2'), 2)
assertLength(('k1', 'x3'), 0)
BJson.delete().execute() # Clear out db.
BJson.create(data=[0, 1, 2, 3, 4, 5])
assertLength((), 6)
def test_json_extract(self):
BJson.delete().execute() # Clear out db.
data = {'k1': {'x1': {'y1': 'z1', 'y2': 'z2'}, 'x2': ['i1', 'i2']}}
BJson.create(data=data)
def assertData(node, path, expected_data):
query = BJson.select(node.extract(*path)).tuples()
data = query[:][0][0]
self.assertEqual(data, expected_data)
assertData(BJson.data, ('k1', 'x1', 'y1'), 'z1')
assertData(BJson.data, ('k1', 'x1'), {'y1': 'z1', 'y2': 'z2'})
assertData(BJson.data, ('k1', 'x2', 0), 'i1')
assertData(BJson.data, ('k1', 'x2', -1), 'i2')
assertData(BJson.data, ('k1',),
{'x1': {'y1': 'z1', 'y2': 'z2'}, 'x2': ['i1', 'i2']})
assertData(BJson.data, ('kx',), None)
assertData(BJson.data['k1'], ('x1', 'y1'), 'z1')
assertData(BJson.data['k1']['x1'], ('y1',), 'z1')
assertData(BJson.data['k1']['x2'], (0,), 'i1')
def test_json_contains_in_list(self):
m1 = self.M.create(data=[{'k1': 'v1', 'k2': 'v2'}, {'a1': 'b1'}])
m2 = self.M.create(data=[{'k3': 'v3'}, {'k4': 'v4'}])
m3 = self.M.create(data=[{'k5': 'v5', 'k6': 'v6'}, {'k1': 'v1'}])
query = (self.M
.select()
.where(self.M.data.contains([{'k1': 'v1'}]))
.order_by(self.M.id))
self.assertEqual([m.id for m in query], [m1.id, m3.id])
def test_integer_index_weirdness(self):
self._create_test_data()
def fails():
with self.database.atomic():
expr = BJson.data.contains_any(2, 8, 12)
results = list(BJson.select().where(
BJson.data.contains_any(2, 8, 12)))
# Complains of a missing cast/conversion for the data-type?
self.assertRaises(ProgrammingError, fails)
class Point(object):
def __init__(self, x, y):
self.x, self.y = x, y
def __eq__(self, other):
return (self.x, self.y) == (other.x, other.y)
class CustomJsonField(BinaryJSONField):
def db_value(self, value):
if isinstance(value, Point):
value = {'x': value.x, 'y': value.y}
return super(CustomJsonField, self).db_value(value)
def python_value(self, value):
if value is not None:
return Point(**value)
class CJM(TestModel):
name = TextField()
point = CustomJsonField()
class TestJsonFieldRegressions(ModelTestCase):
database = db
requires = [JData]
def test_json_field_concat(self):
jd = JData.create(
d1={'k1': {'x1': 'y1'}, 'k2': 'v2', 'k3': 'v3'},
d2={'k1': {'x2': 'y2'}, 'k2': 'v2-x', 'k4': 'v4'})
query = JData.select(JData.d1.concat(JData.d2).alias('data'))
obj = query.get()
self.assertEqual(obj.data, {
'k1': {'x2': 'y2'}, 'k2': 'v2-x', 'k3': 'v3', 'k4': 'v4'})
def test_introspect_bjson_field(self):
introspector = Introspector.from_database(self.database)
models = introspector.generate_models(table_names=['j_data'])
JD = models['j_data']
self.assertEqual(JD._meta.sorted_field_names, ['id', 'd1', 'd2'])
self.assertTrue(isinstance(JD.d1, BinaryJSONField))
self.assertTrue(isinstance(JD.d2, BinaryJSONField))
self.assertTrue(JD.d1.index)
self.assertEqual(JD.d1.index_type, 'GIN')
self.assertFalse(JD.d2.index)
@requires_models(CJM)
def test_json_field_subclass(self):
c1 = CJM.create(name='c1', point=Point(1, 2))
c2 = CJM.insert(name='c2', point=Point(2, 3)).execute()
c1_db = CJM.get(CJM.name == 'c1')
c2_db = CJM.get(CJM.name == 'c2')
self.assertEqual(c1_db.point, Point(1, 2))
self.assertEqual(c2_db.point, Point(2, 3))
c2_db = CJM.select().where(CJM.point == Point(2, 3)).get()
self.assertEqual(c2_db.name, 'c2')
CJM.update(point=Point(3, 4)).where(CJM.point == Point(1, 2)).execute()
c1, c2 = CJM.select().order_by(CJM.name)
self.assertEqual(c1.point, Point(3, 4))
self.assertEqual(c2.point, Point(2, 3))
c1.point = Point(1.2, 2.5)
c1.save()
c1_db = CJM.get(CJM.name == 'c1')
self.assertEqual(c1_db.point, Point(1.2, 2.5))
class TestJSONFieldCustomDumps(ModelTestCase):
database = db
requires = [CustomJSONDumps]
def test_custom_dumps(self):
u1 = uuid.uuid4()
u2 = uuid.uuid4()
data = {'u1': u1, 'u2': u2, 'u3': [u1, u2]}
c = CustomJSONDumps.create(jf=data, jbf=data)
c_db = CustomJSONDumps.get_by_id(c.id)
self.assertEqual(c_db.jf, {
'u1': str(u1),
'u2': str(u2),
'u3': [str(u1), str(u2)]})
self.assertEqual(c_db.jbf, {
'u1': str(u1),
'u2': str(u2),
'u3': [str(u1), str(u2)]})
class TestIntervalField(ModelTestCase):
database = db
requires = [Event]
def test_interval_field(self):
e1 = Event.create(name='hour', duration=datetime.timedelta(hours=1))
e2 = Event.create(name='mix', duration=datetime.timedelta(
days=1,
hours=2,
minutes=3,
seconds=4))
events = [(e.name, e.duration)
for e in Event.select().order_by(Event.duration)]
self.assertEqual(events, [
('hour', datetime.timedelta(hours=1)),
('mix', datetime.timedelta(days=1, hours=2, minutes=3, seconds=4))
])
class TestIndexedField(BaseTestCase):
def test_indexed_field_ddl(self):
class FakeIndexedField(IndexedFieldMixin, CharField):
default_index_type = 'GiST'
class IndexedModel(TestModel):
array_index = ArrayField(CharField)
array_noindex= ArrayField(IntegerField, index=False)
fake_index = FakeIndexedField()
fake_index_with_type = FakeIndexedField(index_type='MAGIC')
fake_noindex = FakeIndexedField(index=False)
class Meta:
database = db
create_sql, _ = IndexedModel._schema._create_table(False).query()
self.assertEqual(create_sql, (
'CREATE TABLE "indexed_model" ('
'"id" SERIAL NOT NULL PRIMARY KEY, '
'"array_index" VARCHAR(255)[] NOT NULL, '
'"array_noindex" INTEGER[] NOT NULL, '
'"fake_index" VARCHAR(255) NOT NULL, '
'"fake_index_with_type" VARCHAR(255) NOT NULL, '
'"fake_noindex" VARCHAR(255) NOT NULL)'))
indexes = [idx.query()[0]
for idx in IndexedModel._schema._create_indexes(False)]
self.assertEqual(indexes, [
('CREATE INDEX "indexed_model_array_index" ON "indexed_model" '
'USING GIN ("array_index")'),
('CREATE INDEX "indexed_model_fake_index" ON "indexed_model" '
'USING GiST ("fake_index")'),
('CREATE INDEX "indexed_model_fake_index_with_type" '
'ON "indexed_model" '
'USING MAGIC ("fake_index_with_type")')])
class IDAlways(TestModel):
id = IdentityField(generate_always=True)
data = CharField()
class IDByDefault(TestModel):
id = IdentityField()
data = CharField()
class TestIdentityField(ModelTestCase):
database = db
requires = [IDAlways, IDByDefault]
def test_identity_field_always(self):
iq = IDAlways.insert_many([(d,) for d in ('d1', 'd2', 'd3')])
curs = iq.execute()
self.assertEqual(list(curs), [(1,), (2,), (3,)])
# Cannot specify id when generate always is true.
with self.assertRaises(ProgrammingError):
with self.database.atomic():
IDAlways.create(id=10, data='d10')
query = IDAlways.select().order_by(IDAlways.id)
self.assertEqual(list(query.tuples()), [
(1, 'd1'), (2, 'd2'), (3, 'd3')])
def test_identity_field_by_default(self):
iq = IDByDefault.insert_many([(d,) for d in ('d1', 'd2', 'd3')])
curs = iq.execute()
self.assertEqual(list(curs), [(1,), (2,), (3,)])
# Cannot specify id when generate always is true.
IDByDefault.create(id=10, data='d10')
query = IDByDefault.select().order_by(IDByDefault.id)
self.assertEqual(list(query.tuples()), [
(1, 'd1'), (2, 'd2'), (3, 'd3'), (10, 'd10')])
def test_schema(self):
sql, params = IDAlways._schema._create_table(False).query()
self.assertEqual(sql, (
'CREATE TABLE "id_always" ("id" INT GENERATED ALWAYS AS IDENTITY '
'NOT NULL PRIMARY KEY, "data" VARCHAR(255) NOT NULL)'))
sql, params = IDByDefault._schema._create_table(False).query()
self.assertEqual(sql, (
'CREATE TABLE "id_by_default" ("id" INT GENERATED BY DEFAULT AS '
'IDENTITY NOT NULL PRIMARY KEY, "data" VARCHAR(255) NOT NULL)'))
class TestServerSide(ModelTestCase):
database = db
requires = [Register]
def setUp(self):
super(TestServerSide, self).setUp()
with db.atomic():
for i in range(100):
Register.create(value=i)
def test_server_side_cursor(self):
query = Register.select().order_by(Register.value)
with self.database.atomic():
with self.assertQueryCount(1):
data = [row.value for row in ServerSide(query)]
self.assertEqual(data, list(range(100)))
ss_query = ServerSide(query.limit(10), array_size=3)
self.assertEqual([row.value for row in ss_query], list(range(10)))
ss_query = ServerSide(query.where(SQL('1 = 0')))
self.assertEqual(list(ss_query), [])
def test_lower_level_apis(self):
query = Register.select(Register.value).order_by(Register.value)
with self.database.atomic():
ssq = ServerSideQuery(query, array_size=10)
curs_wrapper = ssq._execute(self.database)
curs = curs_wrapper.cursor
self.assertTrue(isinstance(curs, FetchManyCursor))
self.assertEqual(curs.fetchone(), (0,))
self.assertEqual(curs.fetchone(), (1,))
curs.close()
def test_close_cursor(self):
query = Register.select(Register.value).order_by(Register.value)
with self.database.atomic():
ssq = ServerSideQuery(query, array_size=10)
accum = []
for i, obj in enumerate(ssq.iterator()):
if i == 25:
break
accum.append(obj.value)
self.assertTrue(ssq.close())
self.assertEqual(len(accum), 25)
self.assertEqual(accum, list(range(25)))
class KX(TestModel):
key = CharField(unique=True)
value = IntegerField()
class TestAutocommitIntegration(ModelTestCase):
database = db
requires = [KX]
def setUp(self):
super(TestAutocommitIntegration, self).setUp()
with self.database.atomic():
kx1 = KX.create(key='k1', value=1)
def force_integrity_error(self):
# Force an integrity error, then verify that the current
# transaction has been aborted.
self.assertRaises(IntegrityError, KX.create, key='k1', value=10)
def test_autocommit_default(self):
kx2 = KX.create(key='k2', value=2) # Will be committed.
self.assertTrue(kx2.id > 0)
self.force_integrity_error()
self.assertEqual(KX.select().count(), 2)
self.assertEqual([(kx.key, kx.value)
for kx in KX.select().order_by(KX.key)],
[('k1', 1), ('k2', 2)])
def test_autocommit_disabled(self):
with self.database.manual_commit():
self.database.begin()
kx2 = KX.create(key='k2', value=2) # Not committed.
self.assertTrue(kx2.id > 0) # Yes, we have a primary key.
self.force_integrity_error()
self.database.rollback()
self.assertEqual(KX.select().count(), 1)
kx1_db = KX.get(KX.key == 'k1')
self.assertEqual(kx1_db.value, 1)
def test_atomic_block(self):
with self.database.atomic() as txn:
kx2 = KX.create(key='k2', value=2)
self.assertTrue(kx2.id > 0)
self.force_integrity_error()
txn.rollback(False)
self.assertEqual(KX.select().count(), 1)
kx1_db = KX.get(KX.key == 'k1')
self.assertEqual(kx1_db.value, 1)
def test_atomic_block_exception(self):
with self.assertRaises(IntegrityError):
with self.database.atomic():
KX.create(key='k2', value=2)
KX.create(key='k1', value=10)
self.assertEqual(KX.select().count(), 1)
class TestPostgresIsolationLevel(DatabaseTestCase):
database = db_loader('postgres', isolation_level=3) # SERIALIZABLE.
def test_isolation_level(self):
conn = self.database.connection()
self.assertEqual(conn.isolation_level, 3)
conn.set_isolation_level(2)
self.assertEqual(conn.isolation_level, 2)
self.database.close()
conn = self.database.connection()
self.assertEqual(conn.isolation_level, 3)
self.database.close()
self.database.set_isolation_level(2)
for _ in range(2):
conn = self.database.connection()
self.assertEqual(conn.isolation_level, 2)
self.database.close()
def test_isolation_level_str(self):
db = db_loader('postgres', isolation_level='SERIALIZABLE')
conn = db.connection()
self.assertEqual(conn.isolation_level,
db._adapter.isolation_levels_inv['SERIALIZABLE'])
db.close()
db.set_isolation_level('READ COMMITTED')
conn = db.connection()
self.assertEqual(conn.isolation_level,
db._adapter.isolation_levels_inv['READ COMMITTED'])
db.close()
def test_isolation_level_mixed(self):
db = db_loader('postgres', isolation_level='SERIALIZABLE')
conn = db.connection()
self.assertEqual(conn.isolation_level,
db._adapter.isolation_levels_inv['SERIALIZABLE'])
db.close()
rc = db._adapter.isolation_levels_inv['READ COMMITTED']
db.set_isolation_level(rc)
conn = db.connection()
self.assertEqual(conn.isolation_level, rc)
db.close()
@skip_unless(pg12(), 'cte materialization requires pg >= 12')
class TestPostgresCTEMaterialization(ModelTestCase):
database = db
requires = [Register]
def test_postgres_cte_materialization(self):
Register.insert_many([(i,) for i in (1, 2, 3)]).execute()
for materialized in (None, False, True):
cte = Register.select().cte('t', materialized=materialized)
query = (cte
.select_from(cte.c.value)
.where(cte.c.value != 2)
.order_by(cte.c.value))
self.assertEqual([r.value for r in query], [1, 3])
class TestPostgresLateralJoin(ModelTestCase):
database = db
test_data = (
('a', (('a1', 1),
('a2', 2),
('a10', 10))),
('b', (('b3', 3),
('b4', 4),
('b7', 7))),
('c', ()))
def create_data(self):
ts = lambda d: datetime.datetime(2019, 1, d)
with self.database.atomic():
for username, tweets in self.test_data:
user = User.create(username=username)
for c, d in tweets:
Tweet.create(user=user, content=c, timestamp=ts(d))
@requires_models(User, Tweet)
def test_lateral_top_n(self):
self.create_data()
subq = (Tweet
.select(Tweet.content, Tweet.timestamp)
.where(Tweet.user == User.id)
.order_by(Tweet.timestamp.desc())
.limit(2))
query = (User
.select(User, subq.c.content)
.join(subq, JOIN.LEFT_LATERAL)
.order_by(subq.c.timestamp.desc(nulls='last')))
results = [(u.username, u.content) for u in query]
self.assertEqual(results, [
('a', 'a10'),
('b', 'b7'),
('b', 'b4'),
('a', 'a2'),
('c', None)])
query = (Tweet
.select(User.username, subq.c.content)
.from_(User)
.join(subq, JOIN.LEFT_LATERAL)
.order_by(User.username, subq.c.timestamp))
results = [(t.username, t.content) for t in query]
self.assertEqual(results, [
('a', 'a2'),
('a', 'a10'),
('b', 'b4'),
('b', 'b7'),
('c', None)])
@requires_models(User, Tweet)
def test_lateral_helper(self):
self.create_data()
subq = (Tweet
.select(Tweet.content, Tweet.timestamp)
.where(Tweet.user == User.id)
.order_by(Tweet.timestamp.desc())
.limit(2)
.lateral())
query = (User
.select(User, subq.c.content)
.join(subq, on=True)
.order_by(subq.c.timestamp.desc(nulls='last')))
with self.assertQueryCount(1):
results = [(u.username, u.tweet.content) for u in query]
self.assertEqual(results, [
('a', 'a10'),
('b', 'b7'),
('b', 'b4'),
('a', 'a2')])