SIGNIFICANT: remove playhouse._sqlite_ext.

C implementations of ranking functions are moved to _sqlite_udf.
Capabilites exclusive to the CSqliteExtDatabase implementation are being
migrated over to cysqlite_ext.CySqliteDatabase, which supports them
natively without relying on hacks.
This commit is contained in:
Charles Leifer
2026-02-06 10:37:08 -06:00
parent e666910ded
commit 1aeef460a4
21 changed files with 813 additions and 3709 deletions
+4 -5
View File
@@ -17,11 +17,10 @@ https://github.com/coleifer/peewee/releases
testing, since testing `x in dirty_fields` returns True if one or more field
exists due to operator overloads returning a truthy Expression object.
Refs #3028.
* Begin **significant** pruning of current Cython Sqlite extensions - most of
these will be moved to `playhouse.cysqlite_ext` which supports them natively
(no more hacks). Those that remain will likely just be Cython implementations
of the various rank() family of functions, and those can probably just get
put into `playhouse.sqlite_udf`. Take warning!
* **Significant**: removal of Cython `_sqlite_ext` extension. The C
implementations of the FTS rank functions are moved to `sqlite_udf`. Most of
the remaining functionality is moved to `playhouse.cysqlite_ext` which
supports it natively (no more hacks).
[View commits](https://github.com/coleifer/peewee/compare/3.19.0...master)
-1
View File
@@ -10,4 +10,3 @@ include playhouse/*.pyx
include playhouse/README.md
recursive-include examples *
recursive-include docs *
recursive-include playhouse/_pysqlite *
+5 -2
View File
@@ -843,15 +843,18 @@ Database
.. py:method:: table_function([name=None])
Class-decorator for registering a :py:class:`TableFunction`. Table
Class-decorator for registering a ``cysqlite.TableFunction``. Table
functions are user-defined functions that, rather than returning a
single, scalar value, can return any number of rows of tabular data.
See `cysqlite docs <https://cysqlite.readthedocs.io/>`_ for details on
``TableFunction`` API.
Example:
.. code-block:: python
from playhouse.sqlite_ext import TableFunction
from cysqlite import TableFunction
@db.table_function('series')
class Series(TableFunction):
+242
View File
@@ -0,0 +1,242 @@
.. _cysqlite:
cysqlite Extension
==================
.. py:class:: SqliteExtDatabase(database[, pragmas=None[, timeout=5[, rank_functions=True[, regexp_function=False[, json_contains=False]]]]])
:param list pragmas: A list of 2-tuples containing pragma key and value to
set every time a connection is opened.
:param timeout: Set the busy-timeout on the SQLite driver (in seconds).
:param bool rank_functions: Make search result ranking functions available.
:param bool regexp_function: Make the REGEXP function available.
:param bool json_contains: Make json_containts() function available.
Extends :py:class:`SqliteDatabase` and inherits methods for declaring
user-defined functions, aggregates, window functions, table functions,
collations, pragmas, etc.
.. py:method:: on_commit(fn)
:param fn: callable or ``None`` to clear the current hook.
Register a callback to be executed whenever a transaction is committed
on the current connection. The callback accepts no parameters and the
return value is ignored.
However, if the callback raises a :py:class:`ValueError`, the
transaction will be aborted and rolled-back.
Example:
.. code-block:: python
db = CySqliteDatabase(':memory:')
@db.on_commit
def on_commit():
logger.info('COMMITing changes')
.. py:method:: on_rollback(fn)
:param fn: callable or ``None`` to clear the current hook.
Register a callback to be executed whenever a transaction is rolled
back on the current connection. The callback accepts no parameters and
the return value is ignored.
Example:
.. code-block:: python
@db.on_rollback
def on_rollback():
logger.info('Rolling back changes')
.. py:method:: on_update(fn)
:param fn: callable or ``None`` to clear the current hook.
Register a callback to be executed whenever the database is written to
(via an *UPDATE*, *INSERT* or *DELETE* query). The callback should
accept the following parameters:
* ``query`` - the type of query, either *INSERT*, *UPDATE* or *DELETE*.
* database name - the default database is named *main*.
* table name - name of table being modified.
* rowid - the rowid of the row being modified.
The callback's return value is ignored.
Example:
.. code-block:: python
db = CySqliteDatabase(':memory:')
@db.on_update
def on_update(query_type, db, table, rowid):
# e.g. INSERT row 3 into table users.
logger.info('%s row %s into table %s', query_type, rowid, table)
.. py:method:: authorizer(fn)
:param fn: callable or ``None`` to clear the current authorizer.
Register an authorizer callback. Authorizer callbacks must accept 5
parameters, which vary depending on the operation being checked.
* op: operation code, e.g. ``cysqlite.C_SQLITE_INSERT``.
* p1: operation-specific value, e.g. table name for ``C_SQLITE_INSERT``.
* p2: operation-specific value.
* p3: database name, e.g. ``"main"``.
* p4: inner-most trigger or view responsible for the access attempt if
applicable, else ``None``.
See `sqlite authorizer documentation <https://www.sqlite.org/c3ref/c_alter_table.html>`_
for description of authorizer codes and values for parameters p1 and p2.
The authorizer callback must return one of:
* ``cysqlite.C_SQLITE_OK``: allow operation.
* ``cysqlite.C_SQLITE_IGNORE``: allow statement compilation but prevent
the operation from occuring.
* ``cysqlite.C_SQLITE_DENY``: prevent statement compilation.
More details can be found in the `cysqlite docs <https://cysqlite.readthedocs.org/>`_.
.. py:method:: trace(fn, mask=2):
:param fn: callable or ``None`` to clear the current trace hook.
:param int mask: mask of what types of events to trace. Default value
corresponds to ``SQLITE_TRACE_PROFILE``.
Register a trace hook (``sqlite3_trace_v2``). Trace callback must
accept 4 parameters, which vary depending on the operation being
traced.
* event: type of event, e.g. ``cysqlite.TRACE_PROFILE``.
* sid: memory address of statement (only ``cysqlite.TRACE_CLOSE``), else -1.
* sql: SQL string (only ``cysqlite.TRACE_STMT``), else None.
* ns: estimated number of nanoseconds the statement took to run (only
``cysqlite.TRACE_PROFILE``), else -1.
Any return value from callback is ignored.
More details can be found in the `cysqlite docs <https://cysqlite.readthedocs.org/>`_.
.. py:method:: progress(fn, n=1)
:param fn: callable or ``None`` to clear the current progress handler.
:param int n: approximate number of VM instructions to execute between
calls to the progress handler.
Register a progress handler (``sqlite3_progress_handler``). Callback
takes no arguments and returns 0 to allow progress to continue or any
non-zero value to interrupt progress.
More details can be found in the `cysqlite docs <https://cysqlite.readthedocs.org/>`_.
.. py:attribute:: autocommit
Property which returns a boolean indicating if autocommit is enabled.
By default, this value will be ``True`` except when inside a
transaction (or :py:meth:`~Database.atomic` block).
Example:
.. code-block:: pycon
>>> db = CySqliteDatabase(':memory:')
>>> db.autocommit
True
>>> with db.atomic():
... print(db.autocommit)
...
False
>>> db.autocommit
True
.. py:method:: backup(destination[, pages=None, name=None, progress=None])
:param SqliteDatabase destination: Database object to serve as
destination for the backup.
:param int pages: Number of pages per iteration. Default value of -1
indicates all pages should be backed-up in a single step.
:param str name: Name of source database (may differ if you used ATTACH
DATABASE to load multiple databases). Defaults to "main".
:param progress: Progress callback, called with three parameters: the
number of pages remaining, the total page count, and whether the
backup is complete.
Example:
.. code-block:: python
master = CySqliteDatabase('master.db')
replica = CySqliteDatabase('replica.db')
# Backup the contents of master to replica.
master.backup(replica)
.. py:method:: backup_to_file(filename[, pages, name, progress])
:param filename: Filename to store the database backup.
:param int pages: Number of pages per iteration. Default value of -1
indicates all pages should be backed-up in a single step.
:param str name: Name of source database (may differ if you used ATTACH
DATABASE to load multiple databases). Defaults to "main".
:param progress: Progress callback, called with three parameters: the
number of pages remaining, the total page count, and whether the
backup is complete.
Backup the current database to a file. The backed-up data is not a
database dump, but an actual SQLite database file.
Example:
.. code-block:: python
db = CySqliteDatabase('app.db')
def nightly_backup():
filename = 'backup-%s.db' % (datetime.date.today())
db.backup_to_file(filename)
.. py:method:: blob_open(table, column, rowid[, read_only=False])
:param str table: Name of table containing data.
:param str column: Name of column containing data.
:param int rowid: ID of row to retrieve.
:param bool read_only: Open the blob for reading only.
:param str dbname: Database name (e.g. if multiple databases attached).
:returns: ``cysqlite.Blob`` instance which provides efficient access to
the underlying binary data.
:rtype: cysqlite.Blob
See `cysqlite documentation <https://cysqlite.readthedocs.io/>`_ for
more details.
Example:
.. code-block:: python
class Image(Model):
filename = TextField()
data = BlobField()
buf_size = 1024 * 1024 * 8 # Allocate 8MB for storing file.
rowid = Image.insert({Image.filename: 'thefile.jpg',
Image.data: ZeroBlob(buf_size)}).execute()
# Open the blob, returning a file-like object.
blob = db.blob_open('image', 'data', rowid)
# Write some data to the blob.
blob.write(image_data)
img_size = blob.tell()
# Read the data back out of the blob.
blob.seek(0)
image_data = blob.read(img_size)
+34 -4
View File
@@ -301,7 +301,9 @@ extensions:
* Functions - which take any number of parameters and return a single value.
* Aggregates - which aggregate parameters from multiple rows and return a
single value.
* Window Functions - aggregates which support operating on windows of data.
* Collations - which describe how to sort some value.
* Table Functions - fully user-defined tables (requres ``cysqlite``).
.. note::
For even more extension support, see :py:class:`SqliteExtDatabase`, which
@@ -352,6 +354,34 @@ Example user-defined aggregate:
.group_by(FileChunk.filename)
.order_by(FileChunk.filename, FileChunk.sequence))
Example user-defined window function:
.. code-block:: python
# Window functions are normal aggregates with two additional methods:
# inverse(value) - Perform the inverse of step(value).
# value() - Report value at current step.
@db.aggregate('mysum')
class MySum(object):
def __init__(self):
self._value = 0
def step(self, value):
self._value += (value or 0)
def inverse(self, value):
self._value -= (value or 0) # Do opposite of "step()".
def value(self):
return self._value
def finalize(self):
return self._value
# e.g., aggregate sum of employee salaries over their department.
query = (Employee
.select(
Employee.department,
Employee.salary,
fn.mysum(Employee.salary).over(partition_by=[Employee.department]))
.order_by(Employee.id))
Example collation:
.. code-block:: python
@@ -368,12 +398,12 @@ Example collation:
# Or...
Book.select().order_by(Book.title.asc(collation='reverse'))
Example user-defined table-value function (see :py:class:`TableFunction`
and :py:class:`~SqliteDatabase.table_function`) for additional details:
Example user-defined table-value function (see `cysqlite docs <https://cysqlite.readthedocs.io/>`_
for details on ``TableFunction``).
.. code-block:: python
from playhouse.sqlite_ext import TableFunction
from cysqlite import TableFunction
db = SqliteDatabase('my_app.db')
@@ -419,9 +449,9 @@ For more information, see:
* :py:meth:`SqliteDatabase.func`
* :py:meth:`SqliteDatabase.aggregate`
* :py:meth:`SqliteDatabase.window_function`
* :py:meth:`SqliteDatabase.collation`
* :py:meth:`SqliteDatabase.table_function`
* For even more SQLite extensions, see :ref:`sqlite_ext`
.. _sqlite-locking:
+7 -616
View File
@@ -16,19 +16,13 @@ The ``playhouse.sqlite_ext`` includes even more SQLite features, including:
* :ref:`Full-text search <sqlite-fts>`
* :ref:`JSON extension integration <sqlite-json1>`
* :ref:`Closure table extension support <sqlite-closure-table>`
* :ref:`LSM1 extension support <sqlite-lsm1>`
* :ref:`User-defined table functions <sqlite-vtfunc>`
* Support for online backups using backup API: :py:meth:`~CSqliteExtDatabase.backup_to_file`
* :ref:`BLOB API support, for efficient binary data storage <sqlite-blob>`.
* :ref:`Additional helpers <sqlite-extras>`
Getting started
---------------
To get started with the features described in this document, you will want to
use the :py:class:`SqliteExtDatabase` class from the ``playhouse.sqlite_ext``
module. Furthermore, some features require the ``playhouse._sqlite_ext`` C
extension -- these features will be noted in the documentation.
module.
Instantiating a :py:class:`SqliteExtDatabase`:
@@ -41,209 +35,21 @@ Instantiating a :py:class:`SqliteExtDatabase`:
('journal_mode', 'wal'), # Use WAL-mode (you should always use this!).
('foreign_keys', 1))) # Enforce foreign-key constraints.
.. note::
By default the C extension is not included with the Peewee wheel. If you wish
to build these, you will need to install Peewee via source-distribution:
.. code-block:: python
pip install peewee --no-binary :all:
APIs
----
.. py:class:: SqliteExtDatabase(database[, pragmas=None[, timeout=5[, c_extensions=None[, rank_functions=True[, hash_functions=False[, regexp_function=False]]]]]])
.. py:class:: SqliteExtDatabase(database[, pragmas=None[, timeout=5[, rank_functions=True[, regexp_function=False[, json_contains=False]]]]])
:param list pragmas: A list of 2-tuples containing pragma key and value to
set every time a connection is opened.
:param timeout: Set the busy-timeout on the SQLite driver (in seconds).
:param bool c_extensions: Declare that C extension speedups must/must-not
be used. If set to ``True`` and the extension module is not available,
will raise an :py:class:`ImproperlyConfigured` exception.
:param bool rank_functions: Make search result ranking functions available.
:param bool hash_functions: Make hashing functions available (md5, sha1, etc).
:param bool regexp_function: Make the REGEXP function available.
:param bool json_contains: Make json_containts() function available.
Extends :py:class:`SqliteDatabase` and inherits methods for declaring
user-defined functions, pragmas, etc.
.. py:class:: CSqliteExtDatabase(database[, pragmas=None[, timeout=5[, c_extensions=None[, rank_functions=True[, hash_functions=False[, regexp_function=False[, replace_busy_handler=False]]]]]]])
:param list pragmas: A list of 2-tuples containing pragma key and value to
set every time a connection is opened.
:param timeout: Set the busy-timeout on the SQLite driver (in seconds).
:param bool c_extensions: Declare that C extension speedups must/must-not
be used. If set to ``True`` and the extension module is not available,
will raise an :py:class:`ImproperlyConfigured` exception.
:param bool rank_functions: Make search result ranking functions available.
:param bool hash_functions: Make hashing functions available (md5, sha1, etc).
:param bool regexp_function: Make the REGEXP function available.
:param bool replace_busy_handler: Use a smarter busy-handler implementation.
Extends :py:class:`SqliteExtDatabase` and requires that the
``playhouse._sqlite_ext`` extension module be available.
.. py:method:: on_commit(fn)
Register a callback to be executed whenever a transaction is committed
on the current connection. The callback accepts no parameters and the
return value is ignored.
However, if the callback raises a :py:class:`ValueError`, the
transaction will be aborted and rolled-back.
Example:
.. code-block:: python
db = CSqliteExtDatabase(':memory:')
@db.on_commit
def on_commit():
logger.info('COMMITing changes')
.. py:method:: on_rollback(fn)
Register a callback to be executed whenever a transaction is rolled
back on the current connection. The callback accepts no parameters and
the return value is ignored.
Example:
.. code-block:: python
@db.on_rollback
def on_rollback():
logger.info('Rolling back changes')
.. py:method:: on_update(fn)
Register a callback to be executed whenever the database is written to
(via an *UPDATE*, *INSERT* or *DELETE* query). The callback should
accept the following parameters:
* ``query`` - the type of query, either *INSERT*, *UPDATE* or *DELETE*.
* database name - the default database is named *main*.
* table name - name of table being modified.
* rowid - the rowid of the row being modified.
The callback's return value is ignored.
Example:
.. code-block:: python
db = CSqliteExtDatabase(':memory:')
@db.on_update
def on_update(query_type, db, table, rowid):
# e.g. INSERT row 3 into table users.
logger.info('%s row %s into table %s', query_type, rowid, table)
.. py:method:: changes()
Return the number of rows modified in the currently-open transaction.
.. py:attribute:: autocommit
Property which returns a boolean indicating if autocommit is enabled.
By default, this value will be ``True`` except when inside a
transaction (or :py:meth:`~Database.atomic` block).
Example:
.. code-block:: pycon
>>> db = CSqliteExtDatabase(':memory:')
>>> db.autocommit
True
>>> with db.atomic():
... print(db.autocommit)
...
False
>>> db.autocommit
True
.. py:method:: backup(destination[, pages=None, name=None, progress=None])
:param SqliteDatabase destination: Database object to serve as
destination for the backup.
:param int pages: Number of pages per iteration. Default value of -1
indicates all pages should be backed-up in a single step.
:param str name: Name of source database (may differ if you used ATTACH
DATABASE to load multiple databases). Defaults to "main".
:param progress: Progress callback, called with three parameters: the
number of pages remaining, the total page count, and whether the
backup is complete.
Example:
.. code-block:: python
master = CSqliteExtDatabase('master.db')
replica = CSqliteExtDatabase('replica.db')
# Backup the contents of master to replica.
master.backup(replica)
.. py:method:: backup_to_file(filename[, pages, name, progress])
:param filename: Filename to store the database backup.
:param int pages: Number of pages per iteration. Default value of -1
indicates all pages should be backed-up in a single step.
:param str name: Name of source database (may differ if you used ATTACH
DATABASE to load multiple databases). Defaults to "main".
:param progress: Progress callback, called with three parameters: the
number of pages remaining, the total page count, and whether the
backup is complete.
Backup the current database to a file. The backed-up data is not a
database dump, but an actual SQLite database file.
Example:
.. code-block:: python
db = CSqliteExtDatabase('app.db')
def nightly_backup():
filename = 'backup-%s.db' % (datetime.date.today())
db.backup_to_file(filename)
.. py:method:: blob_open(table, column, rowid[, read_only=False])
:param str table: Name of table containing data.
:param str column: Name of column containing data.
:param int rowid: ID of row to retrieve.
:param bool read_only: Open the blob for reading only.
:returns: :py:class:`Blob` instance which provides efficient access to
the underlying binary data.
:rtype: Blob
See :py:class:`Blob` and :py:class:`ZeroBlob` for more information.
Example:
.. code-block:: python
class Image(Model):
filename = TextField()
data = BlobField()
buf_size = 1024 * 1024 * 8 # Allocate 8MB for storing file.
rowid = Image.insert({Image.filename: 'thefile.jpg',
Image.data: ZeroBlob(buf_size)}).execute()
# Open the blob, returning a file-like object.
blob = db.blob_open('image', 'data', rowid)
# Write some data to the blob.
blob.write(image_data)
img_size = blob.tell()
# Read the data back out of the blob.
blob.seek(0)
image_data = blob.read(img_size)
.. py:class:: RowIDField()
@@ -259,6 +65,7 @@ APIs
content = TextField()
timestamp = TimestampField()
.. py:class:: DocIDField()
Subclass of :py:class:`RowIDField` for use on virtual tables that
@@ -280,6 +87,7 @@ APIs
class Meta:
database = db
.. py:class:: AutoIncrementField()
SQLite, by default, may reuse primary key values after rows are deleted. To
@@ -288,6 +96,7 @@ APIs
There is a small performance cost for this feature. For more information,
see the SQLite docs on `autoincrement <https://sqlite.org/autoinc.html>`_.
.. py:class:: ISODateTimeField()
SQLite does not have a native DateTime data-type. Python ``datetime``
@@ -295,6 +104,7 @@ APIs
:py:class:`DateTimeField` ensures that the UTC offset is stored properly
for tz-aware datetimes and read-back properly when decoding row data.
.. _sqlite-json1:
.. py:class:: JSONField(json_dumps=None, json_loads=None, ...)
@@ -1445,141 +1255,6 @@ APIs
Generate a model class suitable for accessing the `vocab table <http://sqlite.org/fts5.html#the_fts5vocab_virtual_table_module>`_
corresponding to FTS5 search index.
.. _sqlite-vtfunc:
.. py:class:: TableFunction()
Implement a user-defined table-valued function. Unlike a simple
:ref:`scalar or aggregate <sqlite-user-functions>` function, which returns
a single scalar value, a table-valued function can return any number of
rows of tabular data.
Simple example:
.. code-block:: python
from playhouse.sqlite_ext import TableFunction
class Series(TableFunction):
# Name of columns in each row of generated data.
columns = ['value']
# Name of parameters the function may be called with.
params = ['start', 'stop', 'step']
def initialize(self, start=0, stop=None, step=1):
"""
Table-functions declare an initialize() method, which is
called with whatever arguments the user has called the
function with.
"""
self.start = self.current = start
self.stop = stop or float('Inf')
self.step = step
def iterate(self, idx):
"""
Iterate is called repeatedly by the SQLite database engine
until the required number of rows has been read **or** the
function raises a `StopIteration` signalling no more rows
are available.
"""
if self.current > self.stop:
raise StopIteration
ret, self.current = self.current, self.current + self.step
return (ret,)
# Register the table-function with our database, which ensures it
# is declared whenever a connection is opened.
db.table_function('series')(Series)
# Usage:
cursor = db.execute_sql('SELECT * FROM series(?, ?, ?)', (0, 5, 2))
for value, in cursor:
print(value)
.. note::
A :py:class:`TableFunction` must be registered with a database
connection before it can be used. To ensure the table function is
always available, you can use the
:py:meth:`SqliteDatabase.table_function` decorator to register the
function with the database.
:py:class:`TableFunction` implementations must provide two attributes and
implement two methods, described below.
.. py:attribute:: columns
A list containing the names of the columns for the data returned by the
function. For example, a function that is used to split a string on a
delimiter might specify 3 columns: ``[substring, start_idx, end_idx]``.
.. py:attribute:: params
The names of the parameters the function may be called with. All
parameters, including optional parameters, should be listed. For
example, a function that is used to split a string on a delimiter might
specify 2 params: ``[string, delimiter]``.
.. py:attribute:: name
*Optional* - specify the name for the table function. If not provided,
name will be taken from the class name.
.. py:attribute:: print_tracebacks = True
Print a full traceback for any errors that occur in the
table-function's callback methods. When set to False, only the generic
OperationalError will be visible.
.. py:method:: initialize(**parameter_values)
:param parameter_values: Parameters the function was called with.
:returns: No return value.
The ``initialize`` method is called to initialize the table function
with the parameters the user specified when calling the function.
.. py:method:: iterate(idx)
:param int idx: current iteration step
:returns: A tuple of row data corresponding to the columns named
in the :py:attr:`~TableFunction.columns` attribute.
:raises StopIteration: To signal that no more rows are available.
This function is called repeatedly and returns successive rows of data.
The function may terminate before all rows are consumed (especially if
the user specified a ``LIMIT`` on the results). Alternatively, the
function can signal that no more data is available by raising a
``StopIteration`` exception.
.. py:classmethod:: register(conn)
:param conn: A ``sqlite3.Connection`` object.
Register the table function with a DB-API 2.0 ``sqlite3.Connection``
object. Table-valued functions **must** be registered before they can
be used in a query.
Example:
.. code-block:: python
class MyTableFunction(TableFunction):
name = 'my_func'
# ... other attributes and methods ...
db = SqliteDatabase(':memory:')
db.connect()
MyTableFunction.register(db.connection())
To ensure the :py:class:`TableFunction` is registered every time a
connection is opened, use the :py:meth:`~SqliteDatabase.table_function`
decorator.
.. _sqlite-closure-table:
@@ -1773,287 +1448,3 @@ APIs
.. note::
For an in-depth discussion of the SQLite transitive closure extension,
check out this blog post, `Querying Tree Structures in SQLite using Python and the Transitive Closure Extension <https://charlesleifer.com/blog/querying-tree-structures-in-sqlite-using-python-and-the-transitive-closure-extension/>`_.
.. _sqlite-lsm1:
.. py:class:: LSMTable()
:py:class:`VirtualModel` subclass suitable for working with the `lsm1 extension <https://charlesleifer.com/blog/lsm-key-value-storage-in-sqlite3/>`_
The *lsm1* extension is a virtual table that provides a SQL interface to
the `lsm key/value storage engine from SQLite4 <http://sqlite.org/src4/doc/trunk/www/lsmusr.wiki>`_.
.. note::
The LSM1 extension has not been released yet (SQLite version 3.22 at
time of writing), so consider this feature experimental with potential
to change in subsequent releases.
LSM tables define one primary key column and an arbitrary number of
additional value columns (which are serialized and stored in a single value
field in the storage engine). The primary key must be all of the same type
and use one of the following field types:
* :py:class:`IntegerField`
* :py:class:`TextField`
* :py:class:`BlobField`
Since the LSM storage engine is a key/value store, primary keys (including
integers) must be specified by the application.
.. attention::
Secondary indexes are not supported by the LSM engine, so the only
efficient queries will be lookups (or range queries) on the primary
key. Other fields can be queried and filtered on, but may result in a
full table-scan.
Example model declaration:
.. code-block:: python
db = SqliteExtDatabase('my_app.db')
db.load_extension('lsm.so') # Load shared library.
class EventLog(LSMTable):
timestamp = IntegerField(primary_key=True)
action = TextField()
sender = TextField()
target = TextField()
class Meta:
database = db
filename = 'eventlog.ldb' # LSM data is stored in separate db.
# Declare virtual table.
EventLog.create_table()
Example queries:
.. code-block:: python
# Use dictionary operators to get, set and delete rows from the LSM
# table. Slices may be passed to represent a range of key values.
def get_timestamp():
# Return time as integer expressing time in microseconds.
return int(time.time() * 1000000)
# Create a new row, at current timestamp.
ts = get_timestamp()
EventLog[ts] = ('pageview', 'search', '/blog/some-post/')
# Retrieve row from event log.
log = EventLog[ts]
print(log.action, log.sender, log.target)
# Prints ("pageview", "search", "/blog/some-post/")
# Delete the row.
del EventLog[ts]
# We can also use the "create()" method.
EventLog.create(
timestamp=get_timestamp(),
action='signup',
sender='newsletter',
target='sqlite-news')
Simple key/value model declaration:
.. code-block:: python
class KV(LSMTable):
key = TextField(primary_key=True)
value = TextField()
class Meta:
database = db
filename = 'kv.ldb'
db.create_tables([KV])
For tables consisting of a single value field, Peewee will return the value
directly when getting a single item. You can also request slices of rows,
in which case Peewee returns a corresponding :py:class:`Select` query,
which can be iterated over. Below are some examples:
.. code-block:: pycon
>>> KV['k0'] = 'v0'
>>> print(KV['k0'])
'v0'
>>> data = [{'key': 'k%d' % i, 'value': 'v%d' % i} for i in range(20)]
>>> KV.insert_many(data).execute()
>>> KV.select().count()
20
>>> KV['k8']
'v8'
>>> list(KV['k4.1':'k7.x']
[Row(key='k5', value='v5'),
Row(key='k6', value='v6'),
Row(key='k7', value='v7')]
>>> list(KV['k6xxx':])
[Row(key='k7', value='v7'),
Row(key='k8', value='v8'),
Row(key='k9', value='v9')]
You can also index the :py:class:`LSMTable` using expressions:
.. code-block:: pycon
>>> list(KV[KV.key > 'k6'])
[Row(key='k7', value='v7'),
Row(key='k8', value='v8'),
Row(key='k9', value='v9')]
>>> list(KV[(KV.key > 'k6') & (KV.value != 'v8')])
[Row(key='k7', value='v7'),
Row(key='k9', value='v9')]
You can delete single rows using ``del`` or multiple rows using slices
or expressions:
.. code-block:: pycon
>>> del KV['k1']
>>> del KV['k3x':'k8']
>>> del KV[KV.key.between('k10', 'k18')]
>>> list(KV[:])
[Row(key='k0', value='v0'),
Row(key='k19', value='v19'),
Row(key='k2', value='v2'),
Row(key='k3', value='v3'),
Row(key='k9', value='v9')]
Attempting to get a single non-existant key will result in a ``DoesNotExist``,
but slices will not raise an exception:
.. code-block:: pycon
>>> KV['k1']
...
KV.DoesNotExist: <Model:KV> instance matching query does not exist: ...
>>> list(KV['k1':'k1'])
[]
.. _sqlite-blob:
.. py:class:: ZeroBlob(length)
:param int length: Size of blob in bytes.
:py:class:`ZeroBlob` is used solely to reserve space for storing a BLOB
that supports incremental I/O. To use the `SQLite BLOB-store <https://www.sqlite.org/c3ref/blob_open.html>`_
it is necessary to first insert a ZeroBlob of the desired size into the
row you wish to use with incremental I/O.
For example, see :py:class:`Blob`.
.. py:class:: Blob(database, table, column, rowid[, read_only=False])
:param database: :py:class:`SqliteExtDatabase` instance.
:param str table: Name of table being accessed.
:param str column: Name of column being accessed.
:param int rowid: Primary-key of row being accessed.
:param bool read_only: Prevent any modifications to the blob data.
Open a blob, stored in the given table/column/row, for incremental I/O.
To allocate storage for new data, you can use the :py:class:`ZeroBlob`,
which is very efficient.
.. code-block:: python
class RawData(Model):
data = BlobField()
# Allocate 100MB of space for writing a large file incrementally:
query = RawData.insert({'data': ZeroBlob(1024 * 1024 * 100)})
rowid = query.execute()
# Now we can open the row for incremental I/O:
blob = Blob(db, 'rawdata', 'data', rowid)
# Read from the file and write to the blob in chunks of 4096 bytes.
while True:
data = file_handle.read(4096)
if not data:
break
blob.write(data)
bytes_written = blob.tell()
blob.close()
.. py:method:: read([n=None])
:param int n: Only read up to *n* bytes from current position in file.
Read up to *n* bytes from the current position in the blob file. If *n*
is not specified, the entire blob will be read.
.. py:method:: seek(offset[, whence=0])
:param int offset: Seek to the given offset in the file.
:param int whence: Seek relative to the specified frame of reference.
Values for ``whence``:
* ``0``: beginning of file
* ``1``: current position
* ``2``: end of file
.. py:method:: tell()
Return current offset within the file.
.. py:method:: write(data)
:param bytes data: Data to be written
Writes the given data, starting at the current position in the file.
.. py:method:: close()
Close the file and free associated resources.
.. py:method:: reopen(rowid)
:param int rowid: Primary key of row to open.
If a blob has already been opened for a given table/column, you can use
the :py:meth:`~Blob.reopen` method to re-use the same :py:class:`Blob`
object for accessing multiple rows in the table.
.. _sqlite-extras:
Additional Features
-------------------
The :py:class:`SqliteExtDatabase` can also register other useful functions:
* ``rank_functions`` (enabled by default): registers functions for ranking
search results, such as *bm25* and *lucene*.
* ``hash_functions``: registers md5, sha1, sha256, adler32, crc32 and
murmurhash functions.
* ``regexp_function``: registers a regexp function.
Examples:
.. code-block:: python
def create_new_user(username, password):
# DO NOT DO THIS IN REAL LIFE. PLEASE.
query = User.insert({'username': username, 'password': fn.sha1(password)})
new_user_id = query.execute()
You can use the *murmurhash* function to hash bytes to an integer for compact
storage:
.. code-block:: pycon
>>> db = SqliteExtDatabase(':memory:', hash_functions=True)
>>> db.execute_sql('SELECT murmurhash(?)', ('abcdefg',)).fetchone()
(4188131059,)
View File
-73
View File
@@ -1,73 +0,0 @@
/* cache.h - definitions for the LRU cache
*
* Copyright (C) 2004-2015 Gerhard Häring <gh@ghaering.de>
*
* This file is part of pysqlite.
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgment in the product documentation would be
* appreciated but is not required.
* 2. Altered source versions must be plainly marked as such, and must not be
* misrepresented as being the original software.
* 3. This notice may not be removed or altered from any source distribution.
*/
#ifndef PYSQLITE_CACHE_H
#define PYSQLITE_CACHE_H
#include "Python.h"
/* The LRU cache is implemented as a combination of a doubly-linked with a
* dictionary. The list items are of type 'Node' and the dictionary has the
* nodes as values. */
typedef struct _pysqlite_Node
{
PyObject_HEAD
PyObject* key;
PyObject* data;
long count;
struct _pysqlite_Node* prev;
struct _pysqlite_Node* next;
} pysqlite_Node;
typedef struct
{
PyObject_HEAD
int size;
/* a dictionary mapping keys to Node entries */
PyObject* mapping;
/* the factory callable */
PyObject* factory;
pysqlite_Node* first;
pysqlite_Node* last;
/* if set, decrement the factory function when the Cache is deallocated.
* this is almost always desirable, but not in the pysqlite context */
int decref_factory;
} pysqlite_Cache;
extern PyTypeObject pysqlite_NodeType;
extern PyTypeObject pysqlite_CacheType;
int pysqlite_node_init(pysqlite_Node* self, PyObject* args, PyObject* kwargs);
void pysqlite_node_dealloc(pysqlite_Node* self);
int pysqlite_cache_init(pysqlite_Cache* self, PyObject* args, PyObject* kwargs);
void pysqlite_cache_dealloc(pysqlite_Cache* self);
PyObject* pysqlite_cache_get(pysqlite_Cache* self, PyObject* args);
int pysqlite_cache_setup_types(void);
#endif
-129
View File
@@ -1,129 +0,0 @@
/* connection.h - definitions for the connection type
*
* Copyright (C) 2004-2015 Gerhard Häring <gh@ghaering.de>
*
* This file is part of pysqlite.
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgment in the product documentation would be
* appreciated but is not required.
* 2. Altered source versions must be plainly marked as such, and must not be
* misrepresented as being the original software.
* 3. This notice may not be removed or altered from any source distribution.
*/
#ifndef PYSQLITE_CONNECTION_H
#define PYSQLITE_CONNECTION_H
#include "Python.h"
#include "pythread.h"
#include "structmember.h"
#include "cache.h"
#include "module.h"
#include "sqlite3.h"
typedef struct
{
PyObject_HEAD
sqlite3* db;
/* the type detection mode. Only 0, PARSE_DECLTYPES, PARSE_COLNAMES or a
* bitwise combination thereof makes sense */
int detect_types;
/* the timeout value in seconds for database locks */
double timeout;
/* for internal use in the timeout handler: when did the timeout handler
* first get called with count=0? */
double timeout_started;
/* None for autocommit, otherwise a PyString with the isolation level */
PyObject* isolation_level;
/* NULL for autocommit, otherwise a string with the BEGIN statement; will be
* freed in connection destructor */
char* begin_statement;
/* 1 if a check should be performed for each API call if the connection is
* used from the same thread it was created in */
int check_same_thread;
int initialized;
/* thread identification of the thread the connection was created in */
long thread_ident;
pysqlite_Cache* statement_cache;
/* Lists of weak references to statements and cursors used within this connection */
PyObject* statements;
PyObject* cursors;
/* Counters for how many statements/cursors were created in the connection. May be
* reset to 0 at certain intervals */
int created_statements;
int created_cursors;
PyObject* row_factory;
/* Determines how bytestrings from SQLite are converted to Python objects:
* - PyUnicode_Type: Python Unicode objects are constructed from UTF-8 bytestrings
* - OptimizedUnicode: Like before, but for ASCII data, only PyStrings are created.
* - PyString_Type: PyStrings are created as-is.
* - Any custom callable: Any object returned from the callable called with the bytestring
* as single parameter.
*/
PyObject* text_factory;
/* remember references to functions/classes used in
* create_function/create/aggregate, use these as dictionary keys, so we
* can keep the total system refcount constant by clearing that dictionary
* in connection_dealloc */
PyObject* function_pinboard;
/* a dictionary of registered collation name => collation callable mappings */
PyObject* collations;
/* Exception objects */
PyObject* Warning;
PyObject* Error;
PyObject* InterfaceError;
PyObject* DatabaseError;
PyObject* DataError;
PyObject* OperationalError;
PyObject* IntegrityError;
PyObject* InternalError;
PyObject* ProgrammingError;
PyObject* NotSupportedError;
} pysqlite_Connection;
extern PyTypeObject pysqlite_ConnectionType;
PyObject* pysqlite_connection_alloc(PyTypeObject* type, int aware);
void pysqlite_connection_dealloc(pysqlite_Connection* self);
PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs);
PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args);
PyObject* _pysqlite_connection_begin(pysqlite_Connection* self);
PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args);
PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args);
PyObject* pysqlite_connection_new(PyTypeObject* type, PyObject* args, PyObject* kw);
int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject* kwargs);
int pysqlite_connection_register_cursor(pysqlite_Connection* connection, PyObject* cursor);
int pysqlite_check_thread(pysqlite_Connection* self);
int pysqlite_check_connection(pysqlite_Connection* con);
int pysqlite_connection_setup_types(void);
#endif
-58
View File
@@ -1,58 +0,0 @@
/* module.h - definitions for the module
*
* Copyright (C) 2004-2015 Gerhard Häring <gh@ghaering.de>
*
* This file is part of pysqlite.
*
* This software is provided 'as-is', without any express or implied
* warranty. In no event will the authors be held liable for any damages
* arising from the use of this software.
*
* Permission is granted to anyone to use this software for any purpose,
* including commercial applications, and to alter it and redistribute it
* freely, subject to the following restrictions:
*
* 1. The origin of this software must not be misrepresented; you must not
* claim that you wrote the original software. If you use this software
* in a product, an acknowledgment in the product documentation would be
* appreciated but is not required.
* 2. Altered source versions must be plainly marked as such, and must not be
* misrepresented as being the original software.
* 3. This notice may not be removed or altered from any source distribution.
*/
#ifndef PYSQLITE_MODULE_H
#define PYSQLITE_MODULE_H
#include "Python.h"
#define PYSQLITE_VERSION "2.8.2"
extern PyObject* pysqlite_Error;
extern PyObject* pysqlite_Warning;
extern PyObject* pysqlite_InterfaceError;
extern PyObject* pysqlite_DatabaseError;
extern PyObject* pysqlite_InternalError;
extern PyObject* pysqlite_OperationalError;
extern PyObject* pysqlite_ProgrammingError;
extern PyObject* pysqlite_IntegrityError;
extern PyObject* pysqlite_DataError;
extern PyObject* pysqlite_NotSupportedError;
extern PyObject* pysqlite_OptimizedUnicode;
/* the functions time.time() and time.sleep() */
extern PyObject* time_time;
extern PyObject* time_sleep;
/* A dictionary, mapping colum types (INTEGER, VARCHAR, etc.) to converter
* functions, that convert the SQL value to the appropriate Python value.
* The key is uppercase.
*/
extern PyObject* converters;
extern int _enable_callback_tracebacks;
extern int pysqlite_BaseTypeAdapted;
#define PARSE_DECLTYPES 1
#define PARSE_COLNAMES 2
#endif
File diff suppressed because it is too large Load Diff
+235
View File
@@ -1,4 +1,7 @@
# cython: language_level=3
from libc.stdlib cimport free, malloc
from libc.math cimport log, sqrt
import sys
from difflib import SequenceMatcher
from random import randint
@@ -6,6 +9,238 @@ from random import randint
IS_PY3K = sys.version_info[0] == 3
# FTS ranking functions.
cdef double *get_weights(int ncol, tuple raw_weights):
cdef:
int argc = len(raw_weights)
int icol
double *weights = <double *>malloc(sizeof(double) * ncol)
for icol in range(ncol):
if argc == 0:
weights[icol] = 1.0
elif icol < argc:
weights[icol] = <double>raw_weights[icol]
else:
weights[icol] = 0.0
return weights
def peewee_rank(py_match_info, *raw_weights):
cdef:
unsigned int *match_info
unsigned int *phrase_info
bytes _match_info_buf = bytes(py_match_info)
char *match_info_buf = _match_info_buf
int nphrase, ncol, icol, iphrase, hits, global_hits
int P_O = 0, C_O = 1, X_O = 2
double score = 0.0, weight
double *weights
match_info = <unsigned int *>match_info_buf
nphrase = match_info[P_O]
ncol = match_info[C_O]
weights = get_weights(ncol, raw_weights)
# matchinfo X value corresponds to, for each phrase in the search query, a
# list of 3 values for each column in the search table.
# So if we have a two-phrase search query and three columns of data, the
# following would be the layout:
# p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8]
# p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17]
for iphrase in range(nphrase):
phrase_info = &match_info[X_O + iphrase * ncol * 3]
for icol in range(ncol):
weight = weights[icol]
if weight == 0:
continue
# The idea is that we count the number of times the phrase appears
# in this column of the current row, compared to how many times it
# appears in this column across all rows. The ratio of these values
# provides a rough way to score based on "high value" terms.
hits = phrase_info[3 * icol]
global_hits = phrase_info[3 * icol + 1]
if hits > 0:
score += weight * (<double>hits / <double>global_hits)
free(weights)
return -1 * score
def peewee_lucene(py_match_info, *raw_weights):
# Usage: peewee_lucene(matchinfo(table, 'pcnalx'), 1)
cdef:
unsigned int *match_info
bytes _match_info_buf = bytes(py_match_info)
char *match_info_buf = _match_info_buf
int nphrase, ncol
double total_docs, term_frequency
double doc_length, docs_with_term, avg_length
double idf, weight, rhs, denom
double *weights
int P_O = 0, C_O = 1, N_O = 2, L_O, X_O
int iphrase, icol, x
double score = 0.0
match_info = <unsigned int *>match_info_buf
nphrase = match_info[P_O]
ncol = match_info[C_O]
total_docs = match_info[N_O]
L_O = 3 + ncol
X_O = L_O + ncol
weights = get_weights(ncol, raw_weights)
for iphrase in range(nphrase):
for icol in range(ncol):
weight = weights[icol]
if weight == 0:
continue
doc_length = match_info[L_O + icol]
x = X_O + (3 * (icol + iphrase * ncol))
term_frequency = match_info[x] # f(qi)
docs_with_term = match_info[x + 2] or 1. # n(qi)
idf = log(total_docs / (docs_with_term + 1.))
tf = sqrt(term_frequency)
fieldNorms = 1.0 / sqrt(doc_length)
score += (idf * tf * fieldNorms)
free(weights)
return -1 * score
def peewee_bm25(py_match_info, *raw_weights):
# Usage: peewee_bm25(matchinfo(table, 'pcnalx'), 1)
# where the second parameter is the index of the column and
# the 3rd and 4th specify k and b.
cdef:
unsigned int *match_info
bytes _match_info_buf = bytes(py_match_info)
char *match_info_buf = _match_info_buf
int nphrase, ncol
double B = 0.75, K = 1.2
double total_docs, term_frequency
double doc_length, docs_with_term, avg_length
double idf, weight, ratio, num, b_part, denom, pc_score
double *weights
int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O
int iphrase, icol, x
double score = 0.0
match_info = <unsigned int *>match_info_buf
# PCNALX = matchinfo format.
# P = 1 = phrase count within query.
# C = 1 = searchable columns in table.
# N = 1 = total rows in table.
# A = c = for each column, avg number of tokens
# L = c = for each column, length of current row (in tokens)
# X = 3 * c * p = for each phrase and table column,
# * phrase count within column for current row.
# * phrase count within column for all rows.
# * total rows for which column contains phrase.
nphrase = match_info[P_O] # n
ncol = match_info[C_O]
total_docs = match_info[N_O] # N
L_O = A_O + ncol
X_O = L_O + ncol
weights = get_weights(ncol, raw_weights)
for iphrase in range(nphrase):
for icol in range(ncol):
weight = weights[icol]
if weight == 0:
continue
x = X_O + (3 * (icol + iphrase * ncol))
term_frequency = match_info[x] # f(qi, D)
docs_with_term = match_info[x + 2] # n(qi)
# log( (N - n(qi) + 0.5) / (n(qi) + 0.5) )
idf = log(
(total_docs - docs_with_term + 0.5) /
(docs_with_term + 0.5))
if idf <= 0.0:
idf = 1e-6
doc_length = match_info[L_O + icol] # |D|
avg_length = match_info[A_O + icol] # avgdl
if avg_length == 0:
avg_length = 1
ratio = doc_length / avg_length
num = term_frequency * (K + 1)
b_part = 1 - B + (B * ratio)
denom = term_frequency + (K * b_part)
pc_score = idf * (num / denom)
score += (pc_score * weight)
free(weights)
return -1 * score
def peewee_bm25f(py_match_info, *raw_weights):
# Usage: peewee_bm25f(matchinfo(table, 'pcnalx'), 1)
# where the second parameter is the index of the column and
# the 3rd and 4th specify k and b.
cdef:
unsigned int *match_info
bytes _match_info_buf = bytes(py_match_info)
char *match_info_buf = _match_info_buf
int nphrase, ncol
double B = 0.75, K = 1.2, epsilon
double total_docs, term_frequency, docs_with_term
double doc_length = 0.0, avg_length = 0.0
double idf, weight, ratio, num, b_part, denom, pc_score
double *weights
int P_O = 0, C_O = 1, N_O = 2, A_O = 3, L_O, X_O
int iphrase, icol, x
double score = 0.0
match_info = <unsigned int *>match_info_buf
nphrase = match_info[P_O] # n
ncol = match_info[C_O]
total_docs = match_info[N_O] # N
L_O = A_O + ncol
X_O = L_O + ncol
for icol in range(ncol):
avg_length += match_info[A_O + icol]
doc_length += match_info[L_O + icol]
epsilon = 1.0 / (total_docs * avg_length)
if avg_length == 0:
avg_length = 1
ratio = doc_length / avg_length
weights = get_weights(ncol, raw_weights)
for iphrase in range(nphrase):
for icol in range(ncol):
weight = weights[icol]
if weight == 0:
continue
x = X_O + (3 * (icol + iphrase * ncol))
term_frequency = match_info[x] # f(qi, D)
docs_with_term = match_info[x + 2] # n(qi)
# log( (N - n(qi) + 0.5) / (n(qi) + 0.5) )
idf = log(
(total_docs - docs_with_term + 0.5) /
(docs_with_term + 0.5))
idf = epsilon if idf <= 0 else idf
num = term_frequency * (K + 1)
b_part = 1 - B + (B * ratio)
denom = term_frequency + (K * b_part)
pc_score = idf * ((num / denom) + 1.)
score += (pc_score * weight)
free(weights)
return -1 * score
# String UDF.
def damerau_levenshtein_dist(s1, s2):
cdef:
+13 -9
View File
@@ -18,8 +18,8 @@ from playhouse.sqlite_ext import (
SearchField,
VirtualModel,
FTSModel,
FTS5Model,
rank)
FTS5Model)
from playhouse.sqlite_udf import rank
try:
import cysqlite
@@ -45,7 +45,7 @@ def __dbstatus__(flag, return_highwater=False, return_current=False):
def getter(self):
if self._state.conn is None:
raise ImproperlyConfigured('database connection not opened.')
result = sqlite_get_db_status(self._state.conn, flag)
result = self._state.conn.status(flag)
if return_current:
return result[0]
return result[1] if return_highwater else result
@@ -126,25 +126,25 @@ class CySqliteDatabase(SqliteDatabase):
def on_commit(self, fn):
self._commit_hook = fn
if not self.is_closed():
conn.commit_hook(fn)
self.connection().commit_hook(fn)
return fn
def on_rollback(self, fn):
self._rollback_hook = fn
if not self.is_closed():
conn.rollback_hook(fn)
self.connection().rollback_hook(fn)
return fn
def on_update(self, fn):
self._update_hook = fn
if not self.is_closed():
conn.update_hook(fn)
self.connection().update_hook(fn)
return fn
def authorizer(self, fn):
self._authorizer = fn
if not self.is_closed():
conn.authorizer(fn)
self.connection().authorizer(fn)
return fn
def trace(self, fn, mask=2):
@@ -154,7 +154,7 @@ class CySqliteDatabase(SqliteDatabase):
self._trace = (fn, mask)
if not self.is_closed():
args = (None,) if fn is None else self._trace
conn.authorizer(*args)
self.connection().authorizer(*args)
return fn
def progress(self, fn, n=1):
@@ -164,7 +164,7 @@ class CySqliteDatabase(SqliteDatabase):
self._progress = (fn, mask)
if not self.is_closed():
args = (None,) if fn is None else self._progress
conn.progress(*args)
self.connection().progress(*args)
return fn
def begin(self, lock_type='deferred'):
@@ -183,6 +183,10 @@ class CySqliteDatabase(SqliteDatabase):
def autocommit(self):
return self.connection().autocommit()
def blob_open(self, table, column, rowid, read_only=False, dbname=None):
return self.connection().blob_open(table, column, rowid, read_only,
db_name)
def backup(self, destination, pages=None, name=None, progress=None,
src_name=None):
+13 -461
View File
@@ -1,7 +1,5 @@
import json
import math
import re
import struct
import sys
from peewee import *
@@ -16,22 +14,9 @@ from peewee import OP
from peewee import VirtualField
from peewee import merge_dict
from peewee import sqlite3
try:
from playhouse._sqlite_ext import (
backup,
backup_to_file,
Blob,
ConnectionHelper,
register_hash_functions,
register_rank_functions,
sqlite_get_db_status,
sqlite_get_status,
TableFunction,
ZeroBlob,
)
CYTHON_SQLITE_EXTENSIONS = True
except ImportError:
CYTHON_SQLITE_EXTENSIONS = False
from playhouse.sqlite_udf import JSON
from playhouse.sqlite_udf import RANK
from playhouse.sqlite_udf import register_udf_groups
if sys.version_info[0] == 3:
@@ -948,109 +933,6 @@ def ClosureTable(model_class, foreign_key=None, referencing_class=None,
return type(name, (BaseClosureTable,), {'Meta': Meta})
class LSMTable(VirtualModel):
class Meta:
extension_module = 'lsm1'
filename = None
@classmethod
def clean_options(cls, options):
filename = cls._meta.filename
if not filename:
raise ValueError('LSM1 extension requires that you specify a '
'filename for the LSM database.')
else:
if len(filename) >= 2 and filename[0] != '"':
filename = '"%s"' % filename
if not cls._meta.primary_key:
raise ValueError('LSM1 models must specify a primary-key field.')
key = cls._meta.primary_key
if isinstance(key, AutoField):
raise ValueError('LSM1 models must explicitly declare a primary '
'key field.')
if not isinstance(key, (TextField, BlobField, IntegerField)):
raise ValueError('LSM1 key must be a TextField, BlobField, or '
'IntegerField.')
key._hidden = True
if isinstance(key, IntegerField):
data_type = 'UINT'
elif isinstance(key, BlobField):
data_type = 'BLOB'
else:
data_type = 'TEXT'
cls._meta.prefix_arguments = [filename, '"%s"' % key.name, data_type]
# Does the key map to a scalar value, or a tuple of values?
if len(cls._meta.sorted_fields) == 2:
cls._meta._value_field = cls._meta.sorted_fields[1]
else:
cls._meta._value_field = None
return options
@classmethod
def load_extension(cls, path='lsm.so'):
cls._meta.database.load_extension(path)
@staticmethod
def slice_to_expr(key, idx):
if idx.start is not None and idx.stop is not None:
return key.between(idx.start, idx.stop)
elif idx.start is not None:
return key >= idx.start
elif idx.stop is not None:
return key <= idx.stop
@staticmethod
def _apply_lookup_to_query(query, key, lookup):
if isinstance(lookup, slice):
expr = LSMTable.slice_to_expr(key, lookup)
if expr is not None:
query = query.where(expr)
return query, False
elif isinstance(lookup, Expression):
return query.where(lookup), False
else:
return query.where(key == lookup), True
@classmethod
def get_by_id(cls, pk):
query, is_single = cls._apply_lookup_to_query(
cls.select().namedtuples(),
cls._meta.primary_key,
pk)
if is_single:
row = query.get()
return row[1] if cls._meta._value_field is not None else row
else:
return query
@classmethod
def set_by_id(cls, key, value):
if cls._meta._value_field is not None:
data = {cls._meta._value_field: value}
elif isinstance(value, tuple):
data = {}
for field, fval in zip(cls._meta.sorted_fields[1:], value):
data[field] = fval
elif isinstance(value, dict):
data = value
elif isinstance(value, cls):
data = value.__dict__
data[cls._meta.primary_key] = key
cls.replace(data).execute()
@classmethod
def delete_by_id(cls, pk):
query, is_single = cls._apply_lookup_to_query(
cls.delete(),
cls._meta.primary_key,
pk)
return query.execute()
OP.MATCH = 'MATCH'
def _sqlite_regexp(regex, value):
@@ -1058,36 +940,17 @@ def _sqlite_regexp(regex, value):
class SqliteExtDatabase(SqliteDatabase):
def __init__(self, database, c_extensions=None, rank_functions=True,
hash_functions=False, regexp_function=False,
def __init__(self, database, rank_functions=True, regexp_function=False,
json_contains=False, *args, **kwargs):
super(SqliteExtDatabase, self).__init__(database, *args, **kwargs)
self._row_factory = None
if c_extensions and not CYTHON_SQLITE_EXTENSIONS:
raise ImproperlyConfigured('SqliteExtDatabase initialized with '
'C extensions, but shared library was '
'not found!')
prefer_c = CYTHON_SQLITE_EXTENSIONS and (c_extensions is not False)
if rank_functions:
if prefer_c:
register_rank_functions(self)
else:
self.register_function(bm25, 'fts_bm25')
self.register_function(rank, 'fts_rank')
self.register_function(bm25, 'fts_bm25f') # Fall back to bm25.
self.register_function(bm25, 'fts_lucene')
if hash_functions:
if not prefer_c:
raise ValueError('C extension required to register hash '
'functions.')
register_hash_functions(self)
register_udf_groups(self, RANK)
if regexp_function:
self.register_function(_sqlite_regexp, 'regexp', 2)
if json_contains:
self.register_function(_json_contains, 'json_contains')
self._c_extensions = prefer_c
register_udf_groups(self, JSON)
def _add_conn_hooks(self, conn):
super(SqliteExtDatabase, self)._add_conn_hooks(conn)
@@ -1098,325 +961,14 @@ class SqliteExtDatabase(SqliteDatabase):
self._row_factory = fn
if CYTHON_SQLITE_EXTENSIONS:
SQLITE_STATUS_MEMORY_USED = 0
SQLITE_STATUS_PAGECACHE_USED = 1
SQLITE_STATUS_PAGECACHE_OVERFLOW = 2
SQLITE_STATUS_SCRATCH_USED = 3
SQLITE_STATUS_SCRATCH_OVERFLOW = 4
SQLITE_STATUS_MALLOC_SIZE = 5
SQLITE_STATUS_PARSER_STACK = 6
SQLITE_STATUS_PAGECACHE_SIZE = 7
SQLITE_STATUS_SCRATCH_SIZE = 8
SQLITE_STATUS_MALLOC_COUNT = 9
SQLITE_DBSTATUS_LOOKASIDE_USED = 0
SQLITE_DBSTATUS_CACHE_USED = 1
SQLITE_DBSTATUS_SCHEMA_USED = 2
SQLITE_DBSTATUS_STMT_USED = 3
SQLITE_DBSTATUS_LOOKASIDE_HIT = 4
SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE = 5
SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL = 6
SQLITE_DBSTATUS_CACHE_HIT = 7
SQLITE_DBSTATUS_CACHE_MISS = 8
SQLITE_DBSTATUS_CACHE_WRITE = 9
SQLITE_DBSTATUS_DEFERRED_FKS = 10
#SQLITE_DBSTATUS_CACHE_USED_SHARED = 11
def __status__(flag, return_highwater=False):
"""
Expose a sqlite3_status() call for a particular flag as a property of
the Database object.
"""
def getter(self):
result = sqlite_get_status(flag)
return result[1] if return_highwater else result
return property(getter)
def __dbstatus__(flag, return_highwater=False, return_current=False):
"""
Expose a sqlite3_dbstatus() call for a particular flag as a property of
the Database instance. Unlike sqlite3_status(), the dbstatus properties
pertain to the current connection.
"""
def getter(self):
if self._state.conn is None:
raise ImproperlyConfigured('database connection not opened.')
result = sqlite_get_db_status(self._state.conn, flag)
if return_current:
return result[0]
return result[1] if return_highwater else result
return property(getter)
class CSqliteExtDatabase(SqliteExtDatabase):
def __init__(self, *args, **kwargs):
self._conn_helper = None
self._commit_hook = self._rollback_hook = self._update_hook = None
self._replace_busy_handler = False
super(CSqliteExtDatabase, self).__init__(*args, **kwargs)
def init(self, database, replace_busy_handler=False, **kwargs):
super(CSqliteExtDatabase, self).init(database, **kwargs)
self._replace_busy_handler = replace_busy_handler
def _close(self, conn):
if self._commit_hook:
self._conn_helper.set_commit_hook(None)
if self._rollback_hook:
self._conn_helper.set_rollback_hook(None)
if self._update_hook:
self._conn_helper.set_update_hook(None)
return super(CSqliteExtDatabase, self)._close(conn)
def _add_conn_hooks(self, conn):
super(CSqliteExtDatabase, self)._add_conn_hooks(conn)
self._conn_helper = ConnectionHelper(conn)
if self._commit_hook is not None:
self._conn_helper.set_commit_hook(self._commit_hook)
if self._rollback_hook is not None:
self._conn_helper.set_rollback_hook(self._rollback_hook)
if self._update_hook is not None:
self._conn_helper.set_update_hook(self._update_hook)
if self._replace_busy_handler:
timeout = self._timeout or 5
self._conn_helper.set_busy_handler(timeout * 1000)
def on_commit(self, fn):
self._commit_hook = fn
if not self.is_closed():
self._conn_helper.set_commit_hook(fn)
return fn
def on_rollback(self, fn):
self._rollback_hook = fn
if not self.is_closed():
self._conn_helper.set_rollback_hook(fn)
return fn
def on_update(self, fn):
self._update_hook = fn
if not self.is_closed():
self._conn_helper.set_update_hook(fn)
return fn
def changes(self):
return self._conn_helper.changes()
@property
def last_insert_rowid(self):
return self._conn_helper.last_insert_rowid()
@property
def autocommit(self):
return self._conn_helper.autocommit()
def backup(self, destination, pages=None, name=None, progress=None):
return backup(self.connection(), destination.connection(),
pages=pages, name=name, progress=progress)
def backup_to_file(self, filename, pages=None, name=None,
progress=None):
return backup_to_file(self.connection(), filename, pages=pages,
name=name, progress=progress)
def blob_open(self, table, column, rowid, read_only=False):
return Blob(self, table, column, rowid, read_only)
# Status properties.
memory_used = __status__(SQLITE_STATUS_MEMORY_USED)
malloc_size = __status__(SQLITE_STATUS_MALLOC_SIZE, True)
malloc_count = __status__(SQLITE_STATUS_MALLOC_COUNT)
pagecache_used = __status__(SQLITE_STATUS_PAGECACHE_USED)
pagecache_overflow = __status__(SQLITE_STATUS_PAGECACHE_OVERFLOW)
pagecache_size = __status__(SQLITE_STATUS_PAGECACHE_SIZE, True)
scratch_used = __status__(SQLITE_STATUS_SCRATCH_USED)
scratch_overflow = __status__(SQLITE_STATUS_SCRATCH_OVERFLOW)
scratch_size = __status__(SQLITE_STATUS_SCRATCH_SIZE, True)
# Connection status properties.
lookaside_used = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_USED)
lookaside_hit = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_HIT, True)
lookaside_miss = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_SIZE,
True)
lookaside_miss_full = __dbstatus__(SQLITE_DBSTATUS_LOOKASIDE_MISS_FULL,
True)
cache_used = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED, False, True)
#cache_used_shared = __dbstatus__(SQLITE_DBSTATUS_CACHE_USED_SHARED,
# False, True)
schema_used = __dbstatus__(SQLITE_DBSTATUS_SCHEMA_USED, False, True)
statement_used = __dbstatus__(SQLITE_DBSTATUS_STMT_USED, False, True)
cache_hit = __dbstatus__(SQLITE_DBSTATUS_CACHE_HIT, False, True)
cache_miss = __dbstatus__(SQLITE_DBSTATUS_CACHE_MISS, False, True)
cache_write = __dbstatus__(SQLITE_DBSTATUS_CACHE_WRITE, False, True)
class CSqliteExtDatabase(SqliteExtDatabase):
# XXX: here today, gone tomorrow.
def __init__(self, *args, **kwargs):
warnings.warn('CSqliteExtDatabase is deprecated. For equivalent '
'functionality use cysqlite_ext.CySqliteDatabase.',
DeprecationWarning)
super(CSqliteExtDatabase, self).__init__(*args, **kwargs)
def match(lhs, rhs):
return Expression(lhs, OP.MATCH, rhs)
def _parse_match_info(buf):
# See http://sqlite.org/fts3.html#matchinfo
bufsize = len(buf) # Length in bytes.
return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
def get_weights(ncol, raw_weights):
if not raw_weights:
return [1] * ncol
else:
weights = [0] * ncol
for i, weight in enumerate(raw_weights):
weights[i] = weight
return weights
# Ranking implementation, which parse matchinfo.
def rank(raw_match_info, *raw_weights):
# Handle match_info called w/default args 'pcx' - based on the example rank
# function http://sqlite.org/fts3.html#appendix_a
match_info = _parse_match_info(raw_match_info)
score = 0.0
p, c = match_info[:2]
weights = get_weights(c, raw_weights)
# matchinfo X value corresponds to, for each phrase in the search query, a
# list of 3 values for each column in the search table.
# So if we have a two-phrase search query and three columns of data, the
# following would be the layout:
# p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8]
# p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17]
for phrase_num in range(p):
phrase_info_idx = 2 + (phrase_num * c * 3)
for col_num in range(c):
weight = weights[col_num]
if not weight:
continue
col_idx = phrase_info_idx + (col_num * 3)
# The idea is that we count the number of times the phrase appears
# in this column of the current row, compared to how many times it
# appears in this column across all rows. The ratio of these values
# provides a rough way to score based on "high value" terms.
row_hits = match_info[col_idx]
all_rows_hits = match_info[col_idx + 1]
if row_hits > 0:
score += weight * (float(row_hits) / all_rows_hits)
return -score
# Okapi BM25 ranking implementation (FTS4 only).
def bm25(raw_match_info, *args):
"""
Usage:
# Format string *must* be pcnalx
# Second parameter to bm25 specifies the index of the column, on
# the table being queries.
bm25(matchinfo(document_tbl, 'pcnalx'), 1) AS rank
"""
match_info = _parse_match_info(raw_match_info)
K = 1.2
B = 0.75
score = 0.0
P_O, C_O, N_O, A_O = range(4) # Offsets into the matchinfo buffer.
term_count = match_info[P_O] # n
col_count = match_info[C_O]
total_docs = match_info[N_O] # N
L_O = A_O + col_count
X_O = L_O + col_count
# Worked example of pcnalx for two columns and two phrases, 100 docs total.
# {
# p = 2
# c = 2
# n = 100
# a0 = 4 -- avg number of tokens for col0, e.g. title
# a1 = 40 -- avg number of tokens for col1, e.g. body
# l0 = 5 -- curr doc has 5 tokens in col0
# l1 = 30 -- curr doc has 30 tokens in col1
#
# x000 -- hits this row for phrase0, col0
# x001 -- hits all rows for phrase0, col0
# x002 -- rows with phrase0 in col0 at least once
#
# x010 -- hits this row for phrase0, col1
# x011 -- hits all rows for phrase0, col1
# x012 -- rows with phrase0 in col1 at least once
#
# x100 -- hits this row for phrase1, col0
# x101 -- hits all rows for phrase1, col0
# x102 -- rows with phrase1 in col0 at least once
#
# x110 -- hits this row for phrase1, col1
# x111 -- hits all rows for phrase1, col1
# x112 -- rows with phrase1 in col1 at least once
# }
weights = get_weights(col_count, args)
for i in range(term_count):
for j in range(col_count):
weight = weights[j]
if weight == 0:
continue
x = X_O + (3 * (j + i * col_count))
term_frequency = float(match_info[x]) # f(qi, D)
docs_with_term = float(match_info[x + 2]) # n(qi)
# log( (N - n(qi) + 0.5) / (n(qi) + 0.5) )
idf = math.log(
(total_docs - docs_with_term + 0.5) /
(docs_with_term + 0.5))
if idf <= 0.0:
idf = 1e-6
doc_length = float(match_info[L_O + j]) # |D|
avg_length = float(match_info[A_O + j]) or 1. # avgdl
ratio = doc_length / avg_length
num = term_frequency * (K + 1.0)
b_part = 1.0 - B + (B * ratio)
denom = term_frequency + (K * b_part)
pc_score = idf * (num / denom)
score += (pc_score * weight)
return -score
def _json_contains(src_json, obj_json):
stack = []
try:
stack.append((json.loads(obj_json), json.loads(src_json)))
except:
# Invalid JSON!
return False
while stack:
obj, src = stack.pop()
if isinstance(src, dict):
if isinstance(obj, dict):
for key in obj:
if key not in src:
return False
stack.append((obj[key], src[key]))
elif isinstance(obj, list):
for item in obj:
if item not in src:
return False
elif obj not in src:
return False
elif isinstance(src, list):
if isinstance(obj, dict):
return False
elif isinstance(obj, list):
try:
for i in range(len(obj)):
stack.append((obj[i], src[i]))
except IndexError:
return False
elif obj not in src:
return False
elif obj != src:
return False
return True
+198 -110
View File
@@ -1,31 +1,26 @@
import collections
import datetime
import hashlib
import heapq
import json
import math
import os
import random
import re
import struct
import sys
import threading
import zlib
try:
from collections import Counter
except ImportError:
Counter = None
try:
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse
try:
from playhouse._sqlite_ext import TableFunction
except ImportError:
TableFunction = None
from urlparse import urlparse
SQLITE_DATETIME_FORMATS = (
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d %H:%M:%S.%f',
'%Y-%m-%d %H:%M:%S.%f%z',
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d %H:%M:%S%z',
'%Y-%m-%d',
'%H:%M:%S',
'%H:%M:%S.%f',
@@ -47,11 +42,12 @@ CONTROL_FLOW = 'control_flow'
DATE = 'date'
FILE = 'file'
HELPER = 'helpers'
JSON = 'json'
MATH = 'math'
RANK = 'rank'
STRING = 'string'
AGGREGATE_COLLECTION = {}
TABLE_FUNCTION_COLLECTION = {}
UDF_COLLECTION = {}
@@ -85,19 +81,10 @@ def aggregate(*groups):
return klass
return decorator
def table_function(*groups):
def decorator(klass):
for group in groups:
TABLE_FUNCTION_COLLECTION.setdefault(group, [])
TABLE_FUNCTION_COLLECTION[group].append(klass)
return klass
return decorator
def udf(*groups):
def udf(group, name=None):
def decorator(fn):
for group in groups:
UDF_COLLECTION.setdefault(group, [])
UDF_COLLECTION[group].append(fn)
UDF_COLLECTION.setdefault(group, [])
UDF_COLLECTION[group].append((fn, name or fn.__name__))
return fn
return decorator
@@ -112,33 +99,21 @@ def register_aggregate_groups(db, *groups):
seen.add(name)
db.register_aggregate(klass, name)
def register_table_function_groups(db, *groups):
seen = set()
for group in groups:
klasses = TABLE_FUNCTION_COLLECTION.get(group, ())
for klass in klasses:
if klass.name not in seen:
seen.add(klass.name)
db.register_table_function(klass)
def register_udf_groups(db, *groups):
seen = set()
for group in groups:
functions = UDF_COLLECTION.get(group, ())
for function in functions:
name = function.__name__
for function, name in functions:
if name not in seen:
seen.add(name)
db.register_function(function, name)
def register_groups(db, *groups):
register_aggregate_groups(db, *groups)
register_table_function_groups(db, *groups)
register_udf_groups(db, *groups)
def register_all(db):
register_aggregate_groups(db, *AGGREGATE_COLLECTION)
register_table_function_groups(db, *TABLE_FUNCTION_COLLECTION)
register_udf_groups(db, *UDF_COLLECTION)
@@ -288,11 +263,44 @@ def substr_count(haystack, needle):
def strip_chars(haystack, chars):
return haystack.strip(chars)
def _hash(constructor, *args):
hash_obj = constructor()
for arg in args:
hash_obj.update(arg)
return hash_obj.hexdigest()
@udf(JSON)
def json_contains(src_json, obj_json):
stack = []
try:
stack.append((json.loads(obj_json), json.loads(src_json)))
except:
# Invalid JSON!
return False
while stack:
obj, src = stack.pop()
if isinstance(src, dict):
if isinstance(obj, dict):
for key in obj:
if key not in src:
return False
stack.append((obj[key], src[key]))
elif isinstance(obj, list):
for item in obj:
if item not in src:
return False
elif obj not in src:
return False
elif isinstance(src, list):
if isinstance(obj, dict):
return False
elif isinstance(obj, list):
try:
for i in range(len(obj)):
stack.append((obj[i], src[i]))
except IndexError:
return False
elif obj not in src:
return False
elif obj != src:
return False
return True
# Aggregates.
class _heap_agg(object):
@@ -380,26 +388,15 @@ class duration(object):
@aggregate(MATH)
class mode(object):
if Counter:
def __init__(self):
self.items = Counter()
def __init__(self):
self.items = collections.Counter()
def step(self, *args):
self.items.update(args)
def step(self, *args):
self.items.update(args)
def finalize(self):
if self.items:
return self.items.most_common(1)[0][0]
else:
def __init__(self):
self.items = []
def step(self, item):
self.items.append(item)
def finalize(self):
if self.items:
return max(set(self.items), key=self.items.count)
def finalize(self):
if self.items:
return self.items.most_common(1)[0][0]
@aggregate(MATH)
class minrange(_heap_agg):
@@ -480,57 +477,148 @@ class stddev(object):
return math.sqrt(sum((i - mean) ** 2 for i in self.values) / (self.n - 1))
def _parse_match_info(buf):
# See http://sqlite.org/fts3.html#matchinfo
bufsize = len(buf) # Length in bytes.
return [struct.unpack('@I', buf[i:i+4])[0] for i in range(0, bufsize, 4)]
def get_weights(ncol, raw_weights):
if not raw_weights:
return [1] * ncol
else:
weights = [0] * ncol
for i, weight in enumerate(raw_weights):
weights[i] = weight
return weights
# Ranking implementation, which parse matchinfo.
def rank(raw_match_info, *raw_weights):
# Handle match_info called w/default args 'pcx' - based on the example rank
# function http://sqlite.org/fts3.html#appendix_a
match_info = _parse_match_info(raw_match_info)
score = 0.0
p, c = match_info[:2]
weights = get_weights(c, raw_weights)
# matchinfo X value corresponds to, for each phrase in the search query, a
# list of 3 values for each column in the search table.
# So if we have a two-phrase search query and three columns of data, the
# following would be the layout:
# p0 : c0=[0, 1, 2], c1=[3, 4, 5], c2=[6, 7, 8]
# p1 : c0=[9, 10, 11], c1=[12, 13, 14], c2=[15, 16, 17]
for phrase_num in range(p):
phrase_info_idx = 2 + (phrase_num * c * 3)
for col_num in range(c):
weight = weights[col_num]
if not weight:
continue
col_idx = phrase_info_idx + (col_num * 3)
# The idea is that we count the number of times the phrase appears
# in this column of the current row, compared to how many times it
# appears in this column across all rows. The ratio of these values
# provides a rough way to score based on "high value" terms.
row_hits = match_info[col_idx]
all_rows_hits = match_info[col_idx + 1]
if row_hits > 0:
score += weight * (float(row_hits) / all_rows_hits)
return -score
# Okapi BM25 ranking implementation (FTS4 only).
def bm25(raw_match_info, *args):
"""
Usage:
# Format string *must* be pcnalx
# Second parameter to bm25 specifies the index of the column, on
# the table being queries.
bm25(matchinfo(document_tbl, 'pcnalx'), 1) AS rank
"""
match_info = _parse_match_info(raw_match_info)
K = 1.2
B = 0.75
score = 0.0
P_O, C_O, N_O, A_O = range(4) # Offsets into the matchinfo buffer.
term_count = match_info[P_O] # n
col_count = match_info[C_O]
total_docs = match_info[N_O] # N
L_O = A_O + col_count
X_O = L_O + col_count
# Worked example of pcnalx for two columns and two phrases, 100 docs total.
# {
# p = 2
# c = 2
# n = 100
# a0 = 4 -- avg number of tokens for col0, e.g. title
# a1 = 40 -- avg number of tokens for col1, e.g. body
# l0 = 5 -- curr doc has 5 tokens in col0
# l1 = 30 -- curr doc has 30 tokens in col1
#
# x000 -- hits this row for phrase0, col0
# x001 -- hits all rows for phrase0, col0
# x002 -- rows with phrase0 in col0 at least once
#
# x010 -- hits this row for phrase0, col1
# x011 -- hits all rows for phrase0, col1
# x012 -- rows with phrase0 in col1 at least once
#
# x100 -- hits this row for phrase1, col0
# x101 -- hits all rows for phrase1, col0
# x102 -- rows with phrase1 in col0 at least once
#
# x110 -- hits this row for phrase1, col1
# x111 -- hits all rows for phrase1, col1
# x112 -- rows with phrase1 in col1 at least once
# }
weights = get_weights(col_count, args)
for i in range(term_count):
for j in range(col_count):
weight = weights[j]
if weight == 0:
continue
x = X_O + (3 * (j + i * col_count))
term_frequency = float(match_info[x]) # f(qi, D)
docs_with_term = float(match_info[x + 2]) # n(qi)
# log( (N - n(qi) + 0.5) / (n(qi) + 0.5) )
idf = math.log(
(total_docs - docs_with_term + 0.5) /
(docs_with_term + 0.5))
if idf <= 0.0:
idf = 1e-6
doc_length = float(match_info[L_O + j]) # |D|
avg_length = float(match_info[A_O + j]) or 1. # avgdl
ratio = doc_length / avg_length
num = term_frequency * (K + 1.0)
b_part = 1.0 - B + (B * ratio)
denom = term_frequency + (K * b_part)
pc_score = idf * (num / denom)
score += (pc_score * weight)
return -score
if cython_udf is not None:
rank = udf(RANK, 'fts_rank')(cython_udf.peewee_rank)
lucene = udf(RANK, 'fts_lucene')(cython_udf.peewee_lucene)
bm25 = udf(RANK, 'fts_bm25')(cython_udf.peewee_bm25)
bm25f = udf(RANK, 'fts_bm25f')(cython_udf.peewee_bm25f)
damerau_levenshtein_dist = udf(STRING)(cython_udf.damerau_levenshtein_dist)
levenshtein_dist = udf(STRING)(cython_udf.levenshtein_dist)
str_dist = udf(STRING)(cython_udf.str_dist)
median = aggregate(MATH)(cython_udf.median)
if TableFunction is not None:
@table_function(STRING)
class RegexSearch(TableFunction):
params = ['regex', 'search_string']
columns = ['match']
name = 'regex_search'
def initialize(self, regex=None, search_string=None):
self._iter = re.finditer(regex, search_string)
def iterate(self, idx):
return (next(self._iter).group(0),)
@table_function(DATE)
class DateSeries(TableFunction):
params = ['start', 'stop', 'step_seconds']
columns = ['date']
name = 'date_series'
def initialize(self, start, stop, step_seconds=86400):
self.start = format_date_time_sqlite(start)
self.stop = format_date_time_sqlite(stop)
step_seconds = int(step_seconds)
self.step_seconds = datetime.timedelta(seconds=step_seconds)
if (self.start.hour == 0 and
self.start.minute == 0 and
self.start.second == 0 and
step_seconds >= 86400):
self.format = '%Y-%m-%d'
elif (self.start.year == 1900 and
self.start.month == 1 and
self.start.day == 1 and
self.stop.year == 1900 and
self.stop.month == 1 and
self.stop.day == 1 and
step_seconds < 86400):
self.format = '%H:%M:%S'
else:
self.format = '%Y-%m-%d %H:%M:%S'
def iterate(self, idx):
if self.start > self.stop:
raise StopIteration
current = self.start
self.start += self.step_seconds
return (current.strftime(self.format),)
else:
rank = udf(RANK, 'fts_rank')(rank)
bm25 = udf(RANK, 'fts_bm25')(bm25)
+1 -1
View File
@@ -37,7 +37,7 @@ aiomysql = ["aiomysql", "greenlet"]
asyncpg = ["asyncpg ", "greenlet"]
[tool.setuptools]
packages = ["playhouse", "playhouse._pysqlite"]
packages = ["playhouse"]
py-modules = ["peewee", "pwiz"]
exclude-package-data = {"playhouse" = ["*.pyx"]}
+19 -85
View File
@@ -1,36 +1,21 @@
import os
import platform
import sys
try:
from distutils.errors import CCompilerError
from distutils.errors import DistutilsExecError
from distutils.errors import DistutilsPlatformError
except ImportError:
from setuptools._distutils.errors import CCompilerError
from setuptools._distutils.errors import DistutilsExecError
from setuptools._distutils.errors import DistutilsPlatformError
import os
from setuptools import setup
from setuptools.extension import Extension
extension_support = True # Assume we are building C extensions.
# Check if Cython is available and use it to generate extension modules. If
# Cython is not installed, we will fall back to using the pre-generated C files
# (so long as we're running on CPython).
try:
from Cython.Build import cythonize
from Cython.Distutils.extension import Extension
cython_installed = True
except ImportError:
cython_installed = False
else:
if platform.python_implementation() != 'CPython':
cython_installed = extension_support = False
else:
cython_installed = True
if sys.version_info[0] < 3:
FileNotFoundError = EnvironmentError
if platform.python_implementation() != 'CPython':
extension_support = False
elif os.environ.get('NO_SQLITE'):
# Retain backward-compat for not building C extensions.
extension_support = False
else:
extension_support = True
if cython_installed:
src_ext = '.pyx'
@@ -38,67 +23,16 @@ else:
src_ext = '.c'
cythonize = lambda obj: obj
sqlite_udf_module = Extension(
'playhouse._sqlite_udf',
['playhouse/_sqlite_udf' + src_ext])
sqlite_ext_module = Extension(
'playhouse._sqlite_ext',
['playhouse/_sqlite_ext' + src_ext],
libraries=['sqlite3'])
ext_modules = cythonize([sqlite_udf_module, sqlite_ext_module])
def _have_sqlite_extension_support():
import shutil
import tempfile
try:
from distutils.ccompiler import new_compiler
from distutils.sysconfig import customize_compiler
except ImportError:
from setuptools.command.build_ext import customize_compiler
from setuptools.command.build_ext import new_compiler
libraries = ['sqlite3']
c_code = ('#include <sqlite3.h>\n\n'
'int main(int argc, char **argv) { return 0; }')
tmp_dir = tempfile.mkdtemp(prefix='tmp_pw_sqlite3_')
bin_file = os.path.join(tmp_dir, 'test_pw_sqlite3')
src_file = bin_file + '.c'
with open(src_file, 'w') as fh:
fh.write(c_code)
compiler = new_compiler()
customize_compiler(compiler)
success = False
try:
compiler.link_shared_object(
compiler.compile([src_file], output_dir=tmp_dir),
bin_file,
libraries=['sqlite3'])
except CCompilerError:
print('unable to compile sqlite3 C extensions - missing headers?')
except DistutilsExecError:
print('unable to compile sqlite3 C extensions - no c compiler?')
except DistutilsPlatformError:
print('unable to compile sqlite3 C extensions - platform error')
except FileNotFoundError:
print('unable to compile sqlite3 C extensions - no compiler!')
else:
success = True
shutil.rmtree(tmp_dir)
return success
if extension_support:
if os.environ.get('NO_SQLITE'):
print('SQLite extensions will not be built at users request.')
ext_modules = []
elif not _have_sqlite_extension_support():
print('Could not find libsqlite3, extensions will not be built.')
ext_modules = []
sqlite_udf_module = Extension(
'playhouse._sqlite_udf',
['playhouse/_sqlite_udf' + src_ext])
ext_modules = cythonize([sqlite_udf_module])
else:
ext_modules = []
setup(name='peewee',
packages=['playhouse'],
py_modules=['peewee', 'pwiz'],
ext_modules=ext_modules)
setup(
name='peewee',
packages=['playhouse'],
py_modules=['peewee', 'pwiz'],
ext_modules=ext_modules)
+2 -2
View File
@@ -30,9 +30,9 @@ try:
except:
print('Unable to import CockroachDB tests, skipping.')
try:
from .csqlite_ext import *
from .cysqlite_ext import *
except ImportError:
print('Unable to import sqlite C extension tests, skipping.')
print('Unable to import cysqlite tests, skipping.')
from .dataset import *
from .db_url import *
from .extra_fields import *
+37 -176
View File
@@ -1,10 +1,10 @@
import glob
import os
import sys
import cysqlite
from peewee import *
from peewee import sqlite3
from playhouse.sqlite_ext import CYTHON_SQLITE_EXTENSIONS
from playhouse.sqlite_ext import *
from playhouse.cysqlite_ext import *
from .base import BaseTestCase
from .base import DatabaseTestCase
@@ -13,23 +13,22 @@ from .base import db_loader
from .base import skip_unless
database = CSqliteExtDatabase('peewee_test.db', timeout=100,
hash_functions=1)
database = CySqliteDatabase('peewee_test.db', timeout=100)
class CDatabaseTestCase(DatabaseTestCase):
class CyDatabaseTestCase(DatabaseTestCase):
database = database
def tearDown(self):
super(CDatabaseTestCase, self).tearDown()
if os.path.exists(self.database.database):
os.unlink(self.database.database)
super(CyDatabaseTestCase, self).tearDown()
for filename in glob.glob(self.database.database + '*'):
os.unlink(filename)
def execute(self, sql, *params):
return self.database.execute_sql(sql, params)
class TestCSqliteHelpers(CDatabaseTestCase):
class TestCSqliteHelpers(CyDatabaseTestCase):
def test_autocommit(self):
self.assertTrue(self.database.autocommit)
self.database.begin()
@@ -121,36 +120,7 @@ class TestCSqliteHelpers(CDatabaseTestCase):
self.assertTrue(self.database.cache_used is not None)
HUser = Table('users', ('id', 'username'))
class TestHashFunctions(CDatabaseTestCase):
database = database
def setUp(self):
super(TestHashFunctions, self).setUp()
self.database.execute_sql(
'create table users (id integer not null primary key, '
'username text not null)')
def test_md5(self):
for username in ('charlie', 'huey', 'zaizee'):
HUser.insert({HUser.username: username}).execute(self.database)
query = (HUser
.select(HUser.username,
fn.SUBSTR(fn.SHA1(HUser.username), 1, 6).alias('sha'))
.order_by(HUser.username)
.tuples()
.execute(self.database))
self.assertEqual(query[:], [
('charlie', 'd8cd10'),
('huey', '89b31a'),
('zaizee', 'b4dcf9')])
class TestBackup(CDatabaseTestCase):
class TestBackup(CyDatabaseTestCase):
backup_filenames = set(('test_backup.db', 'test_backup1.db',
'test_backup2.db'))
@@ -172,7 +142,7 @@ class TestBackup(CDatabaseTestCase):
self._populate_test_data()
# Back-up to an in-memory database and verify contents.
other_db = CSqliteExtDatabase(':memory:')
other_db = CySqliteDatabase(':memory:')
self.database.backup(other_db)
cursor = other_db.execute_sql('SELECT value FROM register ORDER BY '
'value;')
@@ -180,7 +150,7 @@ class TestBackup(CDatabaseTestCase):
other_db.close()
def test_backup_preserve_pagesize(self):
db1 = CSqliteExtDatabase('test_backup1.db')
db1 = CySqliteDatabase('test_backup1.db')
with db1.connection_context():
db1.page_size = 8192
self._populate_test_data(db=db1)
@@ -188,7 +158,7 @@ class TestBackup(CDatabaseTestCase):
db1.connect()
self.assertEqual(db1.page_size, 8192)
db2 = CSqliteExtDatabase('test_backup2.db')
db2 = CySqliteDatabase('test_backup2.db')
db1.backup(db2)
self.assertEqual(db2.page_size, 8192)
nrows, = db2.execute_sql('select count(*) from register;').fetchone()
@@ -198,7 +168,7 @@ class TestBackup(CDatabaseTestCase):
self._populate_test_data()
self.database.backup_to_file('test_backup.db')
backup_db = CSqliteExtDatabase('test_backup.db')
backup_db = CySqliteDatabase('test_backup.db')
cursor = backup_db.execute_sql('SELECT value FROM register ORDER BY '
'value;')
self.assertEqual([val for val, in cursor.fetchall()], list(range(100)))
@@ -211,7 +181,7 @@ class TestBackup(CDatabaseTestCase):
def progress(remaining, total, is_done):
accum.append((remaining, total, is_done))
other_db = CSqliteExtDatabase(':memory:')
other_db = CySqliteDatabase(':memory:')
self.database.backup(other_db, pages=1, progress=progress)
self.assertTrue(len(accum) > 0)
@@ -226,127 +196,13 @@ class TestBackup(CDatabaseTestCase):
def broken_progress(remaining, total, is_done):
raise ValueError('broken')
other_db = CSqliteExtDatabase(':memory:')
other_db = CySqliteDatabase(':memory:')
self.assertRaises(ValueError, self.database.backup, other_db,
progress=broken_progress)
other_db.close()
class TestBlob(CDatabaseTestCase):
def setUp(self):
super(TestBlob, self).setUp()
self.Register = Table('register', ('id', 'data'))
self.execute('CREATE TABLE register (id INTEGER NOT NULL PRIMARY KEY, '
'data BLOB NOT NULL)')
def create_blob_row(self, nbytes):
Register = self.Register.bind(self.database)
Register.insert({Register.data: ZeroBlob(nbytes)}).execute()
return self.database.last_insert_rowid
def test_blob(self):
rowid1024 = self.create_blob_row(1024)
rowid16 = self.create_blob_row(16)
blob = Blob(self.database, 'register', 'data', rowid1024)
self.assertEqual(len(blob), 1024)
blob.write(b'x' * 1022)
blob.write(b'zz')
blob.seek(1020)
self.assertEqual(blob.tell(), 1020)
data = blob.read(3)
self.assertEqual(data, b'xxz')
self.assertEqual(blob.read(), b'z')
self.assertEqual(blob.read(), b'')
blob.seek(-10, 2)
self.assertEqual(blob.tell(), 1014)
self.assertEqual(blob.read(), b'xxxxxxxxzz')
blob.reopen(rowid16)
self.assertEqual(blob.tell(), 0)
self.assertEqual(len(blob), 16)
blob.write(b'x' * 15)
self.assertEqual(blob.tell(), 15)
def test_blob_exceed_size(self):
rowid = self.create_blob_row(16)
blob = self.database.blob_open('register', 'data', rowid)
with self.assertRaisesCtx(ValueError):
blob.seek(17, 0)
with self.assertRaisesCtx(ValueError):
blob.write(b'x' * 17)
blob.write(b'x' * 16)
self.assertEqual(blob.tell(), 16)
blob.seek(0)
data = blob.read(17) # Attempting to read more data is OK.
self.assertEqual(data, b'x' * 16)
data = blob.read(1)
self.assertEqual(data, b'')
blob.seek(0)
blob.write(b'0123456789abcdef')
self.assertEqual(blob[0], b'0')
self.assertEqual(blob[-1], b'f')
self.assertRaises(IndexError, lambda: data[17])
blob.close()
def test_blob_errors_opening(self):
rowid = self.create_blob_row(4)
with self.assertRaisesCtx(OperationalError):
blob = self.database.blob_open('register', 'data', rowid + 1)
with self.assertRaisesCtx(OperationalError):
blob = self.database.blob_open('register', 'missing', rowid)
with self.assertRaisesCtx(OperationalError):
blob = self.database.blob_open('missing', 'data', rowid)
def test_blob_operating_on_closed(self):
rowid = self.create_blob_row(4)
blob = self.database.blob_open('register', 'data', rowid)
self.assertEqual(len(blob), 4)
blob.close()
with self.assertRaisesCtx(InterfaceError):
len(blob)
self.assertRaises(InterfaceError, blob.read)
self.assertRaises(InterfaceError, blob.write, b'foo')
self.assertRaises(InterfaceError, blob.seek, 0, 0)
self.assertRaises(InterfaceError, blob.tell)
self.assertRaises(InterfaceError, blob.reopen, rowid)
blob.close() # Safe to call again.
def test_blob_readonly(self):
rowid = self.create_blob_row(4)
blob = self.database.blob_open('register', 'data', rowid)
blob.write(b'huey')
blob.seek(0)
self.assertEqual(blob.read(), b'huey')
blob.close()
blob = self.database.blob_open('register', 'data', rowid, True)
self.assertEqual(blob.read(), b'huey')
blob.seek(0)
with self.assertRaisesCtx(OperationalError):
blob.write(b'meow')
# BLOB is read-only.
self.assertEqual(blob.read(), b'huey')
class DataTypes(TableFunction):
class DataTypes(cysqlite.TableFunction):
columns = ('key', 'value')
params = ()
name = 'data_types'
@@ -369,20 +225,25 @@ class DataTypes(TableFunction):
raise StopIteration
@skip_unless(sqlite3.sqlite_version_info >= (3, 9), 'requires sqlite >= 3.9')
class TestDataTypesTableFunction(CDatabaseTestCase):
database = db_loader('sqlite')
@skip_unless(cysqlite.sqlite_version_info >= (3, 9), 'requires sqlite >= 3.9')
class TestDataTypesTableFunction(CyDatabaseTestCase):
database = db_loader('cysqlite')
def test_data_types_table_function(self):
self.database.register_table_function(DataTypes)
cursor = self.database.execute_sql('SELECT key, value '
'FROM data_types() ORDER BY key')
self.assertEqual(cursor.fetchall(), [
('k0', None),
('k1', 1),
('k2', 2.),
('k3', u'unicode str'),
('k4', b'byte str'),
('k5', 0),
('k6', 1),
])
for _ in range(2):
cursor = self.database.execute_sql('SELECT key, value FROM '
'data_types() ORDER BY key')
self.assertEqual(cursor.fetchall(), [
('k0', None),
('k1', 1),
('k2', 2.),
('k3', u'unicode str'),
('k4', b'byte str'),
('k5', 0),
('k6', 1),
])
# Ensure table re-registered after close.
self.database.close()
self.database.connect()
+3 -439
View File
@@ -6,7 +6,6 @@ import sys
from peewee import *
from peewee import sqlite3
from playhouse.sqlite_ext import *
from playhouse._sqlite_ext import TableFunction
from .base import BaseTestCase
from .base import IS_SQLITE_37
@@ -28,7 +27,7 @@ from .sqlite_helpers import json_text_installed
from .sqlite_helpers import jsonb_installed
database = SqliteExtDatabase(':memory:', c_extensions=False, timeout=100)
database = SqliteExtDatabase(':memory:', timeout=100)
CLOSURE_EXTENSION = os.environ.get('PEEWEE_CLOSURE_EXTENSION')
@@ -40,7 +39,7 @@ if not LSM_EXTENSION and os.path.exists('lsm.so'):
LSM_EXTENSION = './lsm.so'
try:
from playhouse._sqlite_ext import peewee_rank
from playhouse._sqlite_udf import peewee_rank
CYTHON_EXTENSION = True
except ImportError:
CYTHON_EXTENSION = False
@@ -158,249 +157,6 @@ class DT(TestModel):
iso = ISODateTimeField()
class Series(TableFunction):
columns = ['value']
params = ['start', 'stop', 'step']
name = 'series'
def initialize(self, start=0, stop=None, step=1):
self.start = start
self.stop = stop or float('inf')
self.step = step
self.curr = self.start
def iterate(self, idx):
if self.curr > self.stop:
raise StopIteration
ret = self.curr
self.curr += self.step
return (ret,)
class RegexSearch(TableFunction):
columns = ['match']
params = ['regex', 'search_string']
name = 'regex_search'
def initialize(self, regex=None, search_string=None):
if regex and search_string:
self._iter = re.finditer(regex, search_string)
else:
self._iter = None
def iterate(self, idx):
# We do not need `idx`, so just ignore it.
if self._iter is None:
raise StopIteration
else:
return (next(self._iter).group(0),)
class Split(TableFunction):
params = ['data']
columns = ['part']
name = 'str_split'
def initialize(self, data=None):
self._parts = data.split()
self._idx = 0
def iterate(self, idx):
if self._idx < len(self._parts):
result = (self._parts[self._idx],)
self._idx += 1
return result
raise StopIteration
@skip_unless(IS_SQLITE_9, 'requires sqlite >= 3.9')
class TestTableFunction(BaseTestCase):
def setUp(self):
super(TestTableFunction, self).setUp()
self.conn = sqlite3.connect(':memory:')
def tearDown(self):
super(TestTableFunction, self).tearDown()
self.conn.close()
def execute(self, sql, params=None):
return self.conn.execute(sql, params or ())
def test_split(self):
Split.register(self.conn)
curs = self.execute('select part from str_split(?) order by part '
'limit 3', ('well hello huey and zaizee',))
self.assertEqual([row for row, in curs.fetchall()],
['and', 'hello', 'huey'])
def test_split_tbl(self):
Split.register(self.conn)
self.execute('create table post (content TEXT);')
self.execute('insert into post (content) values (?), (?), (?)',
('huey secret post',
'mickey message',
'zaizee diary'))
curs = self.execute('SELECT * FROM post, str_split(post.content)')
results = curs.fetchall()
self.assertEqual(results, [
('huey secret post', 'huey'),
('huey secret post', 'secret'),
('huey secret post', 'post'),
('mickey message', 'mickey'),
('mickey message', 'message'),
('zaizee diary', 'zaizee'),
('zaizee diary', 'diary'),
])
def test_series(self):
Series.register(self.conn)
def assertSeries(params, values, extra_sql=''):
param_sql = ', '.join('?' * len(params))
sql = 'SELECT * FROM series(%s)' % param_sql
if extra_sql:
sql = ' '.join((sql, extra_sql))
curs = self.execute(sql, params)
self.assertEqual([row for row, in curs.fetchall()], values)
assertSeries((0, 10, 2), [0, 2, 4, 6, 8, 10])
assertSeries((5, None, 20), [5, 25, 45, 65, 85], 'LIMIT 5')
assertSeries((4, 0, -1), [4, 3, 2], 'LIMIT 3')
assertSeries((3, 5, 3), [3])
assertSeries((3, 3, 1), [3])
def test_series_tbl(self):
Series.register(self.conn)
self.execute('CREATE TABLE nums (id INTEGER PRIMARY KEY)')
self.execute('INSERT INTO nums DEFAULT VALUES;')
self.execute('INSERT INTO nums DEFAULT VALUES;')
curs = self.execute('SELECT * FROM nums, series(nums.id, nums.id + 2)')
results = curs.fetchall()
self.assertEqual(results, [
(1, 1), (1, 2), (1, 3),
(2, 2), (2, 3), (2, 4)])
curs = self.execute('SELECT * FROM nums, series(nums.id) LIMIT 3')
results = curs.fetchall()
self.assertEqual(results, [(1, 1), (1, 2), (1, 3)])
def test_regex(self):
RegexSearch.register(self.conn)
def assertResults(regex, search_string, values):
sql = 'SELECT * FROM regex_search(?, ?)'
curs = self.execute(sql, (regex, search_string))
self.assertEqual([row for row, in curs.fetchall()], values)
assertResults(
'[0-9]+',
'foo 123 45 bar 678 nuggie 9.0',
['123', '45', '678', '9', '0'])
assertResults(
r'[\w]+@[\w]+\.[\w]{2,3}',
('Dear charlie@example.com, this is nug@baz.com. I am writing on '
'behalf of zaizee@foo.io. He dislikes your blog.'),
['charlie@example.com', 'nug@baz.com', 'zaizee@foo.io'])
assertResults(
'[a-z]+',
'123.pDDFeewXee',
['p', 'eew', 'ee'])
assertResults(
'[0-9]+',
'hello',
[])
def test_regex_tbl(self):
messages = (
'hello foo@example.fap, this is nuggie@example.fap. How are you?',
'baz@example.com wishes to let charlie@crappyblog.com know that '
'huey@example.com hates his blog',
'testing no emails.',
'')
RegexSearch.register(self.conn)
self.execute('create table posts (id integer primary key, msg)')
self.execute('insert into posts (msg) values (?), (?), (?), (?)',
messages)
cur = self.execute('select posts.id, regex_search.rowid, regex_search.match '
'FROM posts, regex_search(?, posts.msg)',
(r'[\w]+@[\w]+\.\w{2,3}',))
results = cur.fetchall()
self.assertEqual(results, [
(1, 1, 'foo@example.fap'),
(1, 2, 'nuggie@example.fap'),
(2, 3, 'baz@example.com'),
(2, 4, 'charlie@crappyblog.com'),
(2, 5, 'huey@example.com'),
])
def test_error_instantiate(self):
class BrokenInstantiate(Series):
name = 'broken_instantiate'
print_tracebacks = False
def __init__(self, *args, **kwargs):
super(BrokenInstantiate, self).__init__(*args, **kwargs)
raise ValueError('broken instantiate')
BrokenInstantiate.register(self.conn)
self.assertRaises(sqlite3.OperationalError, self.execute,
'SELECT * FROM broken_instantiate(1, 10)')
def test_error_init(self):
class BrokenInit(Series):
name = 'broken_init'
print_tracebacks = False
def initialize(self, start=0, stop=None, step=1):
raise ValueError('broken init')
BrokenInit.register(self.conn)
self.assertRaises(sqlite3.OperationalError, self.execute,
'SELECT * FROM broken_init(1, 10)')
self.assertRaises(sqlite3.OperationalError, self.execute,
'SELECT * FROM broken_init(0, 1)')
def test_error_iterate(self):
class BrokenIterate(Series):
name = 'broken_iterate'
print_tracebacks = False
def iterate(self, idx):
raise ValueError('broken iterate')
BrokenIterate.register(self.conn)
self.assertRaises(sqlite3.OperationalError, self.execute,
'SELECT * FROM broken_iterate(1, 10)')
self.assertRaises(sqlite3.OperationalError, self.execute,
'SELECT * FROM broken_iterate(0, 1)')
def test_error_iterate_delayed(self):
# Only raises an exception if the value 7 comes up.
class SomewhatBroken(Series):
name = 'somewhat_broken'
print_tracebacks = False
def iterate(self, idx):
ret = super(SomewhatBroken, self).iterate(idx)
if ret == (7,):
raise ValueError('somewhat broken')
else:
return ret
SomewhatBroken.register(self.conn)
curs = self.execute('SELECT * FROM somewhat_broken(0, 3)')
self.assertEqual(curs.fetchall(), [(0,), (1,), (2,), (3,)])
curs = self.execute('SELECT * FROM somewhat_broken(5, 8)')
self.assertEqual(curs.fetchone(), (5,))
self.assertRaises(sqlite3.OperationalError, curs.fetchall)
curs = self.execute('SELECT * FROM somewhat_broken(0, 2)')
self.assertEqual(curs.fetchall(), [(0,), (1,), (2,)])
@skip_unless(json_installed(), 'requires sqlite json1')
class TestJSONField(ModelTestCase):
database = database
@@ -1367,14 +1123,8 @@ class TestFullTextSearch(BaseFTSTestCase, ModelTestCase):
[(0, -0.85), (1, -0.)])
@skip_unless(CYTHON_EXTENSION, 'requires sqlite c extension')
@skip_unless(CYTHON_EXTENSION, 'requires _sqlite_udf c extension')
class TestFullTextSearchCython(TestFullTextSearch):
database = SqliteExtDatabase(':memory:', c_extensions=CYTHON_EXTENSION)
def test_c_extensions(self):
self.assertTrue(self.database._c_extensions)
self.assertTrue(Post._meta.database._c_extensions)
def test_bm25f(self):
def assertResults(term, expected):
query = MultiColumn.search_bm25f(term, [1.0, 0, 0, 0], True)
@@ -1636,36 +1386,6 @@ class TestFTS5(BaseFTSTestCase, ModelTestCase):
self.assertEqual(FTS5Test.clean_query(a), b)
@skip_unless(CYTHON_EXTENSION, 'requires sqlite c extension')
class TestMurmurHash(ModelTestCase):
database = SqliteExtDatabase(':memory:', c_extensions=CYTHON_EXTENSION,
hash_functions=True)
def assertHash(self, s, e, fn_name='murmurhash'):
func = getattr(fn, fn_name)
query = Select(columns=[func(s)])
cursor = self.database.execute(query)
self.assertEqual(cursor.fetchone()[0], e)
@skip_if(sys.byteorder == 'big', 'fails on big endian')
def test_murmur_hash(self):
self.assertHash('testkey', 2871421366)
self.assertHash('murmur', 3883399899)
self.assertHash('', 0)
self.assertHash('this is a test of a longer string', 2569735385)
self.assertHash(None, None)
@skip_if(sys.version_info[0] == 3, 'requres python 2')
def test_checksums(self):
self.assertHash('testkey', -225678656, 'crc32')
self.assertHash('murmur', 1507884895, 'crc32')
self.assertHash('', 0, 'crc32')
self.assertHash('testkey', 203686666, 'adler32')
self.assertHash('murmur', 155714217, 'adler32')
self.assertHash('', 1, 'adler32')
class TestUserDefinedCallbacks(ModelTestCase):
database = database
requires = [Post, Values]
@@ -2089,162 +1809,6 @@ class TestTransitiveClosureIntegration(BaseTestCase):
database.drop_tables([Node, NodeClosure])
class KV(LSMTable):
key = TextField(primary_key=True)
val_b = BlobField()
val_i = IntegerField()
val_f = FloatField()
val_t = TextField()
class Meta:
database = database
filename = 'test_lsm.ldb'
class KVS(LSMTable):
key = TextField(primary_key=True)
value = TextField()
class Meta:
database = database
filename = 'test_lsm.ldb'
class KVI(LSMTable):
key = IntegerField(primary_key=True)
value = TextField()
class Meta:
database = database
filename = 'test_lsm.ldb'
@skip_unless(LSM_EXTENSION and os.path.exists(LSM_EXTENSION),
'requires lsm1 sqlite extension')
class TestLSM1Extension(BaseTestCase):
def setUp(self):
super(TestLSM1Extension, self).setUp()
if os.path.exists(KV._meta.filename):
os.unlink(KV._meta.filename)
database.connect()
database.load_extension(LSM_EXTENSION.rstrip('.so'))
def tearDown(self):
super(TestLSM1Extension, self).tearDown()
database.unload_extension(LSM_EXTENSION.rstrip('.so'))
database.close()
if os.path.exists(KV._meta.filename):
os.unlink(KV._meta.filename)
def test_lsm_extension(self):
self.assertSQL(KV._schema._create_table(), (
'CREATE VIRTUAL TABLE IF NOT EXISTS "kv" USING lsm1 '
'("test_lsm.ldb", "key", TEXT, "val_b", "val_i", '
'"val_f", "val_t")'), [])
self.assertSQL(KVS._schema._create_table(), (
'CREATE VIRTUAL TABLE IF NOT EXISTS "kvs" USING lsm1 '
'("test_lsm.ldb", "key", TEXT, "value")'), [])
self.assertSQL(KVI._schema._create_table(), (
'CREATE VIRTUAL TABLE IF NOT EXISTS "kvi" USING lsm1 '
'("test_lsm.ldb", "key", UINT, "value")'), [])
def test_lsm_crud_operations(self):
database.create_tables([KV])
with database.transaction():
KV.create(key='k0', val_b=None, val_i=0, val_f=0.1, val_t='v0')
v0 = KV['k0']
self.assertEqual(v0.key, 'k0')
self.assertEqual(v0.val_b, None)
self.assertEqual(v0.val_i, 0)
self.assertEqual(v0.val_f, 0.1)
self.assertEqual(v0.val_t, 'v0')
self.assertRaises(KV.DoesNotExist, lambda: KV['k1'])
# Test that updates work as expected.
KV['k0'] = (None, 1338, 3.14, 'v2-e')
v0_db = KV['k0']
self.assertEqual(v0_db.val_i, 1338)
self.assertEqual(v0_db.val_f, 3.14)
self.assertEqual(v0_db.val_t, 'v2-e')
self.assertEqual(len([item for item in KV.select()]), 1)
del KV['k0']
self.assertEqual(len([item for item in KV.select()]), 0)
def test_insert_replace(self):
database.create_tables([KVS])
KVS.insert({'key': 'k0', 'value': 'v0'}).execute()
self.assertEqual(KVS['k0'], 'v0')
KVS.replace({'key': 'k0', 'value': 'v0-e'}).execute()
self.assertEqual(KVS['k0'], 'v0-e')
# Implicit.
KVS['k0'] = 'v0-x'
self.assertEqual(KVS['k0'], 'v0-x')
def test_index_performance(self):
database.create_tables([KVS])
data = [{'key': 'k%s' % i, 'value': 'v%s' % i} for i in range(20)]
KVS.insert_many(data).execute()
self.assertEqual(KVS.select().count(), 20)
self.assertEqual(KVS['k0'], 'v0')
self.assertEqual(KVS['k19'], 'v19')
keys = [row.key for row in KVS['k4.1':'k8.9']]
self.assertEqual(keys, ['k5', 'k6', 'k7', 'k8'])
keys = sorted([row.key for row in KVS[:'k13']])
self.assertEqual(keys, ['k0', 'k1', 'k10', 'k11', 'k12', 'k13'])
keys = [row.key for row in KVS['k5':]]
self.assertEqual(keys, ['k5', 'k6', 'k7', 'k8', 'k9'])
data = [tuple(row) for row in KVS[KVS.key > 'k5']]
self.assertEqual(data, [
('k6', 'v6'),
('k7', 'v7'),
('k8', 'v8'),
('k9', 'v9')])
del KVS[KVS.key.between('k10', 'k18')]
self.assertEqual(sorted([row.key for row in KVS[:'k2']]),
['k0', 'k1', 'k19', 'k2'])
del KVS['k3.1':'k8.1']
self.assertEqual([row.key for row in KVS[:]],
['k0', 'k1', 'k19', 'k2', 'k3', 'k9'])
del KVS['k1']
self.assertRaises(KVS.DoesNotExist, lambda: KVS['k1'])
def test_index_uint(self):
database.create_tables([KVI])
data = [{'key': i, 'value': 'v%s' % i} for i in range(100)]
with database.transaction():
KVI.insert_many(data).execute()
keys = [row.key for row in KVI[27:33]]
self.assertEqual(keys, [27, 28, 29, 30, 31, 32, 33])
keys = sorted([row.key for row in KVI[KVI.key < 4]])
self.assertEqual(keys, [0, 1, 2, 3])
keys = [row.key for row in KVI[KVI.key > 95]]
self.assertEqual(keys, [96, 97, 98, 99])
@skip_unless(json_installed(), 'requires json1 sqlite extension')
class TestJsonContains(ModelTestCase):
database = SqliteExtDatabase(':memory:', json_contains=True)
-77
View File
@@ -12,10 +12,6 @@ from .base import ModelTestCase
from .base import TestModel
from .base import db_loader
from .base import skip_unless
try:
from playhouse import _sqlite_ext as cython_ext
except ImportError:
cython_ext = None
try:
from playhouse import _sqlite_udf as cython_udf
except ImportError:
@@ -415,76 +411,3 @@ class TestScalarFunctions(BaseTestUDF):
self.assertEqual(
self.sql1('select strip_chars(?, ?)', ' hey foo ', ' '),
'hey foo')
@skip_unless(cython_ext is not None, 'requires sqlite c extension')
@skip_unless(sqlite3.sqlite_version_info >= (3, 9), 'requires sqlite >= 3.9')
class TestVirtualTableFunctions(ModelTestCase):
database = database
requires = MODELS
def sqln(self, sql, *p):
cursor = self.database.execute_sql(sql, p)
return cursor.fetchall()
def test_regex_search(self):
usernames = [
'charlie',
'hu3y17',
'zaizee2012',
'1234.56789',
'hurr durr']
for username in usernames:
User.create(username=username)
rgx = '[0-9]+'
results = self.sqln(
('SELECT user.username, regex_search.match '
'FROM user, regex_search(?, user.username) '
'ORDER BY regex_search.match'),
rgx)
self.assertEqual([row for row in results], [
('1234.56789', '1234'),
('hu3y17', '17'),
('zaizee2012', '2012'),
('hu3y17', '3'),
('1234.56789', '56789'),
])
def test_date_series(self):
ONE_DAY = 86400
def assertValues(start, stop, step_seconds, expected):
results = self.sqln('select * from date_series(?, ?, ?)',
start, stop, step_seconds)
self.assertEqual(results, expected)
assertValues('2015-01-01', '2015-01-05', 86400, [
('2015-01-01',),
('2015-01-02',),
('2015-01-03',),
('2015-01-04',),
('2015-01-05',),
])
assertValues('2015-01-01', '2015-01-05', 86400 / 2, [
('2015-01-01 00:00:00',),
('2015-01-01 12:00:00',),
('2015-01-02 00:00:00',),
('2015-01-02 12:00:00',),
('2015-01-03 00:00:00',),
('2015-01-03 12:00:00',),
('2015-01-04 00:00:00',),
('2015-01-04 12:00:00',),
('2015-01-05 00:00:00',),
])
assertValues('14:20:15', '14:24', 30, [
('14:20:15',),
('14:20:45',),
('14:21:15',),
('14:21:45',),
('14:22:15',),
('14:22:45',),
('14:23:15',),
('14:23:45',),
])