mirror of
https://github.com/coleifer/peewee.git
synced 2026-05-06 07:56:41 -04:00
1589 lines
57 KiB
Python
1589 lines
57 KiB
Python
import asyncio
|
|
import collections
|
|
import contextvars
|
|
import glob
|
|
import itertools
|
|
import tempfile
|
|
import os
|
|
import unittest
|
|
from unittest.mock import Mock, AsyncMock, MagicMock, patch
|
|
|
|
from peewee import *
|
|
from playhouse.pwasyncio import *
|
|
from playhouse.pwasyncio import _State, _ConnectionState, _lazy_cursor_iter
|
|
from .base import MYSQL_PARAMS
|
|
from .base import PSQL_PARAMS
|
|
from .base import IS_MYSQL
|
|
from .base import IS_POSTGRESQL
|
|
|
|
try:
|
|
import asyncpg
|
|
except ImportError:
|
|
asyncpg = None
|
|
|
|
try:
|
|
import aiomysql
|
|
except ImportError:
|
|
aiomysql = None
|
|
|
|
import aiosqlite
|
|
|
|
SQLITE_RETURNING = aiosqlite.sqlite_version_info >= (3, 35, 0)
|
|
|
|
|
|
class TestModel(Model):
|
|
name = CharField()
|
|
value = IntegerField(default=0)
|
|
|
|
class User(Model):
|
|
username = CharField()
|
|
|
|
class Tweet(Model):
|
|
user = ForeignKeyField(User, backref='tweets')
|
|
message = TextField()
|
|
|
|
|
|
class TestGreenletSpawn(unittest.IsolatedAsyncioTestCase):
|
|
async def test_simple_function(self):
|
|
result = await greenlet_spawn(lambda x, y: x + y, 5, 3)
|
|
self.assertEqual(result, 8)
|
|
|
|
async def test_function_with_await(self):
|
|
async def async_helper():
|
|
await asyncio.sleep(0.01)
|
|
return 2
|
|
def func():
|
|
return await_(async_helper()) * 2
|
|
self.assertEqual(await greenlet_spawn(func), 4)
|
|
|
|
async def test_multiple_awaits(self):
|
|
async def fetch_value(val):
|
|
await asyncio.sleep(0.01)
|
|
return val
|
|
def multi():
|
|
return sum([await_(fetch_value(i)) for i in [10, 20, 30]])
|
|
self.assertEqual(await greenlet_spawn(multi), 60)
|
|
|
|
async def test_exception_propagation(self):
|
|
with self.assertRaises(ValueError):
|
|
await greenlet_spawn(lambda: (_ for _ in ()).throw(ValueError('x')))
|
|
|
|
async def test_exception_in_awaitable(self):
|
|
async def fail():
|
|
raise RuntimeError('async error')
|
|
with self.assertRaises(RuntimeError):
|
|
await greenlet_spawn(lambda: await_(fail()))
|
|
|
|
def test_await_outside_greenlet(self):
|
|
with self.assertRaises(MissingGreenletBridge):
|
|
await_(Mock())
|
|
|
|
async def test_contextvars(self):
|
|
var = contextvars.ContextVar('data', default='x')
|
|
state = []
|
|
def get_var():
|
|
state.append(var.get())
|
|
async def aget_var():
|
|
await greenlet_spawn(get_var)
|
|
|
|
var.set('y')
|
|
await aget_var()
|
|
await greenlet_spawn(lambda: await_(aget_var()))
|
|
await aget_var()
|
|
self.assertEqual(state, ['y', 'y', 'y'])
|
|
|
|
|
|
class TestConnectionState(unittest.IsolatedAsyncioTestCase):
|
|
async def test_task_isolation(self):
|
|
cs = _ConnectionState()
|
|
async def worker(tid):
|
|
cs._current().conn = tid
|
|
await asyncio.sleep(0.01)
|
|
return cs._current().conn
|
|
results = await asyncio.gather(*[worker(i) for i in range(5)])
|
|
self.assertEqual(results, [0, 1, 2, 3, 4])
|
|
|
|
async def test_state_attributes(self):
|
|
cs = _ConnectionState()
|
|
cs.set_connection('c')
|
|
cs.transactions.append(1)
|
|
self.assertEqual(cs.conn, 'c')
|
|
self.assertFalse(cs.closed)
|
|
self.assertEqual(cs.transactions, [1])
|
|
|
|
async def test_get_returns_fresh_state(self):
|
|
s = _ConnectionState()._current()
|
|
self.assertIsNone(s.conn)
|
|
self.assertTrue(s.closed)
|
|
self.assertEqual(s.transactions, [])
|
|
|
|
async def test_reset(self):
|
|
cs = _ConnectionState()
|
|
cs.set_connection('x')
|
|
cs.transactions.append(1)
|
|
cs.reset()
|
|
self.assertIsNone(cs.conn)
|
|
self.assertTrue(cs.closed)
|
|
self.assertEqual(cs.transactions, [])
|
|
|
|
async def test_set_connection(self):
|
|
cs = _ConnectionState()
|
|
m = Mock()
|
|
cs.set_connection(m)
|
|
self.assertIs(cs.conn, m)
|
|
self.assertFalse(cs.closed)
|
|
|
|
async def test_done_callback_orphans_connection(self):
|
|
cs = _ConnectionState()
|
|
conn_mock = Mock()
|
|
|
|
async def acquire_and_abandon():
|
|
cs.set_connection(conn_mock)
|
|
return id(asyncio.current_task())
|
|
|
|
task_id = await asyncio.create_task(acquire_and_abandon())
|
|
# After the task completes, the done-callback should have fired.
|
|
await asyncio.sleep(0)
|
|
self.assertNotIn(task_id, cs._states)
|
|
self.assertIn(conn_mock, cs._orphaned_conns)
|
|
|
|
async def test_done_callback_noop_when_closed(self):
|
|
cs = _ConnectionState()
|
|
|
|
async def open_and_close():
|
|
cs.set_connection(Mock())
|
|
cs.reset() # Simulate proper close.
|
|
|
|
await asyncio.create_task(open_and_close())
|
|
await asyncio.sleep(0)
|
|
self.assertEqual(cs._orphaned_conns, [])
|
|
|
|
|
|
def _make_lazy_cursor(rows, batch_size=2):
|
|
it = iter(rows)
|
|
fetch_counts = []
|
|
cleanup_called = []
|
|
|
|
async def fetch_many(count):
|
|
fetch_counts.append(count)
|
|
return list(itertools.islice(it, count))
|
|
|
|
async def cleanup():
|
|
cleanup_called.append(True)
|
|
|
|
cursor = CursorAdapter(
|
|
description=[('id',), ('name',)],
|
|
fetch_many=fetch_many,
|
|
cleanup=cleanup,
|
|
buffer_size=batch_size)
|
|
return cursor, fetch_counts, cleanup_called
|
|
|
|
|
|
class TestCursorAdapter(unittest.IsolatedAsyncioTestCase):
|
|
def test_eager_fetchone(self):
|
|
c = CursorAdapter([(1, 'a'), (2, 'b'), (3, 'c')])
|
|
self.assertEqual(c.fetchone(), (1, 'a'))
|
|
self.assertEqual(c.fetchone(), (2, 'b'))
|
|
self.assertEqual(c.fetchone(), (3, 'c'))
|
|
self.assertIsNone(c.fetchone())
|
|
|
|
def test_eager_fetchall(self):
|
|
rows = [(1,), (2,)]
|
|
c = CursorAdapter(rows)
|
|
self.assertIs(c.fetchall(), rows)
|
|
|
|
def test_eager_iter(self):
|
|
rows = [(1,), (2,), (3,)]
|
|
self.assertEqual(list(CursorAdapter(rows)), rows)
|
|
self.assertEqual(CursorAdapter(rows).rowcount, 3)
|
|
|
|
def test_eager_metadata(self):
|
|
c = CursorAdapter()
|
|
self.assertEqual(c._rows, [])
|
|
self.assertEqual(c.rowcount, 0)
|
|
self.assertEqual(c.description, [])
|
|
self.assertIsNone(c.fetchone())
|
|
self.assertEqual(list(c), [])
|
|
|
|
c = CursorAdapter([(1,)], lastrowid=5, rowcount=1,
|
|
description=[('id',)])
|
|
self.assertEqual(c.lastrowid, 5)
|
|
self.assertEqual(c.rowcount, 1)
|
|
self.assertEqual(c.description, [('id',)])
|
|
|
|
async def test_lazy_fetchone_batches(self):
|
|
rows = [(i,) for i in range(5)]
|
|
cursor, counts, _ = _make_lazy_cursor(rows, batch_size=2)
|
|
|
|
collected = []
|
|
def drain():
|
|
while True:
|
|
r = cursor.fetchone()
|
|
if r is None:
|
|
break
|
|
collected.append(r)
|
|
await greenlet_spawn(drain)
|
|
|
|
self.assertEqual(collected, rows)
|
|
# 2 + 2 + 1 + 0(empty) = 4 calls, each requesting 2
|
|
self.assertEqual(len(counts), 4)
|
|
self.assertTrue(all(c == 2 for c in counts))
|
|
|
|
async def test_lazy_fetchone_empty(self):
|
|
cursor, _, _ = _make_lazy_cursor([], batch_size=2)
|
|
self.assertIsNone(await greenlet_spawn(cursor.fetchone))
|
|
# Already exhausted, still returns None.
|
|
self.assertIsNone(await greenlet_spawn(cursor.fetchone))
|
|
|
|
async def test_lazy_iter(self):
|
|
rows = [(i,) for i in range(7)]
|
|
cursor, counts, _ = _make_lazy_cursor(rows, batch_size=3)
|
|
self.assertEqual(await greenlet_spawn(list, cursor), rows)
|
|
# 3 + 3 + 1 + 0 = 4 calls
|
|
self.assertEqual(len(counts), 4)
|
|
|
|
async def test_lazy_fetchall(self):
|
|
rows = [(1,), (2,), (3,)]
|
|
cursor, _, _ = _make_lazy_cursor(rows, batch_size=10)
|
|
self.assertEqual(await greenlet_spawn(cursor.fetchall), rows)
|
|
|
|
async def test_lazy_buffer_reuse(self):
|
|
rows = [(i,) for i in range(3)]
|
|
cursor, counts, _ = _make_lazy_cursor(rows, batch_size=10)
|
|
await greenlet_spawn(cursor.fetchone)
|
|
self.assertEqual(len(counts), 1)
|
|
await greenlet_spawn(cursor.fetchone)
|
|
await greenlet_spawn(cursor.fetchone)
|
|
self.assertEqual(len(counts), 1) # still 1
|
|
await greenlet_spawn(cursor.fetchone) # second fetch (empty).
|
|
self.assertEqual(len(counts), 2)
|
|
|
|
async def test_lazy_description(self):
|
|
cursor, _, _ = _make_lazy_cursor([], batch_size=2)
|
|
self.assertEqual(cursor.description, [('id',), ('name',)])
|
|
|
|
async def test_lazy_buffer_size_override(self):
|
|
rows = [(i,) for i in range(10)]
|
|
cursor, counts, _ = _make_lazy_cursor(rows, batch_size=5)
|
|
cursor._buffer_size = 3
|
|
await greenlet_spawn(list, cursor)
|
|
self.assertTrue(all(c == 3 for c in counts))
|
|
|
|
async def test_aclose_cleanup(self):
|
|
cursor, _, cleanup = _make_lazy_cursor([], batch_size=2)
|
|
await cursor.aclose()
|
|
self.assertEqual(cleanup, [True])
|
|
self.assertIsNone(cursor._fetch_many)
|
|
self.assertIsNone(cursor._cleanup)
|
|
|
|
async def test_aclose_idempotent(self):
|
|
call_count = []
|
|
async def cleanup():
|
|
call_count.append(1)
|
|
cursor, _, _ = _make_lazy_cursor([], batch_size=2)
|
|
cursor._cleanup = cleanup
|
|
await cursor.aclose()
|
|
await cursor.aclose()
|
|
self.assertEqual(len(call_count), 1)
|
|
|
|
async def test_aclose_noop_for_eager(self):
|
|
await CursorAdapter([(1,)]).aclose() # must not raise
|
|
|
|
async def test_lazy_cursor_iter(self):
|
|
rows = [(1,), (2,), (3,)]
|
|
cursor, _, _ = _make_lazy_cursor(rows, batch_size=10)
|
|
result = await greenlet_spawn(list, _lazy_cursor_iter(cursor))
|
|
self.assertEqual(result, rows)
|
|
|
|
async def test_lazy_cursor_iter_empty(self):
|
|
cursor, _, _ = _make_lazy_cursor([], batch_size=2)
|
|
result = await greenlet_spawn(list, _lazy_cursor_iter(cursor))
|
|
self.assertEqual(result, [])
|
|
|
|
|
|
class TestConnectionWrappers(unittest.IsolatedAsyncioTestCase):
|
|
async def test_sqlite_execute(self):
|
|
mock_cursor = AsyncMock()
|
|
mock_cursor.fetchall.return_value = [(1, 'test')]
|
|
mock_cursor.lastrowid = 1
|
|
mock_cursor.rowcount = 1
|
|
mock_cursor.description = [('id',), ('name',)]
|
|
mock_conn = AsyncMock()
|
|
mock_conn.execute.return_value = mock_cursor
|
|
|
|
result = await AsyncSqliteConnection(mock_conn).execute(
|
|
'SELECT * FROM test')
|
|
self.assertIsInstance(result, CursorAdapter)
|
|
self.assertEqual(result.fetchall(), [(1, 'test')])
|
|
self.assertEqual(result.lastrowid, 1)
|
|
mock_cursor.close.assert_awaited_once()
|
|
|
|
async def test_sqlite_execute_iter_returns_lazy(self):
|
|
mock_cursor = AsyncMock()
|
|
mock_cursor.description = [('a',), ('b',)]
|
|
mock_conn = AsyncMock()
|
|
mock_conn.execute.return_value = mock_cursor
|
|
|
|
conn = AsyncSqliteConnection(mock_conn)
|
|
cursor = await conn.execute_iter('SELECT a, b FROM t')
|
|
self.assertIsInstance(cursor, CursorAdapter)
|
|
self.assertIsNotNone(cursor._fetch_many)
|
|
self.assertEqual(cursor.description, [('a',), ('b',)])
|
|
await cursor.aclose()
|
|
|
|
async def test_sqlite_execute_iter_lock_lifecycle(self):
|
|
mock_cursor = AsyncMock()
|
|
mock_cursor.description = []
|
|
mock_conn = AsyncMock()
|
|
mock_conn.execute.return_value = mock_cursor
|
|
|
|
conn = AsyncSqliteConnection(mock_conn)
|
|
cursor = await conn.execute_iter('SELECT 1')
|
|
self.assertTrue(conn._lock.locked())
|
|
await cursor.aclose()
|
|
self.assertFalse(conn._lock.locked())
|
|
mock_cursor.close.assert_awaited_once()
|
|
|
|
async def test_sqlite_execute_iter_lock_on_failure(self):
|
|
mock_conn = AsyncMock()
|
|
mock_conn.execute.side_effect = RuntimeError('fail')
|
|
conn = AsyncSqliteConnection(mock_conn)
|
|
with self.assertRaises(RuntimeError):
|
|
await conn.execute_iter('invalid')
|
|
self.assertFalse(conn._lock.locked())
|
|
|
|
async def test_mysql_execute(self):
|
|
mock_cursor = AsyncMock()
|
|
mock_cursor.fetchall.return_value = [(1, 'test')]
|
|
mock_cursor.lastrowid = 1
|
|
mock_cursor.rowcount = 1
|
|
mock_cursor.description = [('id',), ('name',)]
|
|
mock_conn = AsyncMock()
|
|
mock_conn.cursor.return_value = mock_cursor
|
|
|
|
result = await AsyncMySQLConnection(mock_conn).execute(
|
|
'SELECT * FROM test')
|
|
self.assertIsInstance(result, CursorAdapter)
|
|
self.assertEqual(result.fetchall(), [(1, 'test')])
|
|
mock_cursor.close.assert_awaited_once()
|
|
|
|
async def test_mysql_cursor_closed_on_error(self):
|
|
mock_cursor = AsyncMock()
|
|
mock_cursor.execute.side_effect = RuntimeError('fail')
|
|
mock_conn = AsyncMock()
|
|
mock_conn.cursor.return_value = mock_cursor
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
await AsyncMySQLConnection(mock_conn).execute('invalid')
|
|
mock_cursor.close.assert_awaited_once()
|
|
|
|
async def test_mysql_concurrent_serialized(self):
|
|
order = []
|
|
async def tracked(sql, params):
|
|
order.append(f'start-{sql}')
|
|
await asyncio.sleep(0.05)
|
|
order.append(f'end-{sql}')
|
|
return []
|
|
mock_cursor = AsyncMock()
|
|
mock_cursor.execute = tracked
|
|
mock_conn = AsyncMock()
|
|
mock_conn.cursor.return_value = mock_cursor
|
|
|
|
conn = AsyncMySQLConnection(mock_conn)
|
|
await asyncio.gather(conn.execute('Q1', None),
|
|
conn.execute('Q2', None))
|
|
idx = {e: i for i, e in enumerate(order)}
|
|
self.assertTrue(idx['end-Q1'] < idx['start-Q2']
|
|
or idx['end-Q2'] < idx['start-Q1'])
|
|
|
|
async def test_mysql_execute_iter_uses_ss_cursor(self):
|
|
import playhouse.pwasyncio as mod
|
|
mock_cursor = AsyncMock()
|
|
mock_cursor.description = [('x',)]
|
|
mock_cursor.execute = AsyncMock()
|
|
mock_conn = AsyncMock()
|
|
mock_conn.cursor.return_value = mock_cursor
|
|
|
|
sentinel = object()
|
|
with patch.object(mod, 'aiomysql') as m:
|
|
m.SSCursor = sentinel
|
|
cursor = await AsyncMySQLConnection(mock_conn).execute_iter(
|
|
'SELECT 1')
|
|
mock_conn.cursor.assert_awaited_once_with(sentinel)
|
|
self.assertIsNotNone(cursor._fetch_many)
|
|
await cursor.aclose()
|
|
|
|
async def test_mysql_execute_iter_lock_lifecycle(self):
|
|
import playhouse.pwasyncio as mod
|
|
mock_cursor = AsyncMock()
|
|
mock_cursor.description = [('x',)]
|
|
mock_cursor.execute = AsyncMock()
|
|
mock_cursor.close = AsyncMock()
|
|
mock_conn = AsyncMock()
|
|
mock_conn.cursor.return_value = mock_cursor
|
|
|
|
conn = AsyncMySQLConnection(mock_conn)
|
|
with patch.object(mod, 'aiomysql', create=True) as m:
|
|
m.SSCursor = object()
|
|
cursor = await conn.execute_iter('SELECT 1')
|
|
self.assertTrue(conn._lock.locked())
|
|
await cursor.aclose()
|
|
self.assertFalse(conn._lock.locked())
|
|
mock_cursor.close.assert_awaited_once()
|
|
|
|
async def test_pg_parameter_conversion(self):
|
|
mock_record = Mock()
|
|
mock_record.keys.return_value = ['id', 'name']
|
|
mock_conn = AsyncMock()
|
|
mock_conn.fetch.return_value = [mock_record]
|
|
|
|
await AsyncPostgresqlConnection(mock_conn).execute(
|
|
'SELECT * FROM t WHERE id = %s AND name = %s', (1, 'x'))
|
|
sql = mock_conn.fetch.call_args[0][0]
|
|
self.assertEqual(sql, 'SELECT * FROM t WHERE id = $1 AND name = $2')
|
|
|
|
async def test_pg_concurrent_serialized(self):
|
|
order = []
|
|
async def tracked(sql, params=None):
|
|
order.append(f'start-{sql}')
|
|
await asyncio.sleep(0.05)
|
|
order.append(f'end-{sql}')
|
|
return []
|
|
mock_conn = AsyncMock()
|
|
mock_conn.fetch = tracked
|
|
|
|
conn = AsyncPostgresqlConnection(mock_conn)
|
|
await asyncio.gather(conn.execute('Q1', None),
|
|
conn.execute('Q2', None))
|
|
idx = {e: i for i, e in enumerate(order)}
|
|
self.assertTrue(idx['end-Q1'] < idx['start-Q2']
|
|
or idx['end-Q2'] < idx['start-Q1'])
|
|
|
|
async def test_pg_no_params(self):
|
|
mock_conn = AsyncMock()
|
|
mock_conn.fetch.return_value = []
|
|
await AsyncPostgresqlConnection(mock_conn).execute(
|
|
'SELECT * FROM t', None)
|
|
mock_conn.fetch.assert_called_once_with('SELECT * FROM t')
|
|
|
|
async def test_pg_empty_results(self):
|
|
mock_conn = AsyncMock()
|
|
mock_conn.fetch.return_value = []
|
|
r = await AsyncPostgresqlConnection(mock_conn).execute(
|
|
'SELECT * FROM empty')
|
|
self.assertEqual(r.fetchall(), [])
|
|
self.assertEqual(r.description, [])
|
|
|
|
def test_translate_placeholders(self):
|
|
f = AsyncPostgresqlConnection._translate_placeholders
|
|
self.assertEqual(f('SELECT 1'), 'SELECT 1')
|
|
self.assertEqual(
|
|
f('SELECT * FROM t WHERE a = %s AND b = %s'),
|
|
'SELECT * FROM t WHERE a = $1 AND b = $2')
|
|
self.assertEqual(
|
|
f('INSERT INTO t VALUES (%s, %s, %s)'),
|
|
'INSERT INTO t VALUES ($1, $2, $3)')
|
|
|
|
def _pg_mocks(self, rows=None):
|
|
rows = rows or []
|
|
attr = MagicMock(); attr.name = 'col1'
|
|
it = iter(rows)
|
|
mock_cursor = AsyncMock()
|
|
async def _fetch(count):
|
|
return list(itertools.islice(it, count))
|
|
mock_cursor.fetch = _fetch
|
|
|
|
mock_stmt = AsyncMock()
|
|
mock_stmt.cursor.return_value = mock_cursor
|
|
mock_stmt.get_attributes = MagicMock(return_value=[attr])
|
|
|
|
mock_tr = AsyncMock()
|
|
mock_conn = MagicMock()
|
|
mock_conn.transaction.return_value = mock_tr
|
|
mock_conn.prepare = AsyncMock(return_value=mock_stmt)
|
|
return mock_conn, mock_tr, mock_stmt, mock_cursor
|
|
|
|
async def test_pg_execute_iter_description(self):
|
|
mock_conn, _, mock_stmt, _ = self._pg_mocks()
|
|
a1, a2 = MagicMock(), MagicMock()
|
|
a1.name = 'id'; a2.name = 'username'
|
|
mock_stmt.get_attributes.return_value = [a1, a2]
|
|
|
|
conn = AsyncPostgresqlConnection(mock_conn)
|
|
cursor = await conn.execute_iter('SELECT id, username FROM users')
|
|
self.assertEqual(cursor.description, [('id',), ('username',)])
|
|
await cursor.aclose()
|
|
|
|
async def test_pg_execute_iter_starts_transaction(self):
|
|
mock_conn, mock_tr, _, _ = self._pg_mocks()
|
|
conn = AsyncPostgresqlConnection(mock_conn)
|
|
cursor = await conn.execute_iter('SELECT 1')
|
|
mock_conn.transaction.assert_called_once()
|
|
mock_tr.start.assert_awaited_once()
|
|
await cursor.aclose()
|
|
|
|
async def test_pg_execute_iter_cleanup_rolls_back(self):
|
|
mock_conn, mock_tr, _, _ = self._pg_mocks()
|
|
conn = AsyncPostgresqlConnection(mock_conn)
|
|
cursor = await conn.execute_iter('SELECT 1')
|
|
self.assertTrue(conn._lock.locked())
|
|
await cursor.aclose()
|
|
mock_tr.rollback.assert_awaited_once()
|
|
self.assertFalse(conn._lock.locked())
|
|
|
|
async def test_pg_execute_iter_translates_placeholders(self):
|
|
mock_conn, _, _, _ = self._pg_mocks()
|
|
conn = AsyncPostgresqlConnection(mock_conn)
|
|
cursor = await conn.execute_iter(
|
|
'SELECT * FROM t WHERE a = %s AND b = %s', params=(1, 2))
|
|
sql = mock_conn.prepare.call_args[0][0]
|
|
self.assertIn('$1', sql)
|
|
self.assertNotIn('%s', sql)
|
|
await cursor.aclose()
|
|
|
|
async def test_pg_execute_iter_lock_on_failure(self):
|
|
mock_conn, _, _, _ = self._pg_mocks()
|
|
mock_conn.prepare = AsyncMock(side_effect=RuntimeError('fail'))
|
|
conn = AsyncPostgresqlConnection(mock_conn)
|
|
with self.assertRaises(RuntimeError):
|
|
await conn.execute_iter('invalid')
|
|
self.assertFalse(conn._lock.locked())
|
|
|
|
|
|
class TestTaskLifecycle(unittest.IsolatedAsyncioTestCase):
|
|
async def asyncSetUp(self):
|
|
with tempfile.NamedTemporaryFile(delete=False) as f:
|
|
self.db_path = f.name
|
|
|
|
self.db = AsyncSqliteDatabase(self.db_path)
|
|
TestModel._meta.set_database(self.db)
|
|
|
|
async with self.db:
|
|
await self.db.acreate_tables([TestModel])
|
|
|
|
async def asyncTearDown(self):
|
|
await self.db.close_pool()
|
|
|
|
if self.db_path and os.path.exists(self.db_path):
|
|
for fname in glob.glob(self.db_path + '*'):
|
|
os.unlink(fname)
|
|
|
|
async def test_task_id_behavior(self):
|
|
async def a1(db):
|
|
accum = []
|
|
async with db:
|
|
accum.append(db._state._current())
|
|
accum.append(await a2(db))
|
|
return accum
|
|
|
|
async def a2(db):
|
|
async with db:
|
|
return db._state._current()
|
|
|
|
def s1(db):
|
|
accum = []
|
|
with db.connection_context():
|
|
accum.append(db._state._current())
|
|
accum.append(s2(db))
|
|
return accum
|
|
|
|
def s2(db):
|
|
with db.connection_context():
|
|
return db._state._current()
|
|
|
|
async with self.db:
|
|
ids = await(a1(self.db))
|
|
ids.extend(await self.db.run(s1, self.db))
|
|
|
|
self.assertEqual(len(ids), 4)
|
|
self.assertEqual(len(set(ids)), 1)
|
|
|
|
async with self.db:
|
|
ids = await asyncio.create_task(a1(self.db))
|
|
ids.extend(await asyncio.create_task(self.db.run(s1, self.db)))
|
|
|
|
self.assertEqual(len(ids), 4)
|
|
self.assertEqual(len(set(ids)), 2)
|
|
|
|
async def test_task_state_cleanup_after_completion(self):
|
|
async def task_with_state():
|
|
async with self.db:
|
|
await self.db.run(TestModel.create, name='test', value=1)
|
|
return id(asyncio.current_task())
|
|
|
|
await asyncio.create_task(task_with_state())
|
|
|
|
# Child task properly closed connection (async with db), so close pool
|
|
# should exit cleanly.
|
|
await asyncio.wait_for(self.db.close_pool(), timeout=2.0)
|
|
|
|
# Verify the write persisted.
|
|
async with self.db:
|
|
self.assertEqual(await self.db.count(TestModel.select()), 1)
|
|
|
|
async def test_concurrent_task_state_isolation(self):
|
|
async def capture(tid):
|
|
async with self.db:
|
|
before = id(self.db._state._current())
|
|
await self.db.run(TestModel.create, name=f't{tid}', value=tid)
|
|
after = id(self.db._state._current())
|
|
self.assertEqual(before, after)
|
|
return before
|
|
|
|
results = await asyncio.gather(*[capture(i) for i in range(5)])
|
|
self.assertTrue(all(results))
|
|
self.assertEqual(len(set(results)), 5)
|
|
|
|
async def test_connection_returned_when_task_dies(self):
|
|
async def acquire_and_abandon():
|
|
await self.db.aconnect()
|
|
return # Connection is not closed, callback must handle cleanup.
|
|
|
|
await asyncio.create_task(acquire_and_abandon())
|
|
|
|
# The done-callback should have moved the connection to orphaned
|
|
# connections, which are handled either via done callback or during
|
|
# pool shutdown.
|
|
await asyncio.wait_for(self.db.close_pool(), timeout=2.0)
|
|
|
|
|
|
class IntegrationTests(object):
|
|
db_path = None
|
|
models = [TestModel, User, Tweet]
|
|
|
|
def get_database(self):
|
|
with tempfile.NamedTemporaryFile(delete=False) as f:
|
|
self.db_path = f.name
|
|
return AsyncSqliteDatabase(self.db_path)
|
|
|
|
def tearDown(self):
|
|
if self.db_path and os.path.exists(self.db_path):
|
|
os.unlink(self.db_path)
|
|
|
|
async def asyncSetUp(self):
|
|
try:
|
|
self.db = self.get_database()
|
|
await self.db.aconnect()
|
|
await self.db.aclose()
|
|
except Exception as exc:
|
|
self.skipTest(f'Cannot connect: {exc}')
|
|
|
|
if isinstance(self.db, AsyncSqliteDatabase):
|
|
self.driver = 'sqlite'
|
|
self.support_returning = SQLITE_RETURNING
|
|
elif isinstance(self.db, AsyncMySQLDatabase):
|
|
self.driver = 'mysql'
|
|
self.support_returning = False
|
|
elif isinstance(self.db, AsyncPostgresqlDatabase):
|
|
self.driver = 'postgresql'
|
|
self.support_returning = True
|
|
else:
|
|
raise ValueError('Unrecognized driver')
|
|
|
|
for m in self.models:
|
|
m._meta.set_database(self.db)
|
|
async with self.db:
|
|
await self.db.adrop_tables(self.models)
|
|
await self.db.acreate_tables(self.models)
|
|
|
|
async def asyncTearDown(self):
|
|
await self.db.aclose()
|
|
async with self.db:
|
|
await self.db.adrop_tables(self.models)
|
|
await self.db.close_pool()
|
|
|
|
async def create_record(self, name='test', value=1):
|
|
return await self.db.run(TestModel.create, name=name, value=value)
|
|
|
|
async def assertCount(self, expected):
|
|
count = await self.db.run(TestModel.select().count)
|
|
self.assertEqual(count, expected)
|
|
|
|
async def assertNames(self, expected):
|
|
curs = await self.db.list(TestModel.select().order_by(TestModel.name))
|
|
self.assertEqual([tm.name for tm in curs], expected)
|
|
|
|
async def seed(self, n=20):
|
|
def _seed():
|
|
with self.db.atomic():
|
|
for i in range(n):
|
|
TestModel.create(name=f'item{i:02d}', value=i * 10)
|
|
await self.db.run(_seed)
|
|
|
|
async def test_pool_created_on_connect(self):
|
|
await self.db.aclose()
|
|
await self.db.close_pool()
|
|
self.assertIsNone(self.db._pool)
|
|
await self.db.aconnect()
|
|
self.assertIsNotNone(self.db._pool)
|
|
self.assertIsNotNone(self.db._state.conn)
|
|
self.assertFalse(self.db.is_closed())
|
|
await self.db.aclose()
|
|
self.assertIsNone(self.db._state.conn)
|
|
self.assertTrue(self.db.is_closed())
|
|
|
|
async def test_is_closed(self):
|
|
for i in range(2):
|
|
await self.db.aconnect()
|
|
self.assertFalse(self.db.is_closed())
|
|
await self.db.aclose()
|
|
self.assertTrue(self.db.is_closed())
|
|
|
|
async def test_multiple_close_safe(self):
|
|
await self.db.aclose()
|
|
self.assertTrue(self.db.is_closed())
|
|
await self.db.aclose()
|
|
await self.db.aconnect()
|
|
self.assertFalse(self.db.is_closed())
|
|
|
|
async def test_reconnect_after_pool_close(self):
|
|
await self.create_record('first', 1)
|
|
await self.db.aclose()
|
|
await self.db.close_pool()
|
|
self.assertIsNone(self.db._pool)
|
|
async with self.db:
|
|
await self.assertCount(1)
|
|
self.assertIsNotNone(self.db._pool)
|
|
self.assertTrue(self.db.is_closed())
|
|
|
|
async def test_connection_reuse_within_task(self):
|
|
await self.db.aconnect()
|
|
c1 = self.db._state.conn
|
|
await self.create_record('a', 1)
|
|
c2 = self.db._state.conn
|
|
await self.create_record('b', 2)
|
|
self.assertIs(c1, c2)
|
|
self.assertIs(c2, self.db._state.conn)
|
|
|
|
async def test_closing_flag_prevents_connect(self):
|
|
self.db._closing = True
|
|
try:
|
|
with self.assertRaises(InterfaceError):
|
|
await self.db.aconnect()
|
|
finally:
|
|
self.db._closing = False
|
|
|
|
async def test_double_close_pool(self):
|
|
await self.db.aclose()
|
|
await self.db.close_pool()
|
|
await self.db.close_pool()
|
|
|
|
async def test_dead_connection_replaced(self):
|
|
if self.driver == 'mysql':
|
|
self.skipTest('closing underlying conn incompatible with aiomysql')
|
|
return
|
|
|
|
await self.db.aconnect()
|
|
conn = self.db._state.conn
|
|
await conn.close()
|
|
await self.create_record('test', 1)
|
|
self.assertIsNot(self.db._state.conn, conn)
|
|
await self.assertCount(1)
|
|
|
|
async def test_context_manager(self):
|
|
async with self.db:
|
|
self.assertIsNotNone(self.db._state.conn)
|
|
self.assertFalse(self.db._state.closed)
|
|
self.assertFalse(self.db.is_closed())
|
|
self.assertIsNone(self.db._state.conn)
|
|
self.assertTrue(self.db._state.closed)
|
|
self.assertTrue(self.db.is_closed())
|
|
|
|
async def test_exception_in_context_manager(self):
|
|
try:
|
|
async with self.db:
|
|
raise RuntimeError('fail')
|
|
except RuntimeError:
|
|
pass
|
|
self.assertTrue(self.db._state.closed)
|
|
self.assertTrue(self.db.is_closed())
|
|
async with self.db:
|
|
await self.create_record('after_error', 1)
|
|
self.assertFalse(self.db.is_closed())
|
|
await self.assertCount(1)
|
|
|
|
self.assertTrue(self.db.is_closed())
|
|
|
|
async def test_execute_sql(self):
|
|
iq, iparams = User.insert(username='x').sql()
|
|
sq, _= User.select().sql()
|
|
await self.db.aexecute_sql(iq, iparams)
|
|
r = await self.db.aexecute_sql(sq)
|
|
self.assertEqual(r.fetchall()[0][1], 'x')
|
|
|
|
async def test_multiple_tasks_raw_sql(self):
|
|
iq, _ = User.insert(username='x').sql()
|
|
sq, _ = User.select(User.username).where(User.username == 'x').sql()
|
|
|
|
async def worker(tid):
|
|
username = f'u{tid}'
|
|
await self.db.aconnect()
|
|
await self.db.aexecute_sql(iq, (username,))
|
|
r = await self.db.aexecute_sql(sq, (username,))
|
|
row = r.fetchone()
|
|
self.assertEqual(row[0], username)
|
|
await self.db.aclose()
|
|
return row
|
|
|
|
results = await asyncio.gather(*[worker(i) for i in range(3)])
|
|
self.assertEqual(sorted(results), [('u0',), ('u1',), ('u2',)])
|
|
|
|
async def test_list(self):
|
|
await self.seed(5)
|
|
query = TestModel.select().order_by(TestModel.value)
|
|
results = await self.db.list(query)
|
|
self.assertEqual(len(results), 5)
|
|
self.assertIsInstance(results[0], TestModel)
|
|
self.assertEqual([r.value for r in results], [0, 10, 20, 30, 40])
|
|
|
|
async def test_list_empty(self):
|
|
self.assertEqual(await self.db.list(TestModel.select()), [])
|
|
|
|
async def test_get(self):
|
|
rec = await self.create_record('unique', 999)
|
|
q = TestModel.select().where(TestModel.name == 'unique')
|
|
fetched = await self.db.get(q)
|
|
self.assertEqual(fetched.id, rec.id)
|
|
self.assertEqual(fetched.name, 'unique')
|
|
self.assertEqual(fetched.value, 999)
|
|
|
|
async def test_get_not_found(self):
|
|
with self.assertRaises(TestModel.DoesNotExist):
|
|
q = TestModel.select().where(TestModel.id == 0)
|
|
await self.db.get(q)
|
|
|
|
async def test_scalar(self):
|
|
await self.seed(10)
|
|
query = TestModel.select(fn.MAX(TestModel.value))
|
|
self.assertEqual(await self.db.scalar(query), 90)
|
|
|
|
async def test_scalar_no_results(self):
|
|
query = TestModel.select(fn.COUNT(TestModel.id))
|
|
self.assertEqual(await self.db.scalar(query), 0)
|
|
|
|
await self.seed(5)
|
|
self.assertEqual(await self.db.scalar(query), 5)
|
|
|
|
async def test_count(self):
|
|
self.assertEqual(await self.db.count(TestModel.select()), 0)
|
|
await self.seed(5)
|
|
self.assertEqual(await self.db.count(TestModel.select()), 5)
|
|
|
|
async def test_exists(self):
|
|
self.assertFalse(await self.db.exists(TestModel.select()))
|
|
await self.create_record('x', 1)
|
|
self.assertTrue(await self.db.exists(TestModel.select()))
|
|
|
|
async def test_aexecute(self):
|
|
q = TestModel.insert_many([(f'item{i}', i) for i in range(10)])
|
|
if self.support_returning:
|
|
q = q.returning(TestModel.name)
|
|
res = await self.db.aexecute(q)
|
|
self.assertEqual([t.name for t in res],
|
|
[f'item{i}' for i in range(10)])
|
|
else:
|
|
await self.db.aexecute(q)
|
|
await self.assertCount(10)
|
|
|
|
q = (TestModel
|
|
.update(value=TestModel.value * 10)
|
|
.where(TestModel.value < 3))
|
|
|
|
if self.support_returning:
|
|
q = q.returning(TestModel.name, TestModel.value)
|
|
res = await self.db.aexecute(q)
|
|
self.assertEqual(sorted([(t.name, t.value) for t in res]),
|
|
[('item0', 0), ('item1', 10), ('item2', 20)])
|
|
else:
|
|
res = await self.db.aexecute(q)
|
|
self.assertEqual(res, 2)
|
|
|
|
q = TestModel.select().where(TestModel.value >= 10)
|
|
self.assertEqual(await self.db.run(q.count), 2)
|
|
|
|
rows = await self.db.aexecute(q.order_by(TestModel.value))
|
|
self.assertEqual([r.name for r in rows], ['item1', 'item2'])
|
|
|
|
q = TestModel.delete().where(TestModel.value >= 10)
|
|
if self.support_returning:
|
|
q = q.returning(TestModel.name, TestModel.value)
|
|
res = await self.db.aexecute(q)
|
|
self.assertEqual(sorted([(t.name, t.value) for t in res]),
|
|
[('item1', 10), ('item2', 20)])
|
|
else:
|
|
res = await self.db.aexecute(q)
|
|
self.assertEqual(res, 2)
|
|
|
|
async def test_run_contextvars(self):
|
|
var = contextvars.ContextVar('v', default='x')
|
|
state = []
|
|
def do_run():
|
|
state.append(var.get())
|
|
var.set('y')
|
|
state.append(var.get())
|
|
await self.db.run(do_run)
|
|
state.append(var.get())
|
|
self.assertEqual(state, ['y', 'y', 'y'])
|
|
|
|
async def test_create(self):
|
|
tm = await self.create_record('test1', 100)
|
|
self.assertEqual(tm.name, 'test1')
|
|
self.assertEqual(tm.value, 100)
|
|
|
|
tm = await self.db.run(TestModel.create, name='test2', value=101)
|
|
self.assertEqual(tm.name, 'test2')
|
|
self.assertEqual(tm.value, 101)
|
|
|
|
await self.assertCount(2)
|
|
await self.assertNames(['test1', 'test2'])
|
|
|
|
async def test_select(self):
|
|
tm = await self.create_record('test1', 100)
|
|
res = await self.db.list(TestModel.select())
|
|
self.assertEqual(len(res), 1)
|
|
self.assertEqual(res[0].name, 'test1')
|
|
self.assertEqual(res[0], tm)
|
|
|
|
async def test_filter(self):
|
|
await self.seed(20)
|
|
query = TestModel.select().where(TestModel.value > 100)
|
|
results = await self.db.list(query)
|
|
self.assertEqual(len(results), 9)
|
|
|
|
async def test_ordering(self):
|
|
await self.seed(20)
|
|
query = TestModel.select().order_by(TestModel.value.desc()).limit(5)
|
|
results = await self.db.list(query)
|
|
self.assertEqual(results[0].value, 190)
|
|
self.assertEqual(results[4].value, 150)
|
|
|
|
async def test_create_save_update(self):
|
|
await self.create_record('test1', 100)
|
|
|
|
def do_update():
|
|
r = TestModel.get(TestModel.name == 'test1')
|
|
r.value = 999; r.save()
|
|
return TestModel.get(TestModel.name == 'test1').value
|
|
|
|
self.assertEqual(await self.db.run(do_update), 999)
|
|
|
|
uq = TestModel.update(name='test1x').where(TestModel.name == 'test1')
|
|
res = await self.db.aexecute(uq)
|
|
#self.assertEqual(res, 1)
|
|
|
|
q = TestModel.select().where(TestModel.name == 'test1x')
|
|
tm = await self.db.get(q)
|
|
self.assertEqual(tm.value, 999)
|
|
|
|
async def test_update(self):
|
|
await self.seed(50)
|
|
await self.db.aexecute(TestModel
|
|
.update(value=TestModel.value + 1000)
|
|
.where(TestModel.value < 250))
|
|
query = TestModel.select().where(TestModel.value >= 1000)
|
|
self.assertEqual(await self.db.count(query), 25)
|
|
|
|
query = TestModel.select(fn.SUM(TestModel.value))
|
|
self.assertEqual(await self.db.scalar(query), 37250)
|
|
|
|
async def test_delete(self):
|
|
await self.seed(20)
|
|
await self.db.aexecute(TestModel.delete().where(TestModel.value < 50))
|
|
await self.assertCount(15)
|
|
|
|
tm = await self.db.get(TestModel.select())
|
|
await self.db.run(tm.delete_instance)
|
|
await self.assertCount(14)
|
|
|
|
async def test_bulk_create(self):
|
|
recs = [TestModel(name=f'b{i}', value=i) for i in range(100)]
|
|
await self.db.run(TestModel.bulk_create, recs, batch_size=25)
|
|
await self.assertCount(100)
|
|
|
|
async def test_bulk_update(self):
|
|
if self.driver == 'postgresql':
|
|
self.skipTest('bulk_update incompatible with asyncpg')
|
|
return
|
|
|
|
accum = [await self.db.run(TestModel.create, name=f'b{i}', value=i)
|
|
for i in range(5)]
|
|
for tm in accum:
|
|
tm.name += '-x'
|
|
tm.value += 100
|
|
|
|
await self.db.run(TestModel.bulk_update, accum,
|
|
fields=[TestModel.name, TestModel.value])
|
|
|
|
q = await self.db.list(TestModel.select().order_by(TestModel.value))
|
|
self.assertEqual([(tm.name, tm.value) for tm in q],
|
|
[('b0-x', 100), ('b1-x', 101), ('b2-x', 102),
|
|
('b3-x', 103), ('b4-x', 104)])
|
|
|
|
async def test_insert_many(self):
|
|
def insert():
|
|
data = [{'name': f'i{i}', 'value': i} for i in range(100)]
|
|
TestModel.insert_many(data).execute()
|
|
await self.db.run(insert)
|
|
await self.assertCount(100)
|
|
|
|
data = [{'name': f'i{i}', 'value': i} for i in range(100, 200)]
|
|
await self.db.aexecute(TestModel.insert_many(data))
|
|
await self.assertCount(200)
|
|
|
|
data = [(f'i{i}', i) for i in range(200, 300)]
|
|
iq = (TestModel
|
|
.insert_many(data, fields=[TestModel.name, TestModel.value]))
|
|
await self.db.aexecute(iq)
|
|
await self.assertCount(300)
|
|
|
|
async def test_atomic(self):
|
|
async with self.db.atomic():
|
|
await self.create_record('a', 1)
|
|
await self.assertCount(1)
|
|
|
|
async with self.db.atomic() as txn:
|
|
await self.create_record('b', 2)
|
|
await self.assertCount(2)
|
|
await self.assertNames(['a', 'b'])
|
|
await txn.arollback()
|
|
await self.assertCount(1)
|
|
await self.create_record('c', 3)
|
|
await self.create_record('d', 4)
|
|
|
|
await self.assertCount(3)
|
|
await self.assertNames(['a', 'c', 'd'])
|
|
|
|
async def test_transaction_commit(self):
|
|
def create_in_tx():
|
|
with self.db.atomic():
|
|
TestModel.create(name='tx1')
|
|
TestModel.create(name='tx2')
|
|
await self.db.run(create_in_tx)
|
|
await self.assertCount(2)
|
|
|
|
async with self.db.atomic():
|
|
await self.db.run(TestModel.create, name='tx1')
|
|
await self.db.run(TestModel.create, name='tx2')
|
|
await self.assertCount(4)
|
|
|
|
async def test_transaction_rollback(self):
|
|
def failing():
|
|
with self.db.atomic():
|
|
TestModel.create(name='tx1')
|
|
raise ValueError('fail')
|
|
with self.assertRaises(ValueError):
|
|
await self.db.run(failing)
|
|
await self.assertCount(0)
|
|
|
|
async with self.db.atomic() as txn:
|
|
await self.create_record('tx2')
|
|
await self.assertCount(1)
|
|
await txn.arollback()
|
|
await self.assertCount(0)
|
|
|
|
async def test_nested_transactions(self):
|
|
def nested():
|
|
with self.db.atomic():
|
|
TestModel.create(name='o1', value=1)
|
|
with self.db.atomic():
|
|
TestModel.create(name='i1', value=2)
|
|
TestModel.create(name='i2', value=3)
|
|
TestModel.create(name='o2', value=4)
|
|
await self.db.run(nested)
|
|
await self.assertCount(4)
|
|
await self.assertNames(['i1', 'i2', 'o1', 'o2'])
|
|
|
|
async with self.db.atomic():
|
|
await self.db.run(TestModel.create, name='o3', value=1)
|
|
async with self.db.atomic():
|
|
await self.db.run(TestModel.create, name='i3', value=2)
|
|
await self.db.run(TestModel.create, name='i4', value=3)
|
|
await self.db.run(TestModel.create, name='o4', value=4)
|
|
|
|
await self.assertCount(8)
|
|
await self.assertNames(['i1', 'i2', 'i3', 'i4',
|
|
'o1', 'o2', 'o3', 'o4'])
|
|
|
|
async def test_nested_implicit_rollback(self):
|
|
def nested():
|
|
with self.db.atomic():
|
|
TestModel.create(name='o1', value=1)
|
|
try:
|
|
with self.db.atomic():
|
|
TestModel.create(name='i1', value=2)
|
|
raise ValueError('fail')
|
|
except ValueError:
|
|
pass
|
|
TestModel.create(name='o2', value=3)
|
|
await self.db.run(nested)
|
|
await self.assertCount(2)
|
|
await self.assertNames(['o1', 'o2'])
|
|
|
|
async with self.db.atomic():
|
|
await self.db.run(TestModel.create, name='o3', value=1)
|
|
try:
|
|
async with self.db.atomic():
|
|
await self.db.run(TestModel.create, name='i3', value=2)
|
|
raise ValueError('fail')
|
|
except ValueError:
|
|
pass
|
|
await self.assertCount(3)
|
|
await self.db.run(TestModel.create, name='o4', value=3)
|
|
|
|
await self.assertCount(4)
|
|
await self.assertNames(['o1', 'o2', 'o3', 'o4'])
|
|
|
|
async def test_nested_explicit_rollback(self):
|
|
def nested():
|
|
with self.db.atomic():
|
|
TestModel.create(name='o1')
|
|
with self.db.atomic() as sp:
|
|
TestModel.create(name='i1')
|
|
self.assertEqual(TestModel.select().count(), 2)
|
|
sp.rollback()
|
|
self.assertEqual(TestModel.select().count(), 1)
|
|
TestModel.create(name='o2')
|
|
|
|
await self.db.run(nested)
|
|
await self.assertCount(2)
|
|
await self.assertNames(['o1', 'o2'])
|
|
|
|
async with self.db.atomic():
|
|
await self.db.run(TestModel.create, name='o3')
|
|
async with self.db.atomic() as sp:
|
|
await self.db.run(TestModel.create, name='i2')
|
|
await self.assertCount(4)
|
|
await sp.arollback()
|
|
await self.assertCount(3)
|
|
await self.db.run(TestModel.create, name='o4')
|
|
|
|
await self.assertCount(4)
|
|
await self.assertNames(['o1', 'o2', 'o3', 'o4'])
|
|
|
|
async def test_nested_mix(self):
|
|
async with self.db.atomic():
|
|
await self.create_record('t1')
|
|
async with self.db.atomic():
|
|
await self.create_record('t2')
|
|
async with self.db.atomic():
|
|
await self.create_record('t3')
|
|
try:
|
|
async with self.db.atomic():
|
|
await self.create_record('t4')
|
|
await self.assertCount(4)
|
|
raise ValueError('fail')
|
|
except ValueError:
|
|
pass
|
|
async with self.db.atomic() as sp:
|
|
await self.create_record('t4')
|
|
await self.assertCount(4)
|
|
await sp.arollback()
|
|
await self.assertCount(3)
|
|
try:
|
|
async with self.db.atomic():
|
|
await self.create_record('t5')
|
|
await self.assertCount(4)
|
|
raise ValueError('fail')
|
|
except ValueError:
|
|
await self.assertCount(3)
|
|
|
|
await self.assertCount(3)
|
|
await self.assertNames(['t1', 't2', 't3'])
|
|
|
|
try:
|
|
async with self.db.atomic():
|
|
await self.create_record('t6')
|
|
async with self.db.atomic():
|
|
await self.create_record('t7')
|
|
async with self.db.atomic():
|
|
await self.create_record('t8')
|
|
await self.assertCount(6)
|
|
raise ValueError('fail')
|
|
except ValueError:
|
|
pass
|
|
|
|
await self.assertCount(3)
|
|
await self.assertNames(['t1', 't2', 't3'])
|
|
|
|
async def test_acommit_arollback(self):
|
|
async with self.db.atomic() as txn:
|
|
await self.create_record('committed', 1)
|
|
await txn.acommit()
|
|
await self.create_record('not-committed', 2)
|
|
await txn.arollback()
|
|
|
|
await self.assertCount(1)
|
|
await self.assertNames(['committed'])
|
|
|
|
async def test_concurrent_reads_writes(self):
|
|
await self.seed(10)
|
|
|
|
async def writer(sid):
|
|
def _write():
|
|
for i in range(5):
|
|
TestModel.create(name=f'w{sid}-{i}', value=sid * 100 + i)
|
|
|
|
async with self.db:
|
|
await self.db.run(_write)
|
|
|
|
async def reader():
|
|
async with self.db:
|
|
query = TestModel.select()
|
|
return await self.db.run(lambda: len(list(query)))
|
|
|
|
await asyncio.gather(*[writer(i) for i in range(3)])
|
|
reads = await asyncio.gather(*[reader() for _ in range(3)])
|
|
self.assertTrue(all(r >= 10 for r in reads))
|
|
await self.assertCount(25)
|
|
|
|
async def test_isolated_connections_per_task(self):
|
|
async def worker(tid):
|
|
async with self.db:
|
|
c1 = self.db._state.conn
|
|
await self.create_record(f't{tid}', tid)
|
|
return c1 is self.db._state.conn
|
|
|
|
results = await asyncio.gather(*[worker(i) for i in range(5)])
|
|
self.assertTrue(all(results))
|
|
await self.assertCount(5)
|
|
|
|
async def test_many_concurrent_tasks(self):
|
|
ntasks = 50 if self.driver == 'sqlite' else 10
|
|
async def task(tid):
|
|
async with self.db:
|
|
await self.create_record(f't{tid}', tid)
|
|
await asyncio.gather(*[task(i) for i in range(ntasks)])
|
|
await self.assertCount(ntasks)
|
|
|
|
async def test_syntax_error_recovery(self):
|
|
with self.assertRaises(Exception):
|
|
await self.db.aexecute_sql('INVALID SQL')
|
|
await self.create_record('after_error', 1)
|
|
await self.assertCount(1)
|
|
|
|
async def test_concurrent_errors(self):
|
|
errors, successes = [], []
|
|
|
|
async def worker(tid):
|
|
async with self.db:
|
|
try:
|
|
def work():
|
|
TestModel.create(name=f't{tid}', value=tid)
|
|
if tid % 2 == 0:
|
|
raise ValueError(f'Task {tid} fails')
|
|
await self.db.run(work)
|
|
successes.append(tid)
|
|
except ValueError:
|
|
errors.append(tid)
|
|
|
|
await asyncio.gather(*[worker(i) for i in range(10)])
|
|
self.assertEqual(sorted(errors), [0, 2, 4, 6, 8])
|
|
self.assertEqual(sorted(successes), [1, 3, 5, 7, 9])
|
|
await self.assertCount(10)
|
|
|
|
async def test_iterate_yields_model_instances(self):
|
|
await self.seed(20)
|
|
results = []
|
|
query = TestModel.select().order_by(TestModel.value)
|
|
async for obj in self.db.iterate(query):
|
|
results.append(obj)
|
|
|
|
self.assertEqual(len(results), 20)
|
|
self.assertTrue(all(isinstance(r, TestModel) for r in results))
|
|
self.assertEqual(results[0].name, 'item00')
|
|
self.assertEqual(results[0].value, 0)
|
|
self.assertEqual(results[-1].name, 'item19')
|
|
self.assertEqual(results[-1].value, 190)
|
|
|
|
async def test_iterate_matches_list(self):
|
|
await self.seed(20)
|
|
query = TestModel.select().order_by(TestModel.name)
|
|
|
|
eager = await self.db.list(query)
|
|
lazy = [obj async for obj in self.db.iterate(query)]
|
|
self.assertEqual(len(eager), len(lazy))
|
|
|
|
for e, l in zip(eager, lazy):
|
|
self.assertEqual(e.name, l.name)
|
|
self.assertEqual(e.value, l.value)
|
|
|
|
async def test_iterate_dicts(self):
|
|
await self.seed(5)
|
|
query = TestModel.select().order_by(TestModel.name)
|
|
results = [row async for row in self.db.iterate(query.dicts())]
|
|
|
|
self.assertEqual(len(results), 5)
|
|
self.assertIsInstance(results[0], dict)
|
|
self.assertEqual(results[0]['name'], 'item00')
|
|
self.assertEqual(results[-1]['name'], 'item04')
|
|
|
|
async def test_iterate_tuples(self):
|
|
await self.seed(5)
|
|
query = TestModel.select(TestModel.name).order_by(TestModel.name)
|
|
results = [row async for row in self.db.iterate(query.tuples())]
|
|
|
|
self.assertEqual(len(results), 5)
|
|
self.assertIsInstance(results[0], tuple)
|
|
self.assertEqual(results[0][0], 'item00')
|
|
self.assertEqual(results[-1][0], 'item04')
|
|
|
|
async def test_iterate_namedtuples(self):
|
|
await self.seed(5)
|
|
query = TestModel.select(TestModel.name).order_by(TestModel.name)
|
|
results = [row async for row in self.db.iterate(query.namedtuples())]
|
|
|
|
self.assertEqual(len(results), 5)
|
|
self.assertEqual(results[0].name, 'item00')
|
|
self.assertEqual(results[0][0], 'item00')
|
|
self.assertEqual(results[-1].name, 'item04')
|
|
self.assertEqual(results[-1][0], 'item04')
|
|
|
|
async def test_iterate_with_where(self):
|
|
await self.seed(20)
|
|
query = (TestModel.select()
|
|
.where(TestModel.value >= 150)
|
|
.order_by(TestModel.value))
|
|
results = [row async for row in self.db.iterate(query)]
|
|
|
|
self.assertEqual(len(results), 5)
|
|
self.assertEqual(results[0].value, 150)
|
|
self.assertEqual(results[-1].value, 190)
|
|
|
|
async def test_iterate_empty(self):
|
|
query = TestModel.select().where(TestModel.id == 0)
|
|
results = [row async for row in self.db.iterate(query)]
|
|
self.assertEqual(results, [])
|
|
|
|
async def test_iterate_buffer_size(self):
|
|
await self.seed(20)
|
|
query = TestModel.select().order_by(TestModel.value)
|
|
results = [obj async for obj in self.db.iterate(query, buffer_size=3)]
|
|
|
|
self.assertEqual(len(results), 20)
|
|
self.assertEqual(results[0].value, 0)
|
|
self.assertEqual(results[-1].value, 190)
|
|
|
|
async def test_iterate_early_break(self):
|
|
await self.seed(20)
|
|
count = 0
|
|
query = TestModel.select().order_by(TestModel.value)
|
|
async for obj in self.db.iterate(query):
|
|
count += 1
|
|
if count == 5:
|
|
break
|
|
self.assertEqual(count, 5)
|
|
# Database still usable (lock released).
|
|
self.assertEqual(await self.db.count(TestModel.select()), 20)
|
|
|
|
async def test_iterate_aggregation(self):
|
|
await self.seed(20)
|
|
query = (TestModel
|
|
.select(fn.AVG(TestModel.value).alias('avg_val'))
|
|
.dicts())
|
|
results = [row async for row in self.db.iterate(query)]
|
|
|
|
self.assertEqual(len(results), 1)
|
|
self.assertEqual(results[0]['avg_val'], 95.0)
|
|
|
|
async def test_iterate_sequential(self):
|
|
await self.seed(20)
|
|
query = (TestModel.select()
|
|
.where(TestModel.value < 50)
|
|
.order_by(TestModel.value))
|
|
r1 = [obj.value async for obj in self.db.iterate(query)]
|
|
|
|
query = (TestModel.select()
|
|
.where(TestModel.value >= 150)
|
|
.order_by(TestModel.value))
|
|
r2 = [obj.value async for obj in self.db.iterate(query)]
|
|
|
|
self.assertEqual(r1, [0, 10, 20, 30, 40])
|
|
self.assertEqual(r2, [150, 160, 170, 180, 190])
|
|
|
|
async def test_iterate_break_then_iterate_again(self):
|
|
await self.seed(20)
|
|
query = TestModel.select().order_by(TestModel.value)
|
|
async for obj in self.db.iterate(query):
|
|
break
|
|
results = []
|
|
async for obj in self.db.iterate(query):
|
|
results.append(obj.value)
|
|
self.assertEqual(len(results), 20)
|
|
|
|
async def test_iterate_multi(self):
|
|
await self.seed(10)
|
|
async def iterate_multi():
|
|
async with self.db:
|
|
query = TestModel.select().order_by(TestModel.value)
|
|
return [obj.id async for obj in self.db.iterate(query)]
|
|
|
|
results = await asyncio.gather(*[iterate_multi() for i in range(5)])
|
|
self.assertEqual(len(results), 5)
|
|
self.assertTrue(all(len(r) == 10 for r in results))
|
|
|
|
async def test_basic_crud(self):
|
|
rec = await self.create_record('testx', value=2)
|
|
self.assertEqual(rec.name, 'testx')
|
|
fetched = await self.db.run(TestModel.get, TestModel.name == 'testx')
|
|
self.assertEqual(fetched.value, 2)
|
|
|
|
def update():
|
|
r = TestModel.get(TestModel.id == rec.id)
|
|
r.value = 100; r.save()
|
|
return TestModel.get(TestModel.id == rec.id)
|
|
self.assertEqual((await self.db.run(update)).value, 100)
|
|
|
|
await self.db.run(rec.delete_instance)
|
|
await self.assertCount(0)
|
|
|
|
async def test_foreign_keys(self):
|
|
users = [User(username=f'u{i}') for i in range(3)]
|
|
await self.db.run(User.bulk_create, users)
|
|
self.assertEqual(await self.db.run(User.select().count), 3)
|
|
users = await self.db.list(User.select())
|
|
|
|
async with self.db.atomic():
|
|
for u in users:
|
|
for i in range(2):
|
|
await self.db.run(
|
|
Tweet.create, user=u, message=f'{u.username}-{i}')
|
|
|
|
self.assertEqual(await self.db.run(Tweet.select().count), 6)
|
|
|
|
q = Tweet.select().where(Tweet.message == 'u0-0')
|
|
tweet = await self.db.get(q)
|
|
self.assertEqual(await self.db.run(lambda: tweet.user.username), 'u0')
|
|
|
|
q = (Tweet.select(Tweet, User)
|
|
.join(User)
|
|
.where(Tweet.message == 'u0-0'))
|
|
tweet = await self.db.get(q)
|
|
self.assertEqual(tweet.user.username, 'u0')
|
|
|
|
q = User.select().where(User.username == 'u2')
|
|
user = await self.db.get(q)
|
|
tweets = await self.db.list(user.tweets.order_by(Tweet.id))
|
|
self.assertEqual([t.message for t in tweets], ['u2-0', 'u2-1'])
|
|
|
|
users_q = User.select().order_by(User.username)
|
|
tweets_q = Tweet.select().order_by(Tweet.message)
|
|
await self.db.aprefetch(users_q, tweets_q)
|
|
self.assertEqual(
|
|
[(u.username, [t.message for t in u.tweets]) for u in users_q],
|
|
[('u0', ['u0-0', 'u0-1']),
|
|
('u1', ['u1-0', 'u1-1']),
|
|
('u2', ['u2-0', 'u2-1'])])
|
|
|
|
async def test_transactions(self):
|
|
def ok_tx():
|
|
with self.db.atomic():
|
|
TestModel.create(name='t1', value=1)
|
|
TestModel.create(name='t2', value=2)
|
|
await self.db.run(ok_tx)
|
|
await self.assertCount(2)
|
|
|
|
def bad_tx():
|
|
with self.db.atomic():
|
|
TestModel.create(name='t3', value=3)
|
|
raise ValueError('fail')
|
|
with self.assertRaises(ValueError):
|
|
await self.db.run(bad_tx)
|
|
|
|
async with self.db.atomic():
|
|
await self.create_record('t4')
|
|
try:
|
|
async with self.db.atomic():
|
|
await self.create_record('t5')
|
|
await self.assertCount(4)
|
|
raise ValueError('fail')
|
|
except ValueError:
|
|
pass
|
|
await self.assertCount(3)
|
|
|
|
await self.assertCount(3)
|
|
await self.assertNames(['t1', 't2', 't4'])
|
|
|
|
|
|
class TestSqliteIntegration(IntegrationTests, unittest.IsolatedAsyncioTestCase):
|
|
def get_database(self):
|
|
with tempfile.NamedTemporaryFile(delete=False) as f:
|
|
self.db_path = f.name
|
|
return AsyncSqliteDatabase(self.db_path)
|
|
|
|
async def test_pragmas(self):
|
|
db = AsyncSqliteDatabase(':memory:', pragmas={'user_version': '99'})
|
|
conn = await db.aconnect()
|
|
r = await conn.execute('PRAGMA user_version')
|
|
self.assertEqual(r.fetchone(), (99,))
|
|
await db.close_pool()
|
|
|
|
async def test_custom_functions(self):
|
|
db = AsyncSqliteDatabase(':memory:')
|
|
|
|
@db.func()
|
|
def title_case(s):
|
|
return s.title()
|
|
|
|
async with db:
|
|
r = await db.aexecute_sql('SELECT title_case(?)', ('test foo',))
|
|
self.assertEqual(r.fetchone(), ('Test Foo',))
|
|
await db.close_pool()
|
|
|
|
async def test_constraint_violation_recovery(self):
|
|
await self.db.aexecute_sql(
|
|
'CREATE TABLE ut (id INTEGER PRIMARY KEY, v TEXT UNIQUE)')
|
|
await self.db.aexecute_sql(
|
|
'INSERT INTO ut (v) VALUES (?)', ('x',))
|
|
with self.assertRaises(IntegrityError):
|
|
await self.db.aexecute_sql(
|
|
'INSERT INTO ut (v) VALUES (?)', ('x',))
|
|
await self.db.aexecute_sql(
|
|
'INSERT INTO ut (v) VALUES (?)', ('y',))
|
|
|
|
|
|
@unittest.skipIf(not IS_POSTGRESQL, 'skipping postgres test')
|
|
@unittest.skipUnless(asyncpg, 'asyncpg not installed')
|
|
class TestPostgresqlIntegration(IntegrationTests, unittest.IsolatedAsyncioTestCase):
|
|
def get_database(self):
|
|
return AsyncPostgresqlDatabase('peewee_test', **PSQL_PARAMS)
|
|
|
|
async def test_placeholder_conversion(self):
|
|
def insert():
|
|
return self.db.execute_sql(
|
|
'INSERT INTO testmodel (name, value) VALUES (%s, %s)',
|
|
('placeholder_test', 999))
|
|
await self.db.run(insert)
|
|
|
|
def query():
|
|
r = self.db.execute_sql(
|
|
'SELECT * FROM testmodel WHERE name = %s',
|
|
('placeholder_test',))
|
|
return r.fetchone()
|
|
row = await self.db.run(query)
|
|
self.assertIsNotNone(row)
|
|
self.assertEqual(row['name'], 'placeholder_test')
|
|
self.assertEqual(row['value'], 999)
|
|
|
|
curs = await self.db.aexecute_sql('select %s', ('test',))
|
|
self.assertEqual(curs.fetchone()[0], 'test')
|
|
|
|
async def test_iterator_with_transaction(self):
|
|
async with self.db.atomic() as tx:
|
|
await self.seed(2)
|
|
q = TestModel.select().order_by(TestModel.value)
|
|
results = [obj.value async for obj in self.db.iterate(q)]
|
|
self.assertEqual(results, [0, 10])
|
|
|
|
await self.assertCount(2)
|
|
|
|
|
|
@unittest.skipIf(not IS_MYSQL, 'skipping mysql test')
|
|
@unittest.skipUnless(aiomysql, 'aiomysql not installed')
|
|
class TestMySQLIntegration(IntegrationTests, unittest.IsolatedAsyncioTestCase):
|
|
def get_database(self):
|
|
return AsyncMySQLDatabase('peewee_test', **MYSQL_PARAMS)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|