diff --git a/playhouse/cysqlite_ext.py b/playhouse/cysqlite_ext.py new file mode 100644 index 00000000..61ea0bc0 --- /dev/null +++ b/playhouse/cysqlite_ext.py @@ -0,0 +1,77 @@ +import logging +from peewee import ImproperlyConfigured +from peewee import SqliteDatabase +from peewee import __exception_wrapper__ + +try: + from cysqlite import Connection +except ImportError: + Connection = None + +logger = logging.getLogger('peewee') + +class CySqliteDatabase(SqliteDatabase): + def _connect(self): + if Connection is None: + raise ImproperlyConfigured('cysqlite is not installed.') + conn = Connection(self.database, timeout=self._timeout, + extensions=True, **self.connect_params) + conn.connect() + try: + self._add_conn_hooks(conn) + except: + conn.close() + raise + return conn + + def _set_pragmas(self, conn): + for pragma, value in self._pragmas: + conn.execute_one('PRAGMA %s = %s;' % (pragma, value)) + + def _attach_databases(self, conn): + for name, db in self._attached.items(): + conn.execute_one('ATTACH DATABASE "%s" AS "%s"' % (db, name)) + + def _load_aggregates(self, conn): + for name, (klass, num_params) in self._aggregates.items(): + conn.create_aggregate(klass, name, num_params) + + def _load_collations(self, conn): + for name, fn in self._collations.items(): + conn.create_collation(fn, name) + + def _load_functions(self, conn): + for name, (fn, num_params, deterministic) in self._functions.items(): + conn.create_function(fn, name, num_params, deterministic) + + def _load_window_functions(self, conn): + for name, (klass, num_params) in self._window_functions.items(): + conn.create_window_function(klass, name, num_params) + + def last_insert_id(self, cursor, query_type=None): + return self.connection().last_insert_rowid() + + def rows_affected(self, cursor): + return self.connection().changes() + + def begin(self, lock_type='deferred'): + with __exception_wrapper__: + self.connection().begin(lock_type) + + def commit(self): + with __exception_wrapper__: + self.connection().commit() + + def rollback(self): + with __exception_wrapper__: + self.connection().rollback() + + def cursor(self): + raise NotImplementedError('cysqlite does not use a cursor interface.') + + def execute_sql(self, sql, params=None): + logger.debug((sql, params)) + with __exception_wrapper__: + conn = self.connection() + stmt = conn.execute(sql, params or ()) + return stmt diff --git a/tests/base.py b/tests/base.py index 6b7fd86b..21d3bf59 100644 --- a/tests/base.py +++ b/tests/base.py @@ -14,6 +14,7 @@ from peewee import * from peewee import sqlite3 from playhouse.cockroachdb import CockroachDatabase from playhouse.cockroachdb import NESTED_TX_MIN_VERSION +from playhouse.cysqlite_ext import CySqliteDatabase from playhouse.mysql_ext import MariaDBConnectorDatabase from playhouse.mysql_ext import MySQLConnectorDatabase from playhouse.psycopg3_ext import Psycopg3Database @@ -26,6 +27,7 @@ def db_loader(engine, name='peewee_test', db_class=None, **params): if db_class is None: engine_aliases = { SqliteDatabase: ['sqlite', 'sqlite3'], + CySqliteDatabase: ['cysqlite'], MySQLDatabase: ['mysql'], PostgresqlDatabase: ['postgres', 'postgresql'], Psycopg3Database: ['psycopg3'], @@ -57,11 +59,12 @@ BACKEND = os.environ.get('PEEWEE_TEST_BACKEND') or 'sqlite' VERBOSITY = int(os.environ.get('PEEWEE_TEST_VERBOSITY') or 1) SLOW_TESTS = bool(os.environ.get('PEEWEE_SLOW_TESTS')) -IS_SQLITE = BACKEND.startswith('sqlite') +IS_SQLITE = BACKEND.startswith(('sqlite', 'cysqlite')) IS_MYSQL = BACKEND.startswith(('mysql', 'maria')) IS_POSTGRESQL = BACKEND.startswith(('postgres', 'psycopg')) IS_CRDB = BACKEND in ('cockroach', 'cockroachdb', 'crdb') IS_PSYCOPG3 = BACKEND == 'psycopg3' +IS_CYSQLITE = BACKEND == 'cysqlite' def make_db_params(key): diff --git a/tests/reflection.py b/tests/reflection.py index 2595a5d6..2f03053c 100644 --- a/tests/reflection.py +++ b/tests/reflection.py @@ -7,6 +7,7 @@ from peewee import * from playhouse.reflection import * from .base import IS_CRDB +from .base import IS_CYSQLITE from .base import IS_SQLITE_OLD from .base import ModelTestCase from .base import TestModel @@ -601,6 +602,7 @@ class TestCyclicalFK(BaseReflectionTestCase): warnings.filterwarnings('ignore') @requires_sqlite + @skip_if(IS_CYSQLITE, 'cysqlite does not implement cursor at the moment.') def test_cyclical_fk(self): # NOTE: this schema was provided by a user. cursor = self.database.cursor() diff --git a/tests/transactions.py b/tests/transactions.py index 5793f92e..7095bb10 100644 --- a/tests/transactions.py +++ b/tests/transactions.py @@ -377,7 +377,7 @@ class TestSession(BaseTransactionTestCase): @skip_unless(IS_SQLITE, 'requires sqlite for transaction lock type') class TestTransactionLockType(BaseTransactionTestCase): def test_lock_type(self): - db2 = new_connection(timeout=0.001) + db2 = new_connection(timeout=0.0001) db2.connect() with self.database.atomic(lock_type='EXCLUSIVE') as txn: