Files
2026-04-23 07:53:46 -05:00

1274 lines
46 KiB
Python

"""
DDL generation tests (CREATE TABLE, indexes, constraints, views).
Test case ordering:
* Core DDL SQL generation (TestModelDDL)
* CREATE TABLE AS (SQL generation and integration)
* View field mapping
* Table name and truncation
* Named constraints integration
"""
import datetime
from peewee import *
from peewee import NodeList
from .base import BaseTestCase
from .base import IS_CRDB
from .base import IS_SQLITE
from .base import ModelDatabaseTestCase
from .base import ModelTestCase
from .base import TestModel
from .base import get_in_memory_db
from .base import requires_postgresql
from .base_models import Category
from .base_models import Note
from .base_models import Person
from .base_models import Relationship
from .base_models import User
# ---------------------------------------------------------------------------
# Module-local models for DDL generation tests.
# Each exercises a specific schema feature (unique, sequence, indexes,
# constraints, schema namespace, etc.).
# ---------------------------------------------------------------------------
class TMUnique(TestModel):
data = TextField(unique=True)
class TMSequence(TestModel):
value = IntegerField(sequence='test_seq')
class TMIndexes(TestModel):
alpha = IntegerField()
beta = IntegerField()
gamma = IntegerField()
class Meta:
indexes = (
(('alpha', 'beta'), True),
(('beta', 'gamma'), False))
class TMConstraints(TestModel):
data = IntegerField(null=True, constraints=[Check('data < 5')])
value = TextField(collation='NOCASE')
added = DateTimeField(constraints=[Default('CURRENT_TIMESTAMP')])
class TMNamedConstraints(TestModel):
fk = ForeignKeyField('self', null=True, constraint_name='tmc_fk')
k = TextField()
v = IntegerField(constraints=[Check('v in (1, 2)')])
class Meta:
constraints = [Check('k != \'kx\'', name='chk_k')]
class CacheData(TestModel):
key = TextField(unique=True)
value = TextField()
class Meta:
schema = 'cache'
class Article(TestModel):
name = TextField(unique=True)
timestamp = TimestampField()
status = IntegerField()
flags = IntegerField()
Article.add_index(Article.timestamp.desc(), Article.status)
idx = (Article
.index(Article.name, Article.timestamp, Article.flags.bin_and(4))
.where(Article.status == 1))
Article.add_index(idx)
Article.add_index(SQL('CREATE INDEX "article_foo" ON "article" ("flags" & 3)'))
# ===========================================================================
# Core DDL SQL generation
# ===========================================================================
class TestModelDDL(ModelDatabaseTestCase):
database = get_in_memory_db()
requires = [Article, CacheData, Category, Note, Person, Relationship,
TMUnique, TMSequence, TMIndexes, TMConstraints,
TMNamedConstraints, User]
def test_database_required(self):
class MissingDB(Model):
data = TextField()
self.assertRaises(ImproperlyConfigured, MissingDB.create_table)
def assertCreateTable(self, model_class, expected):
sql, params = model_class._schema._create_table(False).query()
self.assertEqual(params, [])
indexes = []
for create_index in model_class._schema._create_indexes(False):
isql, params = create_index.query()
self.assertEqual(params, [])
indexes.append(isql)
self.assertEqual([sql] + indexes, expected)
def assertIndexes(self, model_class, expected):
indexes = []
for create_index in model_class._schema._create_indexes(False):
indexes.append(create_index.query())
self.assertEqual(indexes, expected)
def test_model_fk_schema(self):
class Base(TestModel):
class Meta:
database = self.database
class User(Base):
username = TextField()
class Meta:
schema = 'foo'
class Tweet(Base):
user = ForeignKeyField(User)
content = TextField()
class Meta:
schema = 'bar'
self.assertCreateTable(User, [
('CREATE TABLE "foo"."user" ("id" INTEGER NOT NULL PRIMARY KEY, '
'"username" TEXT NOT NULL)')])
self.assertCreateTable(Tweet, [
('CREATE TABLE "bar"."tweet" ("id" INTEGER NOT NULL PRIMARY KEY, '
'"user_id" INTEGER NOT NULL, "content" TEXT NOT NULL, '
'FOREIGN KEY ("user_id") REFERENCES "foo"."user" ("id"))'),
('CREATE INDEX "bar"."tweet_user_id" ON "tweet" ("user_id")')])
def test_bigauto_and_fk(self):
class CustomDB(SqliteDatabase):
field_types = {
'BIGAUTO': 'BIGAUTO',
'BIGINT': 'BIGINT'}
db = CustomDB(None)
class User(db.Model):
id = BigAutoField()
class Tweet(db.Model):
user = ForeignKeyField(User)
self.assertCreateTable(User, [
('CREATE TABLE "user" ("id" BIGAUTO NOT NULL PRIMARY KEY)')])
self.assertCreateTable(Tweet, [
('CREATE TABLE "tweet" ("id" INTEGER NOT NULL PRIMARY KEY, '
'"user_id" BIGINT NOT NULL, FOREIGN KEY ("user_id") REFERENCES '
'"user" ("id"))'),
('CREATE INDEX "tweet_user_id" ON "tweet" ("user_id")')])
def test_model_indexes_with_schema(self):
# Attach cache database so we can reference "cache." as the schema.
self.database.execute_sql("attach database ':memory:' as cache;")
self.assertCreateTable(CacheData, [
('CREATE TABLE "cache"."cache_data" ('
'"id" INTEGER NOT NULL PRIMARY KEY, "key" TEXT NOT NULL, '
'"value" TEXT NOT NULL)'),
('CREATE UNIQUE INDEX "cache"."cache_data_key" ON "cache_data" '
'("key")')])
# Actually create the table to verify it works correctly.
CacheData.create_table()
# Introspect the database and get indexes for the "cache" schema.
indexes = self.database.get_indexes('cache_data', 'cache')
self.assertEqual(len(indexes), 1)
index_metadata = indexes[0]
self.assertEqual(index_metadata.name, 'cache_data_key')
# Verify the index does not exist in the main schema.
self.assertEqual(len(self.database.get_indexes('cache_data')), 0)
class TestDatabase(Database):
index_schema_prefix = False
# When "index_schema_prefix == False", the index name is not prefixed
# with the schema, and the schema is referenced via the table name.
with CacheData.bind_ctx(TestDatabase(None)):
self.assertCreateTable(CacheData, [
('CREATE TABLE "cache"."cache_data" ('
'"id" INTEGER NOT NULL PRIMARY KEY, "key" TEXT NOT NULL, '
'"value" TEXT NOT NULL)'),
('CREATE UNIQUE INDEX "cache_data_key" ON "cache"."cache_data"'
' ("key")')])
def test_model_indexes(self):
self.assertIndexes(Article, [
('CREATE UNIQUE INDEX "article_name" ON "article" ("name")', []),
('CREATE INDEX "article_timestamp_status" ON "article" ('
'"timestamp" DESC, "status")', []),
('CREATE INDEX "article_name_timestamp" ON "article" ('
'"name", "timestamp", ("flags" & 4)) '
'WHERE ("status" = 1)', []),
('CREATE INDEX "article_foo" ON "article" ("flags" & 3)', []),
])
def test_model_index_types(self):
class Event(TestModel):
key = TextField()
timestamp = TimestampField(index=True, index_type='BRIN')
class Meta:
database = self.database
self.assertIndexes(Event, [
('CREATE INDEX "event_timestamp" ON "event" '
'USING BRIN ("timestamp")', [])])
# Check that we support MySQL-style USING clause.
idx, = Event._meta.fields_to_index()
self.assertSQL(idx, (
'CREATE INDEX IF NOT EXISTS "event_timestamp" '
'USING BRIN ON "event" ("timestamp")'), [],
index_using_precedes_table=True)
def test_model_indexes_custom_tablename(self):
class KV(TestModel):
key = TextField()
value = TextField()
timestamp = TimestampField(index=True)
class Meta:
database = self.database
indexes = (
(('key', 'value'), True),
)
table_name = 'kvs'
self.assertIndexes(KV, [
('CREATE INDEX "kvs_timestamp" ON "kvs" ("timestamp")', []),
('CREATE UNIQUE INDEX "kvs_key_value" ON "kvs" ("key", "value")',
[])])
def test_model_indexes_computed_columns(self):
class FuncIdx(TestModel):
a = IntegerField()
b = IntegerField()
class Meta:
database = self.database
i = FuncIdx.index(FuncIdx.a, FuncIdx.b, fn.SUM(FuncIdx.a + FuncIdx.b))
FuncIdx.add_index(i)
self.assertIndexes(FuncIdx, [
('CREATE INDEX "func_idx_a_b" ON "func_idx" '
'("a", "b", SUM("a" + "b"))', []),
])
def test_model_indexes_complex_columns(self):
class Taxonomy(TestModel):
name = CharField()
name_class = CharField()
class Meta:
database = self.database
name = NodeList((fn.LOWER(Taxonomy.name), SQL('varchar_pattern_ops')))
index = (Taxonomy
.index(name, Taxonomy.name_class)
.where(Taxonomy.name_class == 'scientific name'))
Taxonomy.add_index(index)
self.assertIndexes(Taxonomy, [
('CREATE INDEX "taxonomy_name_class" ON "taxonomy" ('
'LOWER("name") varchar_pattern_ops, "name_class") '
'WHERE ("name_class" = ?)', ['scientific name']),
])
def test_add_index_with_fields(self):
class IdxModel(TestModel):
name = CharField()
value = IntegerField()
class Meta:
database = self.database
self.assertEqual(len(IdxModel._meta.indexes), 0)
IdxModel.add_index(IdxModel.name, IdxModel.value, unique=True)
self.assertEqual(len(IdxModel._meta.indexes), 1)
idx = IdxModel._meta.indexes[0]
self.assertIsInstance(idx, ModelIndex)
self.assertTrue(idx._unique)
self.assertIndexes(IdxModel, [
('CREATE UNIQUE INDEX "idx_model_name_value" ON "idx_model" ('
'"name", "value")', []),
])
def test_add_index_with_sql(self):
class IdxModel(TestModel):
name = CharField()
class Meta:
database = self.database
raw = SQL('CREATE INDEX test_idx ON idxmodel2 (name)')
self.assertEqual(len(IdxModel._meta.indexes), 0)
IdxModel.add_index(raw)
self.assertEqual(len(IdxModel._meta.indexes), 1)
self.assertIndexes(IdxModel, [
('CREATE INDEX test_idx ON idxmodel2 (name)', []),
])
def test_index_nulls_distinct(self):
class A(self.database.Model):
key = CharField()
A.add_index(A.key, unique=True, nulls_distinct=True)
self.assertIndexes(A, [
('CREATE UNIQUE INDEX "a_key" ON "a" ("key") NULLS DISTINCT', []),
])
class B(self.database.Model):
key = CharField()
B.add_index(A.key, unique=True, nulls_distinct=False)
self.assertIndexes(B, [
('CREATE UNIQUE INDEX "b_key" ON "b" ("key") NULLS NOT DISTINCT',
[]),
])
class C(self.database.Model):
key = CharField()
C.add_index(C.key, unique=True, nulls_distinct=True, where=(
fn.LOWER(fn.SUBSTR(C.key, 1, 1)) == 'c'))
self.assertIndexes(C, [
('CREATE UNIQUE INDEX "c_key" ON "c" ("key") '
'NULLS DISTINCT '
'WHERE (LOWER(SUBSTR("key", ?, ?)) = ?)',
[1, 1, 'c']),
])
def test_legacy_model_table_and_indexes(self):
class Base(Model):
class Meta:
database = self.database
class WebHTTPRequest(Base):
timestamp = DateTimeField(index=True)
data = TextField()
self.assertTrue(WebHTTPRequest._meta.legacy_table_names)
self.assertCreateTable(WebHTTPRequest, [
('CREATE TABLE "webhttprequest" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"timestamp" DATETIME NOT NULL, "data" TEXT NOT NULL)'),
('CREATE INDEX "webhttprequest_timestamp" ON "webhttprequest" '
'("timestamp")')])
# Table name is explicit, but legacy table names == false, so we get
# the new index name format.
class FooBar(Base):
data = IntegerField(unique=True)
class Meta:
legacy_table_names = False
table_name = 'foobar_tbl'
self.assertFalse(FooBar._meta.legacy_table_names)
self.assertCreateTable(FooBar, [
('CREATE TABLE "foobar_tbl" ("id" INTEGER NOT NULL PRIMARY KEY, '
'"data" INTEGER NOT NULL)'),
('CREATE UNIQUE INDEX "foobar_tbl_data" ON "foobar_tbl" ("data")'),
])
# Table name is explicit and legacy table names == true, so we get
# the old index name format.
class FooBar2(Base):
data = IntegerField(unique=True)
class Meta:
table_name = 'foobar2_tbl'
self.assertTrue(FooBar2._meta.legacy_table_names)
self.assertCreateTable(FooBar2, [
('CREATE TABLE "foobar2_tbl" ("id" INTEGER NOT NULL PRIMARY KEY, '
'"data" INTEGER NOT NULL)'),
('CREATE UNIQUE INDEX "foobar2_data" ON "foobar2_tbl" ("data")')])
def test_without_pk(self):
class NoPK(TestModel):
data = TextField()
class Meta:
database = self.database
primary_key = False
self.assertCreateTable(NoPK, [
('CREATE TABLE "no_pk" ("data" TEXT NOT NULL)')])
def test_without_rowid(self):
class NoRowid(TestModel):
key = TextField(primary_key=True)
value = TextField()
class Meta:
database = self.database
without_rowid = True
self.assertCreateTable(NoRowid, [
('CREATE TABLE "no_rowid" ('
'"key" TEXT NOT NULL PRIMARY KEY, '
'"value" TEXT NOT NULL) WITHOUT ROWID')])
# Subclasses do not inherit "without_rowid" setting.
class SubNoRowid(NoRowid): pass
self.assertCreateTable(SubNoRowid, [
('CREATE TABLE "sub_no_rowid" ('
'"key" TEXT NOT NULL PRIMARY KEY, '
'"value" TEXT NOT NULL)')])
def test_strict_tables(self):
class Strict(TestModel):
key = TextField(primary_key=True)
value = TextField()
class Meta:
database = self.database
strict_tables = True
self.assertCreateTable(Strict, [
('CREATE TABLE "strict" ('
'"key" TEXT NOT NULL PRIMARY KEY, '
'"value" TEXT NOT NULL) STRICT')])
# Subclasses *do* inherit "strict_tables" setting.
class SubStrict(Strict): pass
self.assertCreateTable(SubStrict, [
('CREATE TABLE "sub_strict" ('
'"key" TEXT NOT NULL PRIMARY KEY, '
'"value" TEXT NOT NULL) STRICT')])
def test_without_rowid_strict(self):
class KV(TestModel):
key = TextField(primary_key=True)
class Meta:
database = self.database
strict_tables = True
without_rowid = True
self.assertCreateTable(KV, [
('CREATE TABLE "kv" ("key" TEXT NOT NULL PRIMARY KEY) '
'STRICT, WITHOUT ROWID')])
class SKV(KV):
pass
self.assertCreateTable(SKV, [
('CREATE TABLE "skv" ("key" TEXT NOT NULL PRIMARY KEY) STRICT')])
def test_table_name(self):
class A(TestModel):
class Meta:
database = self.database
table_name = 'A_tbl'
class B(TestModel):
a = ForeignKeyField(A, backref='bs')
class Meta:
database = self.database
table_name = 'B_tbl'
self.assertCreateTable(A, [
'CREATE TABLE "A_tbl" ("id" INTEGER NOT NULL PRIMARY KEY)'])
self.assertCreateTable(B, [
('CREATE TABLE "B_tbl" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"a_id" INTEGER NOT NULL, '
'FOREIGN KEY ("a_id") REFERENCES "A_tbl" ("id"))'),
'CREATE INDEX "B_tbl_a_id" ON "B_tbl" ("a_id")'])
def test_temporary_table(self):
sql, params = User._schema._create_table(temporary=True).query()
self.assertEqual(sql, (
'CREATE TEMPORARY TABLE IF NOT EXISTS "users" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"username" VARCHAR(255) NOT NULL)'))
def test_model_temporary_table(self):
class TempUser(User):
class Meta:
temporary = True
self.reset_sql_history()
TempUser.create_table()
TempUser.drop_table()
queries = [x.msg for x in self.history]
self.assertEqual(queries, [
('CREATE TEMPORARY TABLE IF NOT EXISTS "temp_user" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"username" VARCHAR(255) NOT NULL)', []),
('DROP TABLE IF EXISTS "temp_user"', [])])
def test_drop_table(self):
sql, params = User._schema._drop_table().query()
self.assertEqual(sql, 'DROP TABLE IF EXISTS "users"')
sql, params = User._schema._drop_table(cascade=True).query()
self.assertEqual(sql, 'DROP TABLE IF EXISTS "users" CASCADE')
sql, params = User._schema._drop_table(restrict=True).query()
self.assertEqual(sql, 'DROP TABLE IF EXISTS "users" RESTRICT')
def test_table_constraints(self):
class UKV(TestModel):
key = TextField()
value = TextField()
status = IntegerField()
class Meta:
constraints = [
SQL('CONSTRAINT ukv_kv_uniq UNIQUE (key, value)'),
Check('status > 0')]
database = self.database
table_name = 'ukv'
self.assertCreateTable(UKV, [
('CREATE TABLE "ukv" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"key" TEXT NOT NULL, '
'"value" TEXT NOT NULL, '
'"status" INTEGER NOT NULL, '
'CONSTRAINT ukv_kv_uniq UNIQUE (key, value), '
'CHECK (status > 0))')])
def test_table_settings(self):
class KVSettings(TestModel):
key = TextField(primary_key=True)
value = TextField()
timestamp = TimestampField()
class Meta:
database = self.database
table_settings = ('PARTITION BY RANGE (timestamp)',
'WITHOUT ROWID')
self.assertCreateTable(KVSettings, [
('CREATE TABLE "kv_settings" ('
'"key" TEXT NOT NULL PRIMARY KEY, '
'"value" TEXT NOT NULL, '
'"timestamp" INTEGER NOT NULL) '
'PARTITION BY RANGE (timestamp) '
'WITHOUT ROWID')])
def test_table_options(self):
class TOpts(TestModel):
key = TextField()
class Meta:
database = self.database
options = {
'CHECKSUM': 1,
'COMPRESSION': 'lz4'}
self.assertCreateTable(TOpts, [
('CREATE TABLE "t_opts" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"key" TEXT NOT NULL, '
'CHECKSUM=1, COMPRESSION=lz4)')])
def test_table_and_index_creation(self):
self.assertCreateTable(Person, [
('CREATE TABLE "person" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"first" VARCHAR(255) NOT NULL, '
'"last" VARCHAR(255) NOT NULL, '
'"dob" DATE)'),
'CREATE INDEX "person_dob" ON "person" ("dob")',
('CREATE UNIQUE INDEX "person_first_last" ON '
'"person" ("first", "last")')])
self.assertCreateTable(Note, [
('CREATE TABLE "note" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"author_id" INTEGER NOT NULL, '
'"content" TEXT NOT NULL, '
'FOREIGN KEY ("author_id") REFERENCES "person" ("id"))'),
'CREATE INDEX "note_author_id" ON "note" ("author_id")'])
self.assertCreateTable(Category, [
('CREATE TABLE "category" ('
'"name" VARCHAR(20) NOT NULL PRIMARY KEY, '
'"parent_id" VARCHAR(20), '
'FOREIGN KEY ("parent_id") REFERENCES "category" ("name"))'),
'CREATE INDEX "category_parent_id" ON "category" ("parent_id")'])
self.assertCreateTable(Relationship, [
('CREATE TABLE "relationship" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"from_person_id" INTEGER NOT NULL, '
'"to_person_id" INTEGER NOT NULL, '
'FOREIGN KEY ("from_person_id") REFERENCES "person" ("id"), '
'FOREIGN KEY ("to_person_id") REFERENCES "person" ("id"))'),
('CREATE INDEX "relationship_from_person_id" '
'ON "relationship" ("from_person_id")'),
('CREATE INDEX "relationship_to_person_id" '
'ON "relationship" ("to_person_id")')])
self.assertCreateTable(TMUnique, [
('CREATE TABLE "tm_unique" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"data" TEXT NOT NULL)'),
'CREATE UNIQUE INDEX "tm_unique_data" ON "tm_unique" ("data")'])
self.assertCreateTable(TMSequence, [
('CREATE TABLE "tm_sequence" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"value" INTEGER NOT NULL DEFAULT NEXTVAL(\'test_seq\'))')])
self.assertCreateTable(TMIndexes, [
('CREATE TABLE "tm_indexes" ("id" INTEGER NOT NULL PRIMARY KEY, '
'"alpha" INTEGER NOT NULL, "beta" INTEGER NOT NULL, '
'"gamma" INTEGER NOT NULL)'),
('CREATE UNIQUE INDEX "tm_indexes_alpha_beta" '
'ON "tm_indexes" ("alpha", "beta")'),
('CREATE INDEX "tm_indexes_beta_gamma" '
'ON "tm_indexes" ("beta", "gamma")')])
self.assertCreateTable(TMConstraints, [
('CREATE TABLE "tm_constraints" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"data" INTEGER CHECK (data < 5), '
'"value" TEXT NOT NULL COLLATE NOCASE, '
'"added" DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP)')])
self.assertCreateTable(TMNamedConstraints, [
('CREATE TABLE "tm_named_constraints" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"fk_id" INTEGER, '
'"k" TEXT NOT NULL, '
'"v" INTEGER NOT NULL '
'CHECK (v in (1, 2)), '
'CONSTRAINT "tmc_fk" FOREIGN KEY ("fk_id") '
'REFERENCES "tm_named_constraints" ("id"), '
'CONSTRAINT "chk_k" CHECK (k != \'kx\'))'),
('CREATE INDEX "tm_named_constraints_fk_id" '
'ON "tm_named_constraints" ("fk_id")')])
sql, params = (TMNamedConstraints
._schema
._create_foreign_key(TMNamedConstraints.fk)
.query())
self.assertEqual(sql, (
'ALTER TABLE "tm_named_constraints" ADD CONSTRAINT "tmc_fk" '
'FOREIGN KEY ("fk_id") REFERENCES "tm_named_constraints" ("id")'))
def test_index_name_truncation(self):
class LongIndex(TestModel):
a123456789012345678901234567890 = CharField()
b123456789012345678901234567890 = CharField()
c123456789012345678901234567890 = CharField()
class Meta:
database = self.database
fields = LongIndex._meta.sorted_fields[1:]
self.assertEqual(len(fields), 3)
idx = ModelIndex(LongIndex, fields)
ctx = LongIndex._schema._create_index(idx)
self.assertSQL(ctx, (
'CREATE INDEX IF NOT EXISTS "'
'long_index_a123456789012345678901234567890_b123456789012_9dd2139'
'" ON "long_index" ('
'"a123456789012345678901234567890", '
'"b123456789012345678901234567890", '
'"c123456789012345678901234567890")'), [])
def test_fk_non_pk_ddl(self):
class A(Model):
cf = CharField(max_length=100, unique=True)
df = DecimalField(
max_digits=4,
decimal_places=2,
auto_round=True,
unique=True)
class Meta:
database = self.database
class CF(TestModel):
a = ForeignKeyField(A, field='cf')
class Meta:
database = self.database
class DF(TestModel):
a = ForeignKeyField(A, field='df')
class Meta:
database = self.database
sql, params = CF._schema._create_table(safe=False).query()
self.assertEqual(sql, (
'CREATE TABLE "cf" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"a_id" VARCHAR(100) NOT NULL, '
'FOREIGN KEY ("a_id") REFERENCES "a" ("cf"))'))
sql, params = DF._schema._create_table(safe=False).query()
self.assertEqual(sql, (
'CREATE TABLE "df" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"a_id" DECIMAL(4, 2) NOT NULL, '
'FOREIGN KEY ("a_id") REFERENCES "a" ("df"))'))
def test_deferred_foreign_key(self):
class Language(TestModel):
name = CharField()
selected_snippet = DeferredForeignKey('Snippet', null=True)
class Meta:
database = self.database
class Snippet(TestModel):
code = TextField()
language = ForeignKeyField(Language, backref='snippets')
class Meta:
database = self.database
self.assertEqual(Snippet._meta.fields['language'].rel_model, Language)
self.assertEqual(Language._meta.fields['selected_snippet'].rel_model,
Snippet)
sql, params = Snippet._schema._create_table(safe=False).query()
self.assertEqual(sql, (
'CREATE TABLE "snippet" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"code" TEXT NOT NULL, '
'"language_id" INTEGER NOT NULL, '
'FOREIGN KEY ("language_id") REFERENCES "language" ("id"))'))
sql, params = Language._schema._create_table(safe=False).query()
self.assertEqual(sql, (
'CREATE TABLE "language" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"name" VARCHAR(255) NOT NULL, '
'"selected_snippet_id" INTEGER)'))
sql, params = (Language
._schema
._create_foreign_key(Language.selected_snippet)
.query())
self.assertEqual(sql, (
'ALTER TABLE "language" ADD CONSTRAINT '
'"fk_language_selected_snippet_id_refs_snippet" '
'FOREIGN KEY ("selected_snippet_id") REFERENCES "snippet" ("id")'))
class SnippetComment(TestModel):
snippet_long_foreign_key_identifier = ForeignKeyField(Snippet)
comment = TextField()
class Meta:
database = self.database
sql, params = SnippetComment._schema._create_table(safe=True).query()
self.assertEqual(sql, (
'CREATE TABLE IF NOT EXISTS "snippet_comment" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"snippet_long_foreign_key_identifier_id" INTEGER NOT NULL, '
'"comment" TEXT NOT NULL, '
'FOREIGN KEY ("snippet_long_foreign_key_identifier_id") '
'REFERENCES "snippet" ("id"))'))
sql, params = (SnippetComment._schema
._create_foreign_key(
SnippetComment.snippet_long_foreign_key_identifier)
.query())
self.assertEqual(sql, (
'ALTER TABLE "snippet_comment" ADD CONSTRAINT "'
'fk_snippet_comment_snippet_long_foreign_key_identifier_i_2a8b87d"'
' FOREIGN KEY ("snippet_long_foreign_key_identifier_id") '
'REFERENCES "snippet" ("id")'))
def test_deferred_foreign_key_inheritance(self):
class Base(TestModel):
class Meta:
database = self.database
class WithTimestamp(Base):
timestamp = TimestampField()
class Tweet(Base):
user = DeferredForeignKey('DUser')
content = TextField()
class TimestampTweet(Tweet, WithTimestamp): pass
class DUser(Base):
username = TextField()
sql, params = Tweet._schema._create_table(safe=False).query()
self.assertEqual(sql, (
'CREATE TABLE "tweet" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"content" TEXT NOT NULL, '
'"user_id" INTEGER NOT NULL)'))
sql, params = TimestampTweet._schema._create_table(safe=False).query()
self.assertEqual(sql, (
'CREATE TABLE "timestamp_tweet" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"timestamp" INTEGER NOT NULL, '
'"content" TEXT NOT NULL, '
'"user_id" INTEGER NOT NULL)'))
def test_identity_field(self):
class PG10Identity(TestModel):
id = IdentityField()
data = TextField()
class Meta:
database = self.database
self.assertCreateTable(PG10Identity, [
('CREATE TABLE "pg10_identity" ('
'"id" INT GENERATED BY DEFAULT AS IDENTITY NOT NULL PRIMARY KEY, '
'"data" TEXT NOT NULL)'),
])
def test_self_fk_inheritance(self):
class BaseCategory(TestModel):
parent = ForeignKeyField('self', backref='children')
class Meta:
database = self.database
class CatA1(BaseCategory):
name_a1 = TextField()
class CatA2(CatA1):
name_a2 = TextField()
self.assertTrue(CatA1.parent.rel_model is CatA1)
self.assertTrue(CatA2.parent.rel_model is CatA2)
self.assertCreateTable(CatA1, [
('CREATE TABLE "cat_a1" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"parent_id" INTEGER NOT NULL, '
'"name_a1" TEXT NOT NULL, '
'FOREIGN KEY ("parent_id") REFERENCES "cat_a1" ("id"))'),
('CREATE INDEX "cat_a1_parent_id" ON "cat_a1" ("parent_id")')])
self.assertCreateTable(CatA2, [
('CREATE TABLE "cat_a2" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"parent_id" INTEGER NOT NULL, '
'"name_a1" TEXT NOT NULL, '
'"name_a2" TEXT NOT NULL, '
'FOREIGN KEY ("parent_id") REFERENCES "cat_a2" ("id"))'),
('CREATE INDEX "cat_a2_parent_id" ON "cat_a2" ("parent_id")')])
def test_field_ddl(self):
class Base(self.database.Model):
pass
class FC(Base):
code = FixedCharField(max_length=5)
name = CharField()
class Dbl(Base):
value = DoubleField()
label = CharField()
class SmInt(Base):
value = SmallIntegerField()
label = CharField()
self.assertSQL(FC._schema._create_table(False), (
'CREATE TABLE "fc" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"code" CHAR(5) NOT NULL, '
'"name" VARCHAR(255) NOT NULL)'), [])
self.assertSQL(Dbl._schema._create_table(False), (
'CREATE TABLE "dbl" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"value" REAL NOT NULL, '
'"label" VARCHAR(255) NOT NULL)'), [])
self.assertSQL(SmInt._schema._create_table(False), (
'CREATE TABLE "smint" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"value" INTEGER NOT NULL, '
'"label" VARCHAR(255) NOT NULL)'), [])
class TestDDLAdditionalSQL(ModelDatabaseTestCase):
database = get_in_memory_db()
requires = [User, Note, Person]
def test_not_null_vs_null(self):
class NullableModel(TestModel):
required = CharField()
optional = CharField(null=True)
with_default = IntegerField(default=0)
class Meta:
database = self.database
self.assertSQL(NullableModel._schema._create_table(False), (
'CREATE TABLE "nullable_model" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"required" VARCHAR(255) NOT NULL, '
'"optional" VARCHAR(255), '
'"with_default" INTEGER NOT NULL)'), [])
def test_create_table_safe_values(self):
self.assertSQL(User._schema._create_table(safe=False), (
'CREATE TABLE "users" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"username" VARCHAR(255) NOT NULL)'), [])
self.assertSQL(User._schema._create_table(safe=True), (
'CREATE TABLE IF NOT EXISTS "users" ('
'"id" INTEGER NOT NULL PRIMARY KEY, '
'"username" VARCHAR(255) NOT NULL)'), [])
def test_drop_table_safe_values(self):
self.assertSQL(User._schema._drop_table(safe=False),
'DROP TABLE "users"', [])
self.assertSQL(User._schema._drop_table(safe=True),
'DROP TABLE IF EXISTS "users"', [])
def test_drop_table_cascade_restrict(self):
self.assertSQL(Note._schema._drop_table(cascade=True),
'DROP TABLE IF EXISTS "note" CASCADE', [])
self.assertSQL(Note._schema._drop_table(restrict=True),
'DROP TABLE IF EXISTS "note" RESTRICT', [])
def test_drop_indexes_sql(self):
class Indexed(TestModel):
val = CharField()
class Meta:
database = self.database
indexes = ((('val',), True),)
results = Indexed._schema._drop_indexes(safe=True)
self.assertEqual(len(results), 1)
sql, _ = results[0].query()
self.assertTrue(sql.startswith('DROP INDEX '))
self.assertIn('indexed_val', sql)
def test_create_foreign_key_sql(self):
self.assertSQL(Note._schema._create_foreign_key(Note.author), (
'ALTER TABLE "note" ADD CONSTRAINT '
'"fk_note_author_id_refs_person" '
'FOREIGN KEY ("author_id") REFERENCES "person" ("id")'), [])
def test_truncate_table_sqlite(self):
# SQLite truncate falls back to DELETE FROM.
ctx = User._schema._truncate_table()
self.assertSQL(ctx, 'DELETE FROM "users"', [])
def test_database_required_error(self):
# SchemaManager raises ImproperlyConfigured when no DB set.
class Orphan(Model):
name = CharField()
self.assertRaises(ImproperlyConfigured,
lambda: Orphan._schema.database)
# ===========================================================================
# CREATE TABLE AS (SQL generation and integration)
# ===========================================================================
class TMKV(TestModel):
key = CharField()
value = IntegerField()
extra = IntegerField()
class TMKVNew(TestModel):
key = CharField()
val = IntegerField()
class Meta:
primary_key = False
table_name = 'tmkv_new'
class TestCreateTableAsSQL(ModelDatabaseTestCase):
database = get_in_memory_db()
requires = [TMKV]
def test_create_table_as_sql(self):
query = (TMKV
.select(TMKV.key, TMKV.value.alias('val'))
.where(TMKV.extra < 4))
ctx = TMKV._schema._create_table_as('tmkv_new', query)
self.assertSQL(ctx, (
'CREATE TABLE IF NOT EXISTS "tmkv_new" AS '
'SELECT "t1"."key", "t1"."value" AS "val" FROM "tmkv" AS "t1" '
'WHERE ("t1"."extra" < ?)'), [4])
ctx = TMKV._schema._create_table_as(('alt', 'tmkv_new'), query)
self.assertSQL(ctx, (
'CREATE TABLE IF NOT EXISTS "alt"."tmkv_new" AS '
'SELECT "t1"."key", "t1"."value" AS "val" FROM "tmkv" AS "t1" '
'WHERE ("t1"."extra" < ?)'), [4])
class NoteX(TestModel):
content = TextField()
timestamp = TimestampField()
status = IntegerField()
flags = IntegerField()
class TestCreateAs(ModelTestCase):
requires = [NoteX]
test_data = (
# name, timestamp, status, flags.
(1, 'n1', datetime.datetime(2019, 1, 1), 1, 1),
(2, 'n2', datetime.datetime(2019, 1, 2), 2, 1),
(3, 'n3', datetime.datetime(2019, 1, 3), 9, 1),
(4, 'nx', datetime.datetime(2019, 1, 1), 9, 0))
def setUp(self):
super(TestCreateAs, self).setUp()
fields = NoteX._meta.sorted_fields
NoteX.insert_many(self.test_data, fields=fields).execute()
def tearDown(self):
class Note2(TestModel):
class Meta:
database = self.database
self.database.drop_tables([Note2])
super(TestCreateAs, self).tearDown()
def test_create_as(self):
status = Case(NoteX.status, (
(1, 'published'),
(2, 'draft'),
(9, 'deleted')))
query = (NoteX
.select(NoteX.id, NoteX.content, NoteX.timestamp,
status.alias('status'))
.where(NoteX.flags == SQL('1')))
query.create_table('note2')
class Note2(TestModel):
id = IntegerField()
content = TextField()
timestamp = TimestampField()
status = TextField()
class Meta:
database = self.database
query = Note2.select().order_by(Note2.id)
self.assertEqual(list(query.tuples()), [
(1, 'n1', datetime.datetime(2019, 1, 1), 'published'),
(2, 'n2', datetime.datetime(2019, 1, 2), 'draft'),
(3, 'n3', datetime.datetime(2019, 1, 3), 'deleted')])
class TestCreateTableAs(ModelTestCase):
requires = [TMKV]
def tearDown(self):
try:
TMKVNew.drop_table(safe=True)
except:
pass
super(TestCreateTableAs, self).tearDown()
def test_create_table_as(self):
TMKV.insert_many([('k%02d' % i, i, i) for i in range(10)]).execute()
query = (TMKV
.select(TMKV.key, TMKV.value.alias('val'))
.where(TMKV.extra < 4))
query.create_table('tmkv_new', safe=True)
expected = ['key', 'val']
if IS_CRDB: expected.append('rowid') # CRDB adds this.
self.assertEqual(
[col.name for col in self.database.get_columns('tmkv_new')],
expected)
query = TMKVNew.select().order_by(TMKVNew.key)
self.assertEqual([(r.key, r.val) for r in query],
[('k00', 0), ('k01', 1), ('k02', 2), ('k03', 3)])
# ===========================================================================
# Table name, truncation, view field mapping, and named constraints
# ===========================================================================
class TestViewFieldMapping(ModelTestCase):
requires = [User]
def tearDown(self):
try:
self.execute('drop view user_testview_fm')
except Exception as exc:
pass
super(TestViewFieldMapping, self).tearDown()
def test_view_field_mapping(self):
user = User.create(username='huey')
self.execute('create view user_testview_fm as '
'select id, username from users')
class View(User):
class Meta:
table_name = 'user_testview_fm'
self.assertEqual([(v.id, v.username) for v in View.select()],
[(user.id, 'huey')])
class TestModelSetTableName(BaseTestCase):
def test_set_table_name(self):
class Foo(TestModel):
pass
self.assertEqual(Foo._meta.table_name, 'foo')
self.assertEqual(Foo._meta.table.__name__, 'foo')
# Writing the attribute directly does not update the cached Table name.
Foo._meta.table_name = 'foo2'
self.assertEqual(Foo._meta.table.__name__, 'foo')
# Use the helper-method.
Foo._meta.set_table_name('foo3')
self.assertEqual(Foo._meta.table.__name__, 'foo3')
class TestTruncateTable(ModelTestCase):
requires = [User]
def test_truncate_table(self):
for i in range(3):
User.create(username='u%s' % i)
ctx = User._schema._truncate_table()
if IS_SQLITE:
self.assertSQL(ctx, 'DELETE FROM "users"', [])
else:
sql, _ = ctx.query()
self.assertTrue(sql.startswith('TRUNCATE TABLE '))
User.truncate_table()
self.assertEqual(User.select().count(), 0)
class TestNamedConstraintsIntegration(ModelTestCase):
requires = [TMNamedConstraints]
def setUp(self):
super(TestNamedConstraintsIntegration, self).setUp()
if IS_SQLITE:
self.database.pragma('foreign_keys', 'on')
def test_named_constraints_integration(self):
t = TMNamedConstraints.create(k='k1', v=1) # Sanity test.
fails = [
{'fk': t.id - 1, 'k': 'k2', 'v': 1}, # Invalid fk.
{'fk': t.id, 'k': 'k3', 'v': 0}, # Invalid val.
{'fk': t.id, 'k': 'kx', 'v': 1}] # Invalid key.
for f in fails:
# MySQL may use OperationalError.
with self.assertRaises((IntegrityError, OperationalError)):
with self.database.atomic() as tx:
TMNamedConstraints.create(**f)
self.assertEqual(len(TMNamedConstraints), 1)
# ===========================================================================
# Gap coverage: Truncate SQL variants, sequence error paths
# ===========================================================================
class TestTruncateTableSQL(ModelDatabaseTestCase):
database = get_in_memory_db()
requires = [User]
def test_truncate_options(self):
# Create a fake database that supports TRUNCATE.
class FakeDB(SqliteDatabase):
truncate_table = True
fake_db = FakeDB(':memory:')
with fake_db:
class FakeUser(TestModel):
username = CharField()
class Meta:
database = fake_db
table_name = 'users'
query = FakeUser._schema._truncate_table()
self.assertSQL(query, 'TRUNCATE TABLE "users"')
query = FakeUser._schema._truncate_table(restart_identity=True)
self.assertSQL(query, 'TRUNCATE TABLE "users" RESTART IDENTITY')
query = FakeUser._schema._truncate_table(cascade=True)
self.assertSQL(query, 'TRUNCATE TABLE "users" CASCADE')
query = FakeUser._schema._truncate_table(restart_identity=True,
cascade=True)
self.assertSQL(query, ('TRUNCATE TABLE "users" '
'RESTART IDENTITY CASCADE'))
class TestSchemaSequenceErrors(ModelDatabaseTestCase):
database = get_in_memory_db()
requires = [User]
def test_check_sequences_no_support(self):
schema = User._schema
self.assertRaises(ValueError, schema._check_sequences,
User._meta.primary_key)
def test_check_sequences_no_sequence_on_field(self):
class SeqDB(SqliteDatabase):
sequences = True
fake_db = SeqDB(':memory:')
with fake_db:
class FakeModel(TestModel):
data = IntegerField()
class Meta:
database = fake_db
with self.assertRaises(ValueError):
FakeModel._schema._check_sequences(FakeModel.data)
class TestSchemaCreateAllDropAll(ModelTestCase):
requires = [User]
def test_create_all_drop_all(self):
class TempModel(TestModel):
data = CharField()
TempModel._meta.set_database(self.database)
TempModel._schema.create_all()
self.assertTrue(self.database.table_exists('temp_model'))
TempModel._schema.drop_all()
self.assertFalse(self.database.table_exists('temp_model'))
@requires_postgresql
class TestSchemaGetIndexes(ModelTestCase):
def setUp(self):
super(TestSchemaGetIndexes, self).setUp()
queries = [
'create schema s1', 'create schema s2',
'create table s1.t (c1 integer, c2 integer, c3 integer)',
'create table s2.t (c1 integer, c2 integer, c3 integer)',
'create index i1 on s1.t (c1, c2)',
'create index i1 on s2.t (c1, c2)',
'create index i2 on s1.t (c1)',
]
with self.database:
for query in queries:
self.database.execute_sql(query)
def tearDown(self):
with self.database:
self.database.execute_sql('drop schema s1 cascade')
self.database.execute_sql('drop schema s2 cascade')
super(TestSchemaGetIndexes, self).setUp()
def test_schema_get_indexes(self):
tables = self.database.get_tables(schema='s1')
self.assertEqual(tables, ['t'])
idxs = self.database.get_indexes('t', schema='s1')
self.assertEqual([(i.name, i.columns) for i in idxs],
[('i1', ['c1', 'c2']), ('i2', ['c1'])])
idxs = self.database.get_indexes('t', schema='s2')
self.assertEqual([(i.name, i.columns) for i in idxs],
[('i1', ['c1', 'c2'])])