Files
peewee/tests/shortcuts.py
T
2025-07-23 12:03:29 -05:00

877 lines
31 KiB
Python

import operator
from peewee import *
from playhouse.shortcuts import *
from .base import BaseTestCase
from .base import DatabaseTestCase
from .base import ModelTestCase
from .base import TestModel
from .base import db_loader
from .base import get_in_memory_db
from .base import requires_models
from .base import requires_mysql
from .base_models import Category
class User(TestModel):
username = TextField()
@property
def name_hash(self):
return sum(map(ord, self.username)) % 10
class Tweet(TestModel):
user = ForeignKeyField(User, backref='tweets')
content = TextField()
class Tag(TestModel):
tag = TextField()
class TweetTag(TestModel):
tweet = ForeignKeyField(Tweet)
tag = ForeignKeyField(Tag)
class Meta:
primary_key = CompositeKey('tweet', 'tag')
class Owner(TestModel):
name = TextField()
class Label(TestModel):
label = TextField()
class Gallery(TestModel):
name = TextField()
labels = ManyToManyField(Label, backref='galleries')
owner = ForeignKeyField(Owner, backref='galleries')
GalleryLabel = Gallery.labels.through_model
class Student(TestModel):
name = TextField()
StudentCourseProxy = DeferredThroughModel()
class Course(TestModel):
name = TextField()
students = ManyToManyField(Student, through_model=StudentCourseProxy,
backref='courses')
class StudentCourse(TestModel):
student = ForeignKeyField(Student)
course = ForeignKeyField(Course)
StudentCourseProxy.set_model(StudentCourse)
class Host(TestModel):
name = TextField()
class Service(TestModel):
host = ForeignKeyField(Host, backref='services')
name = TextField()
class Device(TestModel):
host = ForeignKeyField(Host, backref='+')
name = TextField()
class Basket(TestModel):
id = IntegerField(primary_key=True)
class Item(TestModel):
id = IntegerField(primary_key=True)
basket = ForeignKeyField(Basket)
class NodeTag(TestModel):
tag = TextField()
class Node(TestModel):
name = TextField()
tag = ForeignKeyField(NodeTag)
parent = ForeignKeyField('self', null=True, backref='children')
class TestModelToDict(ModelTestCase):
database = get_in_memory_db()
requires = [User, Tweet, Tag, TweetTag]
def setUp(self):
super(TestModelToDict, self).setUp()
self.user = User.create(username='peewee')
@requires_models(Node, NodeTag)
def test_self_referential(self):
a, b = [NodeTag.create(tag=tag) for tag in 'ab']
root = Node.create(name='root', tag=a)
n1 = Node.create(name='n1', parent=root, tag=a)
n2 = Node.create(name='n2', parent=root, tag=b)
Parent = Node.alias('parent')
ParentTag = NodeTag.alias('parent_tag')
def assertSerialization(n, expected):
obj = (Node
.select(Node, NodeTag, Parent, ParentTag)
.join_from(Node, NodeTag, JOIN.LEFT_OUTER)
.join_from(Node, Parent, JOIN.LEFT_OUTER)
.join_from(Parent, ParentTag, JOIN.LEFT_OUTER)
.where(Node.name == n)
.first())
self.assertEqual(model_to_dict(obj, recurse=True), expected)
assertSerialization('n1', {
'id': n1.id,
'name': 'n1',
'parent': {'id': root.id, 'name': 'root',
'tag': {'id': a.id, 'tag': 'a'}},
'tag': {'id': a.id, 'tag': 'a'}})
assertSerialization('n2', {
'id': n2.id,
'name': 'n2',
'parent': {'id': root.id, 'name': 'root',
'tag': {'id': a.id, 'tag': 'a'}},
'tag': {'id': b.id, 'tag': 'b'}})
assertSerialization('root', {
'id': root.id,
'name': 'root',
'parent': None,
'tag': {'id': a.id, 'tag': 'a'}})
def test_simple(self):
with self.assertQueryCount(0):
self.assertEqual(model_to_dict(self.user), {
'id': self.user.id,
'username': 'peewee'})
def test_simple_recurse(self):
tweet = Tweet.create(user=self.user, content='t1')
with self.assertQueryCount(0):
self.assertEqual(model_to_dict(tweet), {
'id': tweet.id,
'content': tweet.content,
'user': {
'id': self.user.id,
'username': 'peewee'}})
with self.assertQueryCount(0):
self.assertEqual(model_to_dict(tweet, recurse=False), {
'id': tweet.id,
'content': tweet.content,
'user': self.user.id})
def test_simple_backref(self):
with self.assertQueryCount(1):
self.assertEqual(model_to_dict(self.user, backrefs=True), {
'id': self.user.id,
'tweets': [],
'username': 'peewee'})
tweet = Tweet.create(user=self.user, content='t0')
# Two queries, one for tweets, one for tweet-tags.
with self.assertQueryCount(2):
self.assertEqual(model_to_dict(self.user, backrefs=True), {
'id': self.user.id,
'username': 'peewee',
'tweets': [{'id': tweet.id, 'content': 't0',
'tweettag_set': []}]})
def test_recurse_and_backrefs(self):
tweet = Tweet.create(user=self.user, content='t0')
with self.assertQueryCount(1):
self.assertEqual(model_to_dict(tweet, backrefs=True), {
'id': tweet.id,
'content': 't0',
'tweettag_set': [],
'user': {'id': self.user.id, 'username': 'peewee'}})
@requires_models(Category)
def test_recursive_fk(self):
root = Category.create(name='root')
child = Category.create(name='child', parent=root)
grandchild = Category.create(name='grandchild', parent=child)
with self.assertQueryCount(0):
for recurse in (True, False):
self.assertEqual(model_to_dict(root, recurse=recurse), {
'name': 'root',
'parent': None})
with self.assertQueryCount(1):
self.assertEqual(model_to_dict(root, backrefs=True), {
'name': 'root',
'parent': None,
'children': [{'name': 'child'}]})
with self.assertQueryCount(1):
self.assertEqual(model_to_dict(root, backrefs=True), {
'name': 'root',
'parent': None,
'children': [{'name': 'child'}]})
with self.assertQueryCount(1):
self.assertEqual(model_to_dict(child, backrefs=True), {
'name': 'child',
'parent': {'name': 'root'},
'children': [{'name': 'grandchild'}]})
with self.assertQueryCount(0):
self.assertEqual(model_to_dict(child, backrefs=False), {
'name': 'child',
'parent': {'name': 'root'}})
def test_manytomany(self):
tweet = Tweet.create(user=self.user, content='t0')
tag1 = Tag.create(tag='t1')
tag2 = Tag.create(tag='t2')
Tag.create(tag='tx')
TweetTag.create(tweet=tweet, tag=tag1)
TweetTag.create(tweet=tweet, tag=tag2)
with self.assertQueryCount(4):
self.assertEqual(model_to_dict(self.user, backrefs=True), {
'id': self.user.id,
'username': 'peewee',
'tweets': [{
'id': tweet.id,
'content': 't0',
'tweettag_set': [
{'tag': {'id': tag1.id, 'tag': 't1'}},
{'tag': {'id': tag2.id, 'tag': 't2'}}]}]})
@requires_models(Label, Gallery, GalleryLabel, Owner)
def test_manytomany_field(self):
data = (
('charlie', 'family', ('nuggie', 'bearbe')),
('charlie', 'pets', ('huey', 'zaizee', 'beanie')),
('peewee', 'misc', ('nuggie', 'huey')))
for owner_name, gallery, labels in data:
owner, _ = Owner.get_or_create(name=owner_name)
gallery = Gallery.create(name=gallery, owner=owner)
label_objects = [Label.get_or_create(label=l)[0] for l in labels]
gallery.labels.add(label_objects)
query = (Gallery
.select(Gallery, Owner)
.join(Owner)
.switch(Gallery)
.join(GalleryLabel)
.join(Label)
.where(Label.label == 'nuggie')
.order_by(Gallery.id))
rows = [model_to_dict(gallery, backrefs=True, manytomany=True)
for gallery in query]
self.assertEqual(rows, [
{
'id': 1,
'name': 'family',
'owner': {'id': 1, 'name': 'charlie'},
'labels': [{'id': 1, 'label': 'nuggie'},
{'id': 2, 'label': 'bearbe'}],
},
{
'id': 3,
'name': 'misc',
'owner': {'id': 2, 'name': 'peewee'},
'labels': [{'id': 1, 'label': 'nuggie'},
{'id': 3, 'label': 'huey'}],
}])
@requires_models(Student, Course, StudentCourse)
def test_manytomany_deferred(self):
data = (
('s1', ('ca', 'cb', 'cc')),
('s2', ('cb', 'cd')),
('s3', ()))
c = {}
for student, courses in data:
s = Student.create(name=student)
for course in courses:
if course not in c:
c[course] = Course.create(name=course)
StudentCourse.create(student=s, course=c[course])
query = Student.select().order_by(Student.name)
data = []
for user in query:
user_dict = model_to_dict(user, manytomany=True)
user_dict['courses'].sort(key=operator.itemgetter('id'))
data.append(user_dict)
self.assertEqual(data, [
{'id': 1, 'name': 's1', 'courses': [
{'id': 1, 'name': 'ca'},
{'id': 2, 'name': 'cb'},
{'id': 3, 'name': 'cc'}]},
{'id': 2, 'name': 's2', 'courses': [
{'id': 2, 'name': 'cb'},
{'id': 4, 'name': 'cd'}]},
{'id': 3, 'name': 's3', 'courses': []}])
query = Course.select().order_by(Course.name)
data = []
for course in query:
course_dict = model_to_dict(course, manytomany=True)
course_dict['students'].sort(key=operator.itemgetter('id'))
data.append(course_dict)
self.assertEqual(data, [
{'id': 1, 'name': 'ca', 'students': [
{'id': 1, 'name': 's1'}]},
{'id': 2, 'name': 'cb', 'students': [
{'id': 1, 'name': 's1'},
{'id': 2, 'name': 's2'}]},
{'id': 3, 'name': 'cc', 'students': [
{'id': 1, 'name': 's1'}]},
{'id': 4, 'name': 'cd', 'students': [
{'id': 2, 'name': 's2'}]}])
def test_recurse_max_depth(self):
t0, t1, t2 = [Tweet.create(user=self.user, content='t%s' % i)
for i in range(3)]
tag0, tag1 = [Tag.create(tag=t) for t in ['tag0', 'tag1']]
TweetTag.create(tweet=t0, tag=tag0)
TweetTag.create(tweet=t0, tag=tag1)
TweetTag.create(tweet=t1, tag=tag1)
data = model_to_dict(self.user, recurse=True, backrefs=True)
self.assertEqual(data, {
'id': self.user.id,
'username': 'peewee',
'tweets': [
{'id': t0.id, 'content': 't0', 'tweettag_set': [
{'tag': {'tag': 'tag0', 'id': tag0.id}},
{'tag': {'tag': 'tag1', 'id': tag1.id}},
]},
{'id': t1.id, 'content': 't1', 'tweettag_set': [
{'tag': {'tag': 'tag1', 'id': tag1.id}},
]},
{'id': t2.id, 'content': 't2', 'tweettag_set': []},
]})
data = model_to_dict(self.user, recurse=True, backrefs=True,
max_depth=2)
self.assertEqual(data, {
'id': self.user.id,
'username': 'peewee',
'tweets': [
{'id': t0.id, 'content': 't0', 'tweettag_set': [
{'tag': tag0.id}, {'tag': tag1.id},
]},
{'id': t1.id, 'content': 't1', 'tweettag_set': [
{'tag': tag1.id},
]},
{'id': t2.id, 'content': 't2', 'tweettag_set': []},
]})
data = model_to_dict(self.user, recurse=True, backrefs=True,
max_depth=1)
self.assertEqual(data, {
'id': self.user.id,
'username': 'peewee',
'tweets': [
{'id': t0.id, 'content': 't0'},
{'id': t1.id, 'content': 't1'},
{'id': t2.id, 'content': 't2'}]})
self.assertEqual(model_to_dict(self.user, recurse=True, backrefs=True,
max_depth=0),
{'id': self.user.id, 'username': 'peewee'})
def test_only(self):
username_dict = {'username': 'peewee'}
self.assertEqual(model_to_dict(self.user, only=[User.username]),
username_dict)
self.assertEqual(
model_to_dict(self.user, backrefs=True, only=[User.username]),
username_dict)
tweet = Tweet.create(user=self.user, content='t0')
tweet_dict = {'content': 't0', 'user': {'username': 'peewee'}}
field_list = [Tweet.content, Tweet.user, User.username]
self.assertEqual(model_to_dict(tweet, only=field_list),
tweet_dict)
self.assertEqual(model_to_dict(tweet, backrefs=True, only=field_list),
tweet_dict)
tweet_dict['user'] = self.user.id
self.assertEqual(model_to_dict(tweet, backrefs=True, recurse=False,
only=field_list),
tweet_dict)
def test_exclude(self):
self.assertEqual(model_to_dict(self.user, exclude=[User.id]),
{'username': 'peewee'})
# Exclude the foreign key using FK field and backref.
self.assertEqual(model_to_dict(self.user, backrefs=True,
exclude=[User.id, Tweet.user]),
{'username': 'peewee'})
self.assertEqual(model_to_dict(self.user, backrefs=True,
exclude=[User.id, User.tweets]),
{'username': 'peewee'})
tweet = Tweet.create(user=self.user, content='t0')
fields = [Tweet.tweettag_set, Tweet.id, Tweet.user]
self.assertEqual(model_to_dict(tweet, backrefs=True, exclude=fields),
{'content': 't0'})
fields[-1] = User.id
self.assertEqual(model_to_dict(tweet, backrefs=True, exclude=fields),
{'content': 't0', 'user': {'username': 'peewee'}})
def test_extra_attrs(self):
with self.assertQueryCount(0):
extra = ['name_hash']
self.assertEqual(model_to_dict(self.user, extra_attrs=extra), {
'id': self.user.id,
'username': 'peewee',
'name_hash': 5})
with self.assertQueryCount(0):
self.assertRaises(AttributeError, model_to_dict, self.user,
extra_attrs=['xx'])
def test_fields_from_query(self):
User.delete().execute()
for i in range(3):
user = User.create(username='u%d' % i)
for x in range(i + 1):
Tweet.create(user=user, content='%s-%s' % (user.username, x))
query = (User
.select(User.username, fn.COUNT(Tweet.id).alias('ct'))
.join(Tweet, JOIN.LEFT_OUTER)
.group_by(User.username)
.order_by(User.id))
with self.assertQueryCount(1):
u0, u1, u2 = list(query)
self.assertEqual(model_to_dict(u0, fields_from_query=query), {
'username': 'u0',
'ct': 1})
self.assertEqual(model_to_dict(u2, fields_from_query=query), {
'username': 'u2',
'ct': 3})
query = (Tweet
.select(Tweet, User, SQL('1337').alias('magic'))
.join(User)
.order_by(Tweet.id)
.limit(1))
with self.assertQueryCount(1):
tweet, = query
self.assertEqual(model_to_dict(tweet, fields_from_query=query), {
'id': tweet.id,
'content': 'u0-0',
'magic': 1337,
'user': {'id': tweet.user_id, 'username': 'u0'}})
self.assertEqual(model_to_dict(tweet, fields_from_query=query,
exclude=[User.id, Tweet.id]),
{'magic': 1337, 'content': 'u0-0',
'user': {'username': 'u0'}})
def test_fields_from_query_alias(self):
q = User.select(User.username.alias('name'))
res = q[0]
self.assertEqual(model_to_dict(res, fields_from_query=q),
{'name': 'peewee'})
UA = User.alias()
q = UA.select(UA.username.alias('name'))
res = q[0]
self.assertEqual(model_to_dict(res, fields_from_query=q),
{'name': 'peewee'})
def test_only_backref(self):
for i in range(3):
Tweet.create(user=self.user, content=str(i))
data = model_to_dict(self.user, backrefs=True, only=[
User.username,
User.tweets,
Tweet.content])
if 'tweets' in data:
data['tweets'].sort(key=lambda t: t['content'])
self.assertEqual(data, {
'username': 'peewee',
'tweets': [
{'content': '0'},
{'content': '1'},
{'content': '2'}]})
@requires_models(Host, Service, Device)
def test_model_to_dict_disabled_backref(self):
host = Host.create(name='pi')
Device.create(host=host, name='raspberry pi')
Service.create(host=host, name='ssh')
Service.create(host=host, name='vpn')
data = model_to_dict(host, recurse=True, backrefs=True)
services = sorted(data.pop('services'), key=operator.itemgetter('id'))
self.assertEqual(data, {'id': 1, 'name': 'pi'})
self.assertEqual(services, [
{'id': 1, 'name': 'ssh'},
{'id': 2, 'name': 'vpn'}])
@requires_models(Basket, Item)
def test_empty_vs_null_fk(self):
b = Basket.create(id=0)
i = Item.create(id=0, basket=b)
data = model_to_dict(i)
self.assertEqual(data, {'id': 0, 'basket': {'id': 0}})
data = model_to_dict(i, recurse=False)
self.assertEqual(data, {'id': 0, 'basket': 0})
class TestDictToModel(ModelTestCase):
database = get_in_memory_db()
requires = [User, Tweet, Tag, TweetTag]
def setUp(self):
super(TestDictToModel, self).setUp()
self.user = User.create(username='peewee')
def test_simple(self):
data = {'username': 'peewee', 'id': self.user.id}
inst = dict_to_model(User, data)
self.assertTrue(isinstance(inst, User))
self.assertEqual(inst.username, 'peewee')
self.assertEqual(inst.id, self.user.id)
def test_update_model_from_dict(self):
data = {'content': 'tweet', 'user': {'username': 'zaizee'}}
with self.assertQueryCount(0):
user = User(id=3, username='orig')
tweet = Tweet(id=4, content='orig', user=user)
obj = update_model_from_dict(tweet, data)
self.assertEqual(obj.id, 4)
self.assertEqual(obj.content, 'tweet')
self.assertEqual(obj.user.id, 3)
self.assertEqual(obj.user.username, 'zaizee')
def test_related(self):
data = {
'id': 2,
'content': 'tweet-1',
'user': {'id': self.user.id, 'username': 'peewee'}}
with self.assertQueryCount(0):
inst = dict_to_model(Tweet, data)
self.assertTrue(isinstance(inst, Tweet))
self.assertEqual(inst.id, 2)
self.assertEqual(inst.content, 'tweet-1')
self.assertTrue(isinstance(inst.user, User))
self.assertEqual(inst.user.id, self.user.id)
self.assertEqual(inst.user.username, 'peewee')
data['user'] = self.user.id
with self.assertQueryCount(0):
inst = dict_to_model(Tweet, data)
with self.assertQueryCount(1):
self.assertEqual(inst.user, self.user)
def test_backrefs(self):
data = {
'id': self.user.id,
'username': 'peewee',
'tweets': [
{'id': 1, 'content': 't1'},
{'id': 2, 'content': 't2'},
]}
with self.assertQueryCount(0):
inst = dict_to_model(User, data)
self.assertEqual(inst.id, self.user.id)
self.assertEqual(inst.username, 'peewee')
self.assertTrue(isinstance(inst.tweets, list))
t1, t2 = inst.tweets
self.assertEqual(t1.id, 1)
self.assertEqual(t1.content, 't1')
self.assertEqual(t1.user, self.user)
self.assertEqual(t2.id, 2)
self.assertEqual(t2.content, 't2')
self.assertEqual(t2.user, self.user)
def test_unknown_attributes(self):
data = {
'id': self.user.id,
'username': 'peewee',
'xx': 'does not exist'}
self.assertRaises(AttributeError, dict_to_model, User, data)
inst = dict_to_model(User, data, ignore_unknown=True)
self.assertEqual(inst.xx, 'does not exist')
def test_ignore_id_attribute(self):
class Register(Model):
key = CharField(primary_key=True)
data = {'id': 100, 'key': 'k1'}
self.assertRaises(AttributeError, dict_to_model, Register, data)
inst = dict_to_model(Register, data, ignore_unknown=True)
self.assertEqual(inst.__data__, {'key': 'k1'})
class Base(Model):
class Meta:
primary_key = False
class Register2(Model):
key = CharField(primary_key=True)
self.assertRaises(AttributeError, dict_to_model, Register2, data)
inst = dict_to_model(Register2, data, ignore_unknown=True)
self.assertEqual(inst.__data__, {'key': 'k1'})
class ReconnectMySQLDatabase(ReconnectMixin, MySQLDatabase):
def cursor(self, named_cursor=None):
cursor = super(ReconnectMySQLDatabase, self).cursor(named_cursor)
# The first (0th) query fails, as do all queries after the 2nd (1st).
if self._query_counter != 1:
def _fake_execute(self, *args):
raise OperationalError('2006')
cursor.execute = _fake_execute
self._query_counter += 1
return cursor
def close(self):
self._close_counter += 1
return super(ReconnectMySQLDatabase, self).close()
def _reset_mock(self):
self._close_counter = 0
self._query_counter = 0
@requires_mysql
class TestReconnectMixin(DatabaseTestCase):
database = db_loader('mysql', db_class=ReconnectMySQLDatabase)
def test_reconnect_mixin_execute_sql(self):
# Verify initial state.
self.database._reset_mock()
self.assertEqual(self.database._close_counter, 0)
sql = 'select 1 + 1'
curs = self.database.execute_sql(sql)
self.assertEqual(curs.fetchone(), (2,))
self.assertEqual(self.database._close_counter, 1)
# Due to how we configured our mock, our queries are now failing and we
# can verify a reconnect is occuring *AND* the exception is propagated.
self.assertRaises(OperationalError, self.database.execute_sql, sql)
self.assertEqual(self.database._close_counter, 2)
# We reset the mock counters. The first query we execute will fail. The
# second query will succeed (which happens automatically, thanks to the
# retry logic).
self.database._reset_mock()
curs = self.database.execute_sql(sql)
self.assertEqual(curs.fetchone(), (2,))
self.assertEqual(self.database._close_counter, 1)
def test_reconnect_mixin_begin(self):
# Verify initial state.
self.database._reset_mock()
self.assertEqual(self.database._close_counter, 0)
with self.database.atomic():
self.assertTrue(self.database.in_transaction())
self.assertEqual(self.database._close_counter, 1)
# Prepare mock for commit call
self.database._query_counter = 1
# Due to how we configured our mock, our queries are now failing and we
# can verify a reconnect is occuring *AND* the exception is propagated.
self.assertRaises(OperationalError, self.database.atomic().__enter__)
self.assertEqual(self.database._close_counter, 2)
self.assertFalse(self.database.in_transaction())
# We reset the mock counters. The first query we execute will fail. The
# second query will succeed (which happens automatically, thanks to the
# retry logic).
self.database._reset_mock()
with self.database.atomic():
self.assertTrue(self.database.in_transaction())
self.assertEqual(self.database._close_counter, 1)
# Do not reconnect when nesting transactions
self.assertRaises(OperationalError, self.database.atomic().__enter__)
self.assertEqual(self.database._close_counter, 1)
# Prepare mock for commit call
self.database._query_counter = 1
self.assertFalse(self.database.in_transaction())
class MMA(TestModel):
key = TextField()
value = IntegerField()
class MMB(TestModel):
key = TextField()
class MMC(TestModel):
key = TextField()
value = IntegerField()
misc = TextField(null=True)
class TestResolveMultiModelQuery(ModelTestCase):
requires = [MMA, MMB, MMC]
def test_resolve_multimodel_query(self):
MMA.insert_many([('k0', 0), ('k1', 1)]).execute()
MMB.insert_many([('k10',), ('k11',)]).execute()
MMC.insert_many([('k20', 20, 'a'), ('k21', 21, 'b')]).execute()
mma = MMA.select(MMA.key, MMA.value)
mmb = MMB.select(MMB.key, Value(99).alias('value'))
mmc = MMC.select(MMC.key, MMC.value)
query = (mma | mmb | mmc).order_by(SQL('1'))
data = [obj for obj in resolve_multimodel_query(query)]
expected = [
MMA(key='k0', value=0), MMA(key='k1', value=1),
MMB(key='k10', value=99), MMB(key='k11', value=99),
MMC(key='k20', value=20), MMC(key='k21', value=21)]
self.assertEqual(len(data), len(expected))
for row, exp_row in zip(data, expected):
self.assertEqual(row.__class__, exp_row.__class__)
self.assertEqual(row.key, exp_row.key)
self.assertEqual(row.value, exp_row.value)
ts_database = get_in_memory_db()
class TSBase(Model):
class Meta:
database = ts_database
model_metadata_class = ThreadSafeDatabaseMetadata
class TSReg(TSBase):
key = TextField()
class TestThreadSafeDatabaseMetadata(BaseTestCase):
def setUp(self):
super(TestThreadSafeDatabaseMetadata, self).setUp()
ts_database.create_tables([TSReg])
def test_threadsafe_database_metadata(self):
self.assertTrue(isinstance(TSReg._meta, ThreadSafeDatabaseMetadata))
self.assertEqual(TSReg._meta.database, ts_database)
t1 = TSReg.create(key='k1')
t1_db = TSReg.get(TSReg.key == 'k1')
self.assertEqual(t1.id, t1_db.id)
def test_swap_database(self):
d1 = get_in_memory_db()
d2 = get_in_memory_db()
class M(TSBase):
pass
def swap_db():
self.assertEqual(M._meta.database, ts_database)
d1.bind([M])
self.assertEqual(M._meta.database, d1)
with d2.bind_ctx([M]):
self.assertEqual(M._meta.database, d2)
self.assertEqual(M._meta.database, d1)
self.assertEqual(M._meta.database, ts_database)
# From a separate thread, swap the database and verify it works
# correctly.
t = threading.Thread(target=swap_db)
t.start() ; t.join()
# In the main thread the original database has not been altered.
self.assertEqual(M._meta.database, ts_database)
def test_preserve_original_db(self):
outputs = []
d1 = get_in_memory_db()
d2 = get_in_memory_db()
class M(TSBase):
class Meta:
database = d1
def swap_db():
self.assertTrue(M._meta.database is d1)
with d2.bind_ctx([M]):
self.assertTrue(M._meta.database is d2)
self.assertTrue(M._meta.database is d1)
d2.bind([M]) # Now bind to d2 and leave it bound.
self.assertTrue(M._meta.database is d2)
# From a separate thread, swap the database and verify it works
# correctly.
threads = [threading.Thread(target=swap_db) for _ in range(20)]
for t in threads: t.start()
for t in threads: t.join()
# In the main thread the original database has not been altered.
self.assertTrue(M._meta.database is d1)
class TIW(TestModel):
key = CharField()
value = IntegerField(default=0)
extra = IntegerField(default=lambda: 1)
class TestInsertWhere(ModelTestCase):
requires = [User, Tweet, TIW]
def test_insert_where(self):
ua, ub = [User.create(username=n) for n in 'ab']
def _insert_where(user, content):
cond = (Tweet.select()
.where(Tweet.user == user, Tweet.content == content))
where = ~fn.EXISTS(cond)
iq = insert_where(Tweet, {
Tweet.user: user,
Tweet.content: content},
where=where)
return 1 if iq.execute() else 0
self.assertEqual(_insert_where(ua, 't1'), 1)
self.assertEqual(_insert_where(ua, 't2'), 1)
self.assertEqual(_insert_where(ua, 't1'), 0)
self.assertEqual(_insert_where(ua, 't2'), 0)
self.assertEqual(_insert_where(ub, 't1'), 1)
self.assertEqual(_insert_where(ub, 't2'), 1)
def test_insert_where_defaults(self):
TIW.create(key='k1', value=1, extra=2)
def _insert_where(key):
where = ~fn.EXISTS(TIW.select().where(TIW.key == key))
iq = insert_where(TIW, {TIW.key: key}, where)
return 1 if iq.execute() else 0
self.assertEqual(_insert_where('k2'), 1)
self.assertEqual(_insert_where('k1'), 0)
self.assertEqual(_insert_where('k2'), 0)
tiw = TIW.get(TIW.key == 'k2')
self.assertEqual(tiw.value, 0)
self.assertEqual(tiw.extra, 1)