diff --git a/peewee.py b/peewee.py index 6728c7be..ff14834f 100644 --- a/peewee.py +++ b/peewee.py @@ -3562,7 +3562,7 @@ class Database(_callable_context_manager): return _BoundModelsContext(models, self, bind_refs, bind_backrefs) def get_noop_select(self, ctx): - return ctx.sql(Select().columns(SQL('0')).where(SQL('0'))) + return ctx.literal('SELECT 0 WHERE 0') @property def Model(self): @@ -4335,7 +4335,7 @@ class PostgresqlDatabase(Database): return fn.to_timestamp(date_field) def get_noop_select(self, ctx): - return ctx.sql(Select().columns(SQL('0')).where(SQL('false'))) + return ctx.literal('SELECT 0 WHERE false') def set_time_zone(self, timezone): self.execute_sql('set time zone \'%s\';' % timezone.replace("'", "''")) diff --git a/tests/db_tests.py b/tests/db_tests.py index c3a39e6f..a55fd848 100644 --- a/tests/db_tests.py +++ b/tests/db_tests.py @@ -3,14 +3,14 @@ Database connection, pragmas, introspection, threading, and utility tests. Test case ordering: -1. Core database features (pragmas, connection, context settings) -2. Connection semantics -3. Thread safety -4. SQLite-specific (isolation, attach) -5. Deferred database and proxy -6. Introspection -7. Exception handling -8. Utilities (model property, sort_models, chunked) +* Core database features (connection, context settings) +* Session helper and context manager usage. +* Introspection +* Thread-safety +* Deferred db / proxy +* SQLite-specific (pragma, isolation, attach) +* Exception wrappers +* Utilities """ from itertools import permutations from queue import Queue @@ -18,6 +18,7 @@ import platform import re import threading import time +import warnings from peewee import * from peewee import Database @@ -54,98 +55,6 @@ from .base_models import User # =========================================================================== class TestDatabase(DatabaseTestCase): - database = get_sqlite_db() - - def test_pragmas(self): - self.database.cache_size = -2048 - self.assertEqual(self.database.cache_size, -2048) - self.database.cache_size = -4096 - self.assertEqual(self.database.cache_size, -4096) - - self.database.foreign_keys = 'on' - self.assertEqual(self.database.foreign_keys, 1) - self.database.foreign_keys = 'off' - self.assertEqual(self.database.foreign_keys, 0) - - def test_appid_user_version(self): - self.assertEqual(self.database.application_id, 0) - self.assertEqual(self.database.user_version, 0) - self.database.application_id = 1 - self.database.user_version = 2 - self.assertEqual(self.database.application_id, 1) - self.assertEqual(self.database.user_version, 2) - self.assertTrue(self.database.close()) - self.assertTrue(self.database.connect()) - self.assertEqual(self.database.application_id, 1) - self.assertEqual(self.database.user_version, 2) - - def test_timeout_semantics(self): - self.assertEqual(self.database.timeout, 5) - self.assertEqual(self.database.pragma('busy_timeout'), 5000) - - self.database.timeout = 2.5 - self.assertEqual(self.database.timeout, 2.5) - self.assertEqual(self.database.pragma('busy_timeout'), 2500) - - self.database.close() - self.database.connect() - - self.assertEqual(self.database.timeout, 2.5) - self.assertEqual(self.database.pragma('busy_timeout'), 2500) - - def test_pragmas_deferred(self): - pragmas = (('journal_mode', 'wal'),) - db = SqliteDatabase(None, pragmas=pragmas) - self.assertEqual(db._pragmas, pragmas) - - # Test pragmas preserved after initializing. - db.init(':memory:') - self.assertEqual(db._pragmas, pragmas) - - db = SqliteDatabase(None) - self.assertEqual(db._pragmas, ()) - - # Test pragmas are set and subsequently overwritten. - db.init(':memory:', pragmas=pragmas) - self.assertEqual(db._pragmas, pragmas) - - db.init(':memory:', pragmas=()) - self.assertEqual(db._pragmas, ()) - - # Test when specified twice, the previous value is overwritten. - db = SqliteDatabase(None, pragmas=pragmas) - db.init(':memory:', pragmas=(('cache_size', -8000),)) - self.assertEqual(db._pragmas, (('cache_size', -8000),)) - - def test_pragmas_as_dict(self): - pragmas = {'journal_mode': 'wal'} - pragma_list = [('journal_mode', 'wal')] - - db = SqliteDatabase(':memory:', pragmas=pragmas) - self.assertEqual(db._pragmas, pragma_list) - - # Test deferred databases correctly handle pragma dicts. - db = SqliteDatabase(None, pragmas=pragmas) - self.assertEqual(db._pragmas, pragma_list) - - db.init(':memory:') - self.assertEqual(db._pragmas, pragma_list) - - db.init(':memory:', pragmas={}) - self.assertEqual(db._pragmas, []) - - def test_pragmas_permanent(self): - db = SqliteDatabase(':memory:') - db.execute_sql('pragma foreign_keys=0') - self.assertEqual(db.foreign_keys, 0) - - db.pragma('foreign_keys', 1, True) - self.assertEqual(db.foreign_keys, 1) - - db.close() - db.connect() - self.assertEqual(db.foreign_keys, 1) - def test_context_settings(self): class TestDatabase(Database): field_types = {'BIGINT': 'TEST_BIGINT', 'TEXT': 'TEST_TEXT'} @@ -240,13 +149,16 @@ class TestDatabase(DatabaseTestCase): self.assertEqual(state['count'], 2) def test_execute_sql(self): - self.database.execute_sql('CREATE TABLE register (val INTEGER);') - self.database.execute_sql('INSERT INTO register (val) VALUES (?), (?)', - (1337, 31337)) + p = self.database.param + self.database.execute_sql('DROP TABLE IF EXISTS register') + self.database.execute_sql('CREATE TABLE register (val INTEGER)') + self.database.execute_sql( + 'INSERT INTO register (val) VALUES (%s), (%s)' % (p, p), + (1337, 31337)) cursor = self.database.execute_sql( 'SELECT val FROM register ORDER BY val') - self.assertEqual(cursor.fetchall(), [(1337,), (31337,)]) - self.database.execute_sql('DROP TABLE register;') + self.assertEqual(list(cursor.fetchall()), [(1337,), (31337,)]) + self.database.execute_sql('DROP TABLE register') def test_bind_helpers(self): db = get_in_memory_db() @@ -401,309 +313,66 @@ class TestDatabaseConnection(DatabaseTestCase): self.database.execute_sql('drop table foo') +class TestSessionTransactions(DatabaseTestCase): + def test_session(self): + # When session is not active, commit and rollback have no effect. + self.assertFalse(self.database.in_transaction()) + self.assertFalse(self.database.session_commit()) + self.assertFalse(self.database.session_rollback()) -# =========================================================================== -# Thread safety -# =========================================================================== + tx = self.database.session_start() + self.assertTrue(self.database.in_transaction()) + self.assertEqual(self.database.transaction_depth(), 1) -class TestThreadSafety(ModelTestCase): - # HACK: This workaround increases the Sqlite busy timeout when tests are - # being run on certain architectures. - if IS_SQLITE and platform.machine() not in ('i386', 'i686', 'x86_64'): - database = new_connection(timeout=60) - nthreads = 4 - nrows = 10 - requires = [User] + tx2 = self.database.session_start() + self.assertTrue(self.database.in_transaction()) + self.assertEqual(self.database.transaction_depth(), 2) - def test_multiple_writers(self): - def create_users(idx): - for i in range(idx * self.nrows, (idx + 1) * self.nrows): - User.create(username='u%d' % i) + self.assertTrue(self.database.top_transaction() is tx2) + self.assertTrue(self.database.session_commit()) + self.assertTrue(self.database.top_transaction() is tx) + self.assertTrue(self.database.in_transaction()) + self.assertEqual(self.database.transaction_depth(), 1) - threads = [] - for i in range(self.nthreads): - threads.append(threading.Thread(target=create_users, args=(i,))) + self.assertTrue(self.database.session_rollback()) + self.assertEqual(self.database.transaction_depth(), 0) - for t in threads: t.start() - for t in threads: t.join() + def test_db_context_manager_nesting(self): + self.database.close() + self.assertTrue(self.database.is_closed()) - self.assertEqual(User.select().count(), self.nrows * self.nthreads) + with self.database: + self.assertFalse(self.database.is_closed()) + self.assertEqual(self.database.transaction_depth(), 1) + with self.database: + self.assertFalse(self.database.is_closed()) + # Inner atomic becomes savepoint (doesn't increase txn depth) + # but the ctx stack grows. + self.assertEqual(len(self.database._state.ctx), 2) - def test_multiple_readers(self): - data = Queue() - def read_user_count(n): - for i in range(n): - data.put(User.select().count()) + # Inner exited but outer still open. + self.assertFalse(self.database.is_closed()) + self.assertEqual(len(self.database._state.ctx), 1) - threads = [] - for i in range(self.nthreads): - threads.append(threading.Thread(target=read_user_count, - args=(self.nrows,))) + self.assertTrue(self.database.is_closed()) - for t in threads: t.start() - for t in threads: t.join() - self.assertEqual(data.qsize(), self.nrows * self.nthreads) - - def test_mt_general(self): - def connect_close(): - for _ in range(self.nrows): - self.database.connect() - with self.database.atomic() as txn: - self.database.execute_sql('select 1').fetchone() - self.database.close() - - threads = [] - for i in range(self.nthreads): - threads.append(threading.Thread(target=connect_close)) - - for t in threads: t.start() - for t in threads: t.join() - - def test_thread_safety_atomic(self): - @self.database.atomic() - def get_one(n): - time.sleep(n) - return User.select().first() - def run(n): - with self.database.atomic(): - self.assertEqual(get_one(n).username, 'u') - User.create(username='u') - threads = [threading.Thread(target=run, args=(i,)) - for i in (0.01, 0.03, 0.05, 0.07, 0.09, 0.02, 0.04, 0.06)] - for t in threads: t.start() - for t in threads: t.join() - - -class TestThreadSafeMetaRegression(ModelTestCase): - def test_thread_safe_meta(self): - d1 = get_in_memory_db() - d2 = get_in_memory_db() - - class Meta: - database = d1 - model_metadata_class = ThreadSafeDatabaseMetadata - attrs = {'Meta': Meta} - for i in range(1, 30): - attrs['f%d' % i] = IntegerField() - M = type('M', (TestModel,), attrs) - - sql = ('SELECT "t1"."f1", "t1"."f2", "t1"."f3", "t1"."f4" ' - 'FROM "m" AS "t1"') - query = M.select(M.f1, M.f2, M.f3, M.f4) - - def swap_db(): - for i in range(100): - self.assertEqual(M._meta.database, d1) - self.assertSQL(query, sql) - with d2.bind_ctx([M]): - self.assertEqual(M._meta.database, d2) - self.assertSQL(query, sql) - self.assertEqual(M._meta.database, d1) - self.assertSQL(query, sql) - - # From a separate thread, swap the database and verify it works - # correctly. - threads = [threading.Thread(target=swap_db) - for i in range(10)] - for t in threads: t.start() - for t in threads: t.join() - - # In the main thread the original database has not been altered. - self.assertEqual(M._meta.database, d1) - self.assertSQL(query, sql) - -# =========================================================================== -# Deferred database, proxy, and schema namespace -# =========================================================================== - -class TestDeferredDatabase(BaseTestCase): - def test_deferred_database(self): - deferred_db = SqliteDatabase(None) - self.assertTrue(deferred_db.deferred) - - class DeferredModel(Model): - class Meta: - database = deferred_db - - self.assertRaises(Exception, deferred_db.connect) - query = DeferredModel.select() - self.assertRaises(Exception, query.execute) - - deferred_db.init(':memory:') - self.assertFalse(deferred_db.deferred) - - conn = deferred_db.connect() - self.assertFalse(deferred_db.is_closed()) - DeferredModel._schema.create_all() - self.assertEqual(list(DeferredModel.select()), []) - - deferred_db.init(None) - self.assertTrue(deferred_db.deferred) - - # The connection was automatically closed. - self.assertTrue(deferred_db.is_closed()) - - -class TestDBProxy(BaseTestCase): - def test_proxy_context_manager(self): - db = Proxy() - class User(Model): - username = TextField() - - class Meta: - database = db - - self.assertRaises(AttributeError, User.create_table) - - sqlite_db = SqliteDatabase(':memory:') - db.initialize(sqlite_db) - User.create_table() - with db: - self.assertFalse(db.is_closed()) + def test_init_closes_open_connection(self): + db = get_in_memory_db() + db.connect() + self.assertFalse(db.is_closed()) + db.init(':memory:') self.assertTrue(db.is_closed()) - def test_db_proxy(self): - db = Proxy() - class BaseModel(Model): - class Meta: - database = db - - class User(BaseModel): - username = TextField() - - class Tweet(BaseModel): - user = ForeignKeyField(User, backref='tweets') - message = TextField() - - sqlite_db = SqliteDatabase(':memory:') - db.initialize(sqlite_db) - - self.assertEqual(User._meta.database.database, ':memory:') - self.assertEqual(Tweet._meta.database.database, ':memory:') - - self.assertTrue(User._meta.database.is_closed()) - self.assertTrue(Tweet._meta.database.is_closed()) - sqlite_db.connect() - self.assertFalse(User._meta.database.is_closed()) - self.assertFalse(Tweet._meta.database.is_closed()) - sqlite_db.close() - - def test_proxy_decorator(self): - db = DatabaseProxy() - - @db.connection_context() - def with_connection(): - self.assertFalse(db.is_closed()) - - @db.atomic() - def with_transaction(): - self.assertTrue(db.in_transaction()) - - @db.manual_commit() - def with_manual_commit(): - self.assertTrue(db.in_transaction()) - - db.initialize(SqliteDatabase(':memory:')) - with_connection() - self.assertTrue(db.is_closed()) - with_transaction() - self.assertFalse(db.in_transaction()) - with_manual_commit() - self.assertFalse(db.in_transaction()) - - def test_proxy_bind_ctx_callbacks(self): - db = Proxy() - class BaseModel(Model): - class Meta: - database = db - - class Hook(BaseModel): - data = BlobField() # Attaches hook to configure blob-type. - - self.assertTrue(Hook.data._constructor is bytearray) - - class CustomSqliteDB(SqliteDatabase): - sentinel = object() - def get_binary_type(self): - return self.sentinel - - custom_db = CustomSqliteDB(':memory:') - - with custom_db.bind_ctx([Hook]): - self.assertTrue(Hook.data._constructor is custom_db.sentinel) - - self.assertTrue(Hook.data._constructor is bytearray) - - custom_db.bind([Hook]) - self.assertTrue(Hook.data._constructor is custom_db.sentinel) - - -class CatToy(TestModel): - description = TextField() - - class Meta: - schema = 'huey' - - -@requires_postgresql -class TestSchemaNamespace(ModelTestCase): - requires = [CatToy] - - def setUp(self): - with self.database: - self.execute('CREATE SCHEMA huey;') - super(TestSchemaNamespace, self).setUp() - - def tearDown(self): - super(TestSchemaNamespace, self).tearDown() - with self.database: - self.execute('DROP SCHEMA huey;') - - def test_schema(self): - toy = CatToy.create(description='fur mouse') - toy_db = CatToy.select().where(CatToy.id == toy.id).get() - self.assertEqual(toy.id, toy_db.id) - self.assertEqual(toy.description, toy_db.description) + # Can reconnect after re-init. + db.connect() + self.assertFalse(db.is_closed()) + db.close() # =========================================================================== -# SQLite isolation, introspection, and ATTACH +# Introspection. # =========================================================================== -class TestSqliteIsolation(ModelTestCase): - database = get_sqlite_db() - requires = [User] - - def test_sqlite_isolation(self): - for username in ('u1', 'u2', 'u3'): User.create(username=username) - - new_db = get_sqlite_db() - curs = new_db.execute_sql('SELECT COUNT(*) FROM users') - self.assertEqual(curs.fetchone()[0], 3) - - self.assertEqual(User.select().count(), 3) - self.assertEqual(User.delete().execute(), 3) - - with self.database.atomic(): - User.create(username='u4') - User.create(username='u5') - - # Second conn does not see the changes. - curs = new_db.execute_sql('SELECT COUNT(*) FROM users') - self.assertEqual(curs.fetchone()[0], 0) - - # Third conn does not see the changes. - new_db2 = get_sqlite_db() - curs = new_db2.execute_sql('SELECT COUNT(*) FROM users') - self.assertEqual(curs.fetchone()[0], 0) - - # Original connection sees its own changes. - self.assertEqual(User.select().count(), 2) - - curs = new_db.execute_sql('SELECT COUNT(*) FROM users') - self.assertEqual(curs.fetchone()[0], 2) - - class UniqueModel(TestModel): name = CharField(unique=True) @@ -884,6 +553,427 @@ class TestIntrospection(ModelTestCase): ('parent_id', 'category', 'name', 'category')]) +# =========================================================================== +# Thread safety +# =========================================================================== + +class TestThreadSafety(ModelTestCase): + # HACK: This workaround increases the Sqlite busy timeout when tests are + # being run on certain architectures. + if IS_SQLITE and platform.machine() not in ('i386', 'i686', 'x86_64'): + database = new_connection(timeout=60) + nthreads = 4 + nrows = 10 + requires = [User] + + def test_thread_safe_false_uses_noop_lock(self): + from peewee import _NoopLock, _ConnectionState + db = SqliteDatabase(':memory:', thread_safe=False) + self.assertIsInstance(db._lock, _NoopLock) + self.assertIsInstance(db._state, _ConnectionState) + + db.connect() + self.assertFalse(db.is_closed()) + db.close() + self.assertTrue(db.is_closed()) + + def test_multiple_writers(self): + def create_users(idx): + for i in range(idx * self.nrows, (idx + 1) * self.nrows): + User.create(username='u%d' % i) + + threads = [] + for i in range(self.nthreads): + threads.append(threading.Thread(target=create_users, args=(i,))) + + for t in threads: t.start() + for t in threads: t.join() + + self.assertEqual(User.select().count(), self.nrows * self.nthreads) + + def test_multiple_readers(self): + data = Queue() + def read_user_count(n): + for i in range(n): + data.put(User.select().count()) + + threads = [] + for i in range(self.nthreads): + threads.append(threading.Thread(target=read_user_count, + args=(self.nrows,))) + + for t in threads: t.start() + for t in threads: t.join() + self.assertEqual(data.qsize(), self.nrows * self.nthreads) + + def test_mt_general(self): + def connect_close(): + for _ in range(self.nrows): + self.database.connect() + with self.database.atomic() as txn: + self.database.execute_sql('select 1').fetchone() + self.database.close() + + threads = [] + for i in range(self.nthreads): + threads.append(threading.Thread(target=connect_close)) + + for t in threads: t.start() + for t in threads: t.join() + + def test_thread_safety_atomic(self): + @self.database.atomic() + def get_one(n): + time.sleep(n) + return User.select().first() + def run(n): + with self.database.atomic(): + self.assertEqual(get_one(n).username, 'u') + User.create(username='u') + threads = [threading.Thread(target=run, args=(i,)) + for i in (0.01, 0.03, 0.05, 0.07, 0.09, 0.02, 0.04, 0.06)] + for t in threads: t.start() + for t in threads: t.join() + + +class TestThreadSafeMetaRegression(ModelTestCase): + def test_thread_safe_meta(self): + d1 = get_in_memory_db() + d2 = get_in_memory_db() + + class Meta: + database = d1 + model_metadata_class = ThreadSafeDatabaseMetadata + attrs = {'Meta': Meta} + for i in range(1, 30): + attrs['f%d' % i] = IntegerField() + M = type('M', (TestModel,), attrs) + + sql = ('SELECT "t1"."f1", "t1"."f2", "t1"."f3", "t1"."f4" ' + 'FROM "m" AS "t1"') + query = M.select(M.f1, M.f2, M.f3, M.f4) + + def swap_db(): + for i in range(100): + self.assertEqual(M._meta.database, d1) + self.assertSQL(query, sql) + with d2.bind_ctx([M]): + self.assertEqual(M._meta.database, d2) + self.assertSQL(query, sql) + self.assertEqual(M._meta.database, d1) + self.assertSQL(query, sql) + + # From a separate thread, swap the database and verify it works + # correctly. + threads = [threading.Thread(target=swap_db) + for i in range(10)] + for t in threads: t.start() + for t in threads: t.join() + + # In the main thread the original database has not been altered. + self.assertEqual(M._meta.database, d1) + self.assertSQL(query, sql) + + +# =========================================================================== +# Deferred database, proxy, and schema namespace +# =========================================================================== + +class TestDeferredDatabase(BaseTestCase): + def test_deferred_database(self): + deferred_db = SqliteDatabase(None) + self.assertTrue(deferred_db.deferred) + + class DeferredModel(Model): + class Meta: + database = deferred_db + + self.assertRaises(Exception, deferred_db.connect) + query = DeferredModel.select() + self.assertRaises(Exception, query.execute) + + deferred_db.init(':memory:') + self.assertFalse(deferred_db.deferred) + + conn = deferred_db.connect() + self.assertFalse(deferred_db.is_closed()) + DeferredModel._schema.create_all() + self.assertEqual(list(DeferredModel.select()), []) + + deferred_db.init(None) + self.assertTrue(deferred_db.deferred) + + # The connection was automatically closed. + self.assertTrue(deferred_db.is_closed()) + + +class TestDBProxy(BaseTestCase): + def test_proxy_context_manager(self): + db = Proxy() + class User(Model): + username = TextField() + + class Meta: + database = db + + self.assertRaises(AttributeError, User.create_table) + + sqlite_db = SqliteDatabase(':memory:') + db.initialize(sqlite_db) + User.create_table() + with db: + self.assertFalse(db.is_closed()) + self.assertTrue(db.is_closed()) + + def test_db_proxy(self): + db = Proxy() + class BaseModel(Model): + class Meta: + database = db + + class User(BaseModel): + username = TextField() + + class Tweet(BaseModel): + user = ForeignKeyField(User, backref='tweets') + message = TextField() + + sqlite_db = SqliteDatabase(':memory:') + db.initialize(sqlite_db) + + self.assertEqual(User._meta.database.database, ':memory:') + self.assertEqual(Tweet._meta.database.database, ':memory:') + + self.assertTrue(User._meta.database.is_closed()) + self.assertTrue(Tweet._meta.database.is_closed()) + sqlite_db.connect() + self.assertFalse(User._meta.database.is_closed()) + self.assertFalse(Tweet._meta.database.is_closed()) + sqlite_db.close() + + def test_proxy_decorator(self): + db = DatabaseProxy() + + @db.connection_context() + def with_connection(): + self.assertFalse(db.is_closed()) + + @db.atomic() + def with_transaction(): + self.assertTrue(db.in_transaction()) + + @db.manual_commit() + def with_manual_commit(): + self.assertTrue(db.in_transaction()) + + db.initialize(SqliteDatabase(':memory:')) + with_connection() + self.assertTrue(db.is_closed()) + with_transaction() + self.assertFalse(db.in_transaction()) + with_manual_commit() + self.assertFalse(db.in_transaction()) + + def test_proxy_bind_ctx_callbacks(self): + db = Proxy() + class BaseModel(Model): + class Meta: + database = db + + class Hook(BaseModel): + data = BlobField() # Attaches hook to configure blob-type. + + self.assertTrue(Hook.data._constructor is bytearray) + + class CustomSqliteDB(SqliteDatabase): + sentinel = object() + def get_binary_type(self): + return self.sentinel + + custom_db = CustomSqliteDB(':memory:') + + with custom_db.bind_ctx([Hook]): + self.assertTrue(Hook.data._constructor is custom_db.sentinel) + + self.assertTrue(Hook.data._constructor is bytearray) + + custom_db.bind([Hook]) + self.assertTrue(Hook.data._constructor is custom_db.sentinel) + + def test_proxy_uninitialized_getattr(self): + p = Proxy() + self.assertRaises(AttributeError, lambda: p.some_method) + + with self.assertRaises(AttributeError): + with p: + pass + + def test_proxy_setattr_error(self): + p = Proxy() + with self.assertRaises(AttributeError): + p.custom_attr = 42 + + +class CatToy(TestModel): + description = TextField() + + class Meta: + schema = 'huey' + + +@requires_postgresql +class TestSchemaNamespace(ModelTestCase): + requires = [CatToy] + + def setUp(self): + with self.database: + self.execute('CREATE SCHEMA huey;') + super(TestSchemaNamespace, self).setUp() + + def tearDown(self): + super(TestSchemaNamespace, self).tearDown() + with self.database: + self.execute('DROP SCHEMA huey;') + + def test_schema(self): + toy = CatToy.create(description='fur mouse') + toy_db = CatToy.select().where(CatToy.id == toy.id).get() + self.assertEqual(toy.id, toy_db.id) + self.assertEqual(toy.description, toy_db.description) + + +# =========================================================================== +# SQLite pragmas, isolation, introspection, and ATTACH +# =========================================================================== + +class TestSqliteDatabaseFeatures(DatabaseTestCase): + database = get_sqlite_db() + + def test_pragmas(self): + self.database.cache_size = -2048 + self.assertEqual(self.database.cache_size, -2048) + self.database.cache_size = -4096 + self.assertEqual(self.database.cache_size, -4096) + + self.database.foreign_keys = 'on' + self.assertEqual(self.database.foreign_keys, 1) + self.database.foreign_keys = 'off' + self.assertEqual(self.database.foreign_keys, 0) + + def test_appid_user_version(self): + self.assertEqual(self.database.application_id, 0) + self.assertEqual(self.database.user_version, 0) + self.database.application_id = 1 + self.database.user_version = 2 + self.assertEqual(self.database.application_id, 1) + self.assertEqual(self.database.user_version, 2) + self.assertTrue(self.database.close()) + self.assertTrue(self.database.connect()) + self.assertEqual(self.database.application_id, 1) + self.assertEqual(self.database.user_version, 2) + + def test_timeout_semantics(self): + self.assertEqual(self.database.timeout, 5) + self.assertEqual(self.database.pragma('busy_timeout'), 5000) + + self.database.timeout = 2.5 + self.assertEqual(self.database.timeout, 2.5) + self.assertEqual(self.database.pragma('busy_timeout'), 2500) + + self.database.close() + self.database.connect() + + self.assertEqual(self.database.timeout, 2.5) + self.assertEqual(self.database.pragma('busy_timeout'), 2500) + + def test_pragmas_deferred(self): + pragmas = (('journal_mode', 'wal'),) + db = SqliteDatabase(None, pragmas=pragmas) + self.assertEqual(db._pragmas, pragmas) + + # Test pragmas preserved after initializing. + db.init(':memory:') + self.assertEqual(db._pragmas, pragmas) + + db = SqliteDatabase(None) + self.assertEqual(db._pragmas, ()) + + # Test pragmas are set and subsequently overwritten. + db.init(':memory:', pragmas=pragmas) + self.assertEqual(db._pragmas, pragmas) + + db.init(':memory:', pragmas=()) + self.assertEqual(db._pragmas, ()) + + # Test when specified twice, the previous value is overwritten. + db = SqliteDatabase(None, pragmas=pragmas) + db.init(':memory:', pragmas=(('cache_size', -8000),)) + self.assertEqual(db._pragmas, (('cache_size', -8000),)) + + def test_pragmas_as_dict(self): + pragmas = {'journal_mode': 'wal'} + pragma_list = [('journal_mode', 'wal')] + + db = SqliteDatabase(':memory:', pragmas=pragmas) + self.assertEqual(db._pragmas, pragma_list) + + # Test deferred databases correctly handle pragma dicts. + db = SqliteDatabase(None, pragmas=pragmas) + self.assertEqual(db._pragmas, pragma_list) + + db.init(':memory:') + self.assertEqual(db._pragmas, pragma_list) + + db.init(':memory:', pragmas={}) + self.assertEqual(db._pragmas, []) + + def test_pragmas_permanent(self): + db = SqliteDatabase(':memory:') + db.execute_sql('pragma foreign_keys=0') + self.assertEqual(db.foreign_keys, 0) + + db.pragma('foreign_keys', 1, True) + self.assertEqual(db.foreign_keys, 1) + + db.close() + db.connect() + self.assertEqual(db.foreign_keys, 1) + + +class TestSqliteIsolation(ModelTestCase): + database = get_sqlite_db() + requires = [User] + + def test_sqlite_isolation(self): + for username in ('u1', 'u2', 'u3'): User.create(username=username) + + new_db = get_sqlite_db() + curs = new_db.execute_sql('SELECT COUNT(*) FROM users') + self.assertEqual(curs.fetchone()[0], 3) + + self.assertEqual(User.select().count(), 3) + self.assertEqual(User.delete().execute(), 3) + + with self.database.atomic(): + User.create(username='u4') + User.create(username='u5') + + # Second conn does not see the changes. + curs = new_db.execute_sql('SELECT COUNT(*) FROM users') + self.assertEqual(curs.fetchone()[0], 0) + + # Third conn does not see the changes. + new_db2 = get_sqlite_db() + curs = new_db2.execute_sql('SELECT COUNT(*) FROM users') + self.assertEqual(curs.fetchone()[0], 0) + + # Original connection sees its own changes. + self.assertEqual(User.select().count(), 2) + + curs = new_db.execute_sql('SELECT COUNT(*) FROM users') + self.assertEqual(curs.fetchone()[0], 2) + + class Data(TestModel): key = TextField() value = TextField() @@ -1107,131 +1197,6 @@ class TestChunkedUtility(BaseTestCase): self.assertEqual(result, [[0, 2], [4, 6], [8]]) -# =========================================================================== -# Gap coverage: Session edge cases and transaction helpers -# =========================================================================== - -class TestSessionEdgeCases(DatabaseTestCase): - def test_session_commit_no_transaction(self): - """session_commit() returns False when no transaction is active.""" - self.assertFalse(self.database.in_transaction()) - self.assertFalse(self.database.session_commit()) - - def test_session_rollback_no_transaction(self): - """session_rollback() returns False when no transaction is active.""" - self.assertFalse(self.database.in_transaction()) - self.assertFalse(self.database.session_rollback()) - - def test_top_transaction_empty(self): - """top_transaction() returns None when no transactions are active.""" - self.assertIsNone(self.database.top_transaction()) - - def test_top_transaction_with_active(self): - """top_transaction() returns innermost transaction.""" - with self.database.atomic() as txn: - top = self.database.top_transaction() - self.assertIsNotNone(top) - self.assertEqual(self.database.transaction_depth(), 1) - - def test_db_context_manager_nesting(self): - """Nested with db: pushes/pops ctx stack; closes only when empty.""" - self.database.close() - with self.database: - self.assertFalse(self.database.is_closed()) - self.assertEqual(self.database.transaction_depth(), 1) - with self.database: - self.assertFalse(self.database.is_closed()) - # Inner atomic becomes savepoint (doesn't increase txn depth) - # but the ctx stack grows. - self.assertEqual(len(self.database._state.ctx), 2) - # Inner exited but outer still open. - self.assertFalse(self.database.is_closed()) - self.assertEqual(len(self.database._state.ctx), 1) - # Both exited — db closed. - self.assertTrue(self.database.is_closed()) - - def test_init_closes_open_connection(self): - """Database.init() closes an existing open connection.""" - db = get_in_memory_db() - db.connect() - self.assertFalse(db.is_closed()) - db.init(':memory:') - self.assertTrue(db.is_closed()) - # Can reconnect after re-init. - db.connect() - self.assertFalse(db.is_closed()) - db.close() - - -# =========================================================================== -# Gap coverage: Database SQL-generation helpers -# =========================================================================== - -class TestDatabaseSQLHelpers(BaseTestCase): - def test_random_sqlite(self): - """SqliteDatabase.random() produces fn.random().""" - db = SqliteDatabase(':memory:') - result = db.random() - self.assertIsInstance(result, Function) - from peewee import Context - ctx = Context() - sql, _ = ctx.sql(result).query() - self.assertEqual(sql, 'random()') - - def test_get_noop_select_sqlite(self): - """SqliteDatabase.get_noop_select() produces SELECT 0 WHERE 0.""" - db = SqliteDatabase(':memory:') - from peewee import Context - ctx = db.get_sql_context() - sql, _ = db.get_noop_select(ctx).query() - self.assertIn('SELECT', sql) - self.assertIn('0', sql) - - def test_mysql_extract_server_version_mysql8(self): - """MySQLDatabase._extract_server_version parses MySQL 8.x.""" - db = MySQLDatabase.__new__(MySQLDatabase) - version = db._extract_server_version('8.0.31') - self.assertEqual(version, (8, 0, 31)) - - def test_mysql_extract_server_version_mariadb(self): - """MySQLDatabase._extract_server_version parses MariaDB 10.x.""" - db = MySQLDatabase.__new__(MySQLDatabase) - version = db._extract_server_version('5.5.5-10.6.12-MariaDB') - self.assertEqual(version, (10, 6, 12)) - - def test_mysql_extract_server_version_unknown(self): - """MySQLDatabase._extract_server_version returns (0,0,0) for unknown.""" - import warnings - db = MySQLDatabase.__new__(MySQLDatabase) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - version = db._extract_server_version('unknown') - self.assertEqual(version, (0, 0, 0)) - self.assertTrue(len(w) > 0) - - def test_mysql_extract_server_version_tuple_passthrough(self): - """MySQLDatabase._extract_server_version passes tuples through.""" - db = MySQLDatabase.__new__(MySQLDatabase) - version = db._extract_server_version((8, 0, 31)) - self.assertEqual(version, (8, 0, 31)) - - def test_thread_safe_false_uses_noop_lock(self): - """Database(thread_safe=False) uses _NoopLock.""" - from peewee import _NoopLock, _ConnectionState - db = SqliteDatabase(':memory:', thread_safe=False) - self.assertIsInstance(db._lock, _NoopLock) - self.assertIsInstance(db._state, _ConnectionState) - # Should still work for basic operations. - db.connect() - self.assertFalse(db.is_closed()) - db.close() - self.assertTrue(db.is_closed()) - - -# =========================================================================== -# Gap coverage: Utility functions -# =========================================================================== - class TestUtilityFunctions(BaseTestCase): def test_make_snake_case(self): from peewee import make_snake_case @@ -1298,13 +1263,11 @@ class TestUtilityFunctions(BaseTestCase): d1 += d2 self.assertEqual(d1, {'a': 1, 'b': 3, 'c': 4}) - def test_attrdict_getattr_missing(self): d = attrdict(a=1) self.assertEqual(d.a, 1) self.assertRaises(AttributeError, lambda: d.missing) def test_query_val_transform(self): - """_query_val_transform handles various Python types.""" import datetime from peewee import _query_val_transform self.assertEqual(_query_val_transform('hello'), "'hello'") @@ -1322,8 +1285,6 @@ class TestUtilityFunctions(BaseTestCase): self.assertIn('hello', result) def test_deprecated_emits_warning(self): - """__deprecated__() emits a DeprecationWarning.""" - import warnings from peewee import __deprecated__ with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') @@ -1334,24 +1295,48 @@ class TestUtilityFunctions(BaseTestCase): # =========================================================================== -# Gap coverage: Proxy error paths +# Cross-database helpers # =========================================================================== -class TestProxyEdgeCases(BaseTestCase): - def test_proxy_uninitialized_getattr(self): - """Uninitialized Proxy raises AttributeError on attribute access.""" - p = Proxy() - self.assertRaises(AttributeError, lambda: p.some_method) +class TestDatabaseSQLHelpers(BaseTestCase): + def test_random_sqlite(self): + db = SqliteDatabase(':memory:') + self.assertSQL(db.random(), 'random()') - def test_proxy_setattr_error(self): - """Proxy raises AttributeError when setting non-slot attributes.""" - p = Proxy() - with self.assertRaisesCtx(AttributeError): - p.custom_attr = 42 + db = MySQLDatabase.__new__(MySQLDatabase) + self.assertSQL(db.random(), 'rand()') - def test_proxy_uninitialized_context_manager(self): - """Using uninitialized Proxy as context manager raises.""" - p = Proxy() - with self.assertRaisesCtx(AttributeError): - with p: - pass + db = PostgresqlDatabase.__new__(PostgresqlDatabase) + self.assertSQL(db.random(), 'random()') + + def test_get_noop_select_sqlite(self): + db = SqliteDatabase(':memory:') + ctx = db.get_noop_select(Context()) + self.assertSQL(ctx, 'SELECT 0 WHERE 0') + + db = MySQLDatabase.__new__(MySQLDatabase) + ctx = db.get_noop_select(Context()) + self.assertSQL(ctx, 'SELECT 0 WHERE 0=1') + + db = PostgresqlDatabase.__new__(PostgresqlDatabase) + ctx = db.get_noop_select(Context()) + self.assertSQL(ctx, 'SELECT 0 WHERE false') + + def test_mysql_extract_server_version_mysql(self): + db = MySQLDatabase.__new__(MySQLDatabase) + version = db._extract_server_version('8.0.31') + self.assertEqual(version, (8, 0, 31)) + + version = db._extract_server_version('5.5.5-10.6.12-MariaDB') + self.assertEqual(version, (10, 6, 12)) + + db = MySQLDatabase.__new__(MySQLDatabase) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + version = db._extract_server_version('unknown') + self.assertEqual(version, (0, 0, 0)) + self.assertTrue(len(w) > 0) + + db = MySQLDatabase.__new__(MySQLDatabase) + version = db._extract_server_version((8, 0, 31)) + self.assertEqual(version, (8, 0, 31))