Files
Federico Caselli 380c234ce9 Support aiosqlite 0.22.0+
Fixed issue in the aiosqlite driver where SQLAlchemy's setting of
aiosqlite's worker thread to "daemon" stopped working because the aiosqlite
architecture moved the location of the worker thread in version 0.22.0.
This "daemon" flag is necessary so that a program is able to exit if the
SQLite connection itself was not explicitly closed, which is particularly
likely with SQLAlchemy as it maintains SQLite connections in a connection
pool.  While it's perfectly fine to call :meth:`.AsyncEngine.dispose`
before program exit, this is not historically or technically necessary for
any driver of any known backend, since a primary feature of relational
databases is durability.  The change also implements support for
"terminate" with aiosqlite when using version version 0.22.1 or greater,
which implements a sync ``.stop()`` method.

Fixes: #13039
Change-Id: I46efcbaab9dd028f673e113d5f6f2ceddfd133ca
2026-01-01 12:54:47 -05:00

1652 lines
56 KiB
Python

import asyncio
import contextlib
import inspect as stdlib_inspect
from unittest.mock import patch
from sqlalchemy import AssertionPool
from sqlalchemy import Column
from sqlalchemy import create_engine
from sqlalchemy import delete
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import func
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import NullPool
from sqlalchemy import QueuePool
from sqlalchemy import select
from sqlalchemy import SingletonThreadPool
from sqlalchemy import StaticPool
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy import true
from sqlalchemy import union_all
from sqlalchemy.engine import cursor as _cursor
from sqlalchemy.ext.asyncio import async_engine_from_config
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import create_async_pool_from_url
from sqlalchemy.ext.asyncio import engine as _async_engine
from sqlalchemy.ext.asyncio import exc as async_exc
from sqlalchemy.ext.asyncio import exc as asyncio_exc
from sqlalchemy.ext.asyncio.base import ReversibleProxy
from sqlalchemy.ext.asyncio.engine import AsyncConnection
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.pool import AsyncAdaptedQueuePool
from sqlalchemy.testing import assertions
from sqlalchemy.testing import async_test
from sqlalchemy.testing import combinations
from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import eq_regex
from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_none
from sqlalchemy.testing import is_not
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from sqlalchemy.testing import ne_
from sqlalchemy.util import greenlet_spawn
class AsyncFixture:
@config.fixture(
params=[
(rollback, run_second_execute, begin_nested)
for rollback in (True, False)
for run_second_execute in (True, False)
for begin_nested in (True, False)
]
)
def async_trans_ctx_manager_fixture(self, request, metadata):
rollback, run_second_execute, begin_nested = request.param
t = Table("test", metadata, Column("data", Integer))
eng = getattr(self, "bind", None) or config.db
t.create(eng)
async def run_test(subject, trans_on_subject, execute_on_subject):
async with subject.begin() as trans:
if begin_nested:
if not config.requirements.savepoints.enabled:
config.skip_test("savepoints not enabled")
if execute_on_subject:
nested_trans = subject.begin_nested()
else:
nested_trans = trans.begin_nested()
async with nested_trans:
if execute_on_subject:
await subject.execute(t.insert(), {"data": 10})
else:
await trans.execute(t.insert(), {"data": 10})
# for nested trans, we always commit/rollback on the
# "nested trans" object itself.
# only Session(future=False) will affect savepoint
# transaction for session.commit/rollback
if rollback:
await nested_trans.rollback()
else:
await nested_trans.commit()
if run_second_execute:
with assertions.expect_raises_message(
exc.InvalidRequestError,
"Can't operate on closed transaction "
"inside context manager. Please complete the "
"context manager "
"before emitting further commands.",
):
if execute_on_subject:
await subject.execute(
t.insert(), {"data": 12}
)
else:
await trans.execute(
t.insert(), {"data": 12}
)
# outside the nested trans block, but still inside the
# transaction block, we can run SQL, and it will be
# committed
if execute_on_subject:
await subject.execute(t.insert(), {"data": 14})
else:
await trans.execute(t.insert(), {"data": 14})
else:
if execute_on_subject:
await subject.execute(t.insert(), {"data": 10})
else:
await trans.execute(t.insert(), {"data": 10})
if trans_on_subject:
if rollback:
await subject.rollback()
else:
await subject.commit()
else:
if rollback:
await trans.rollback()
else:
await trans.commit()
if run_second_execute:
with assertions.expect_raises_message(
exc.InvalidRequestError,
"Can't operate on closed transaction inside "
"context "
"manager. Please complete the context manager "
"before emitting further commands.",
):
if execute_on_subject:
await subject.execute(t.insert(), {"data": 12})
else:
await trans.execute(t.insert(), {"data": 12})
expected_committed = 0
if begin_nested:
# begin_nested variant, we inserted a row after the nested
# block
expected_committed += 1
if not rollback:
# not rollback variant, our row inserted in the target
# block itself would be committed
expected_committed += 1
if execute_on_subject:
eq_(
await subject.scalar(select(func.count()).select_from(t)),
expected_committed,
)
else:
with subject.connect() as conn:
eq_(
await conn.scalar(select(func.count()).select_from(t)),
expected_committed,
)
return run_test
class EngineFixture(AsyncFixture, fixtures.TablesTest):
__requires__ = ("async_dialect",)
@testing.fixture
def async_engine(self):
return engines.testing_engine(
asyncio=True, options={"sqlite_share_pool": True}
)
@testing.fixture
def adhoc_async_engine(self):
return engines.testing_engine(asyncio=True)
@testing.fixture
def async_connection(self, async_engine):
with async_engine.sync_engine.connect() as conn:
yield AsyncConnection(async_engine, conn)
@classmethod
def define_tables(cls, metadata):
Table(
"users",
metadata,
Column("user_id", Integer, primary_key=True, autoincrement=False),
Column("user_name", String(20)),
)
@classmethod
def insert_data(cls, connection):
users = cls.tables.users
connection.execute(
users.insert(),
[{"user_id": i, "user_name": "name%d" % i} for i in range(1, 20)],
)
class AsyncEngineTest(EngineFixture):
__backend__ = True
@testing.fails("the failure is the test")
@async_test
async def test_we_are_definitely_running_async_tests(self, async_engine):
async with async_engine.connect() as conn:
eq_(await conn.scalar(text("select 1")), 2)
@async_test
async def test_interrupt_ctxmanager_connection(
self, async_engine, async_trans_ctx_manager_fixture
):
fn = async_trans_ctx_manager_fixture
async with async_engine.connect() as conn:
await fn(conn, trans_on_subject=False, execute_on_subject=True)
def test_proxied_attrs_engine(self, async_engine):
sync_engine = async_engine.sync_engine
is_(async_engine.url, sync_engine.url)
is_(async_engine.pool, sync_engine.pool)
is_(async_engine.dialect, sync_engine.dialect)
eq_(async_engine.name, sync_engine.name)
eq_(async_engine.driver, sync_engine.driver)
eq_(async_engine.echo, sync_engine.echo)
@async_test
async def test_run_async(self, async_engine):
async def test_meth(async_driver_connection):
# there's no method that's guaranteed to be on every
# driver, so just stringify it and compare that to the
# outside
return str(async_driver_connection)
def run_sync_to_async(connection):
connection_fairy = connection.connection
async_return = connection_fairy.run_async(
lambda driver_connection: test_meth(driver_connection)
)
assert not stdlib_inspect.iscoroutine(async_return)
return async_return
async with async_engine.connect() as conn:
driver_connection = (
await conn.get_raw_connection()
).driver_connection
res = await conn.run_sync(run_sync_to_async)
assert not stdlib_inspect.iscoroutine(res)
eq_(res, str(driver_connection))
@async_test
async def test_engine_eq_ne(self, async_engine):
e2 = _async_engine.AsyncEngine(async_engine.sync_engine)
e3 = engines.testing_engine(asyncio=True)
eq_(async_engine, e2)
ne_(async_engine, e3)
is_false(async_engine == None)
def test_no_attach_to_event_loop(self, testing_engine):
"""test #6409
note this test does not seem to trigger the bug that was originally
fixed in #6409, when using python 3.10 and higher (the original issue
can repro in 3.8 at least, based on my testing). It's been simplified
to no longer explicitly create a new loop, asyncio.run() already
creates a new loop.
"""
import asyncio
import threading
errs = []
def go():
async def main():
tasks = [task() for _ in range(2)]
await asyncio.gather(*tasks)
await engine.dispose()
async def task():
async with engine.begin() as connection:
result = await connection.execute(select(1))
result.all()
try:
engine = engines.testing_engine(asyncio=True)
asyncio.run(main())
except Exception as err:
errs.append(err)
t = threading.Thread(target=go)
t.start()
t.join()
if errs:
raise errs[0]
@async_test
async def test_connection_info(self, async_engine):
async with async_engine.connect() as conn:
conn.info["foo"] = "bar"
eq_(conn.sync_connection.info, {"foo": "bar"})
@async_test
async def test_connection_eq_ne(self, async_engine):
async with async_engine.connect() as conn:
c2 = _async_engine.AsyncConnection(
async_engine, conn.sync_connection
)
eq_(conn, c2)
async with async_engine.connect() as c3:
ne_(conn, c3)
is_false(conn == None)
@async_test
async def test_transaction_eq_ne(self, async_engine):
async with async_engine.connect() as conn:
t1 = await conn.begin()
t2 = _async_engine.AsyncTransaction._regenerate_proxy_for_target(
t1._proxied
)
eq_(t1, t2)
is_false(t1 == None)
@testing.variation("simulate_gc", [True, False])
def test_appropriate_warning_for_gced_connection(
self, adhoc_async_engine, simulate_gc
):
"""test #9237 which builds upon a not really complete solution
added for #8419."""
async def go():
conn = await adhoc_async_engine.connect()
await conn.begin()
await conn.execute(select(1))
pool_connection = await conn.get_raw_connection()
return pool_connection
from sqlalchemy.util.concurrency import await_
pool_connection = await_(go())
rec = pool_connection._connection_record
ref = rec.fairy_ref
pool = pool_connection._pool
echo = False
if simulate_gc:
# not using expect_warnings() here because we also want to do a
# negative test for warnings, and we want to absolutely make sure
# the thing here that emits the warning is the correct path
from sqlalchemy.pool.base import _finalize_fairy
with (
mock.patch.object(
pool._dialect,
"do_rollback",
mock.Mock(side_effect=Exception("can't run rollback")),
),
mock.patch("sqlalchemy.util.warn") as m,
):
_finalize_fairy(
None, rec, pool, ref, echo, transaction_was_reset=False
)
if adhoc_async_engine.dialect.has_terminate:
expected_msg = (
"The garbage collector is trying to clean up.*which will "
"be terminated."
)
else:
expected_msg = (
"The garbage collector is trying to clean up.*which will "
"be dropped, as it cannot be safely terminated."
)
# [1] == .args, not in 3.7
eq_regex(m.mock_calls[0][1][0], expected_msg)
else:
# the warning emitted by the pool is inside of a try/except:
# so it's impossible right now to have this warning "raise".
# for now, test by using mock.patch
with mock.patch("sqlalchemy.util.warn") as m:
pool_connection.close()
eq_(m.mock_calls, [])
@async_test
@testing.skip_if(lambda config: not config.db.dialect.has_terminate)
async def test_dbapi_terminate(self, adhoc_async_engine):
conn = await adhoc_async_engine.raw_connection()
dbapi_conn = conn.dbapi_connection
dbapi_conn.terminate()
conn.invalidate()
@async_test
async def test_statement_compile(self, async_engine):
stmt = str(select(1).compile(async_engine))
async with async_engine.connect() as conn:
eq_(str(select(1).compile(conn)), stmt)
def test_clear_compiled_cache(self, async_engine):
async_engine.sync_engine._compiled_cache["foo"] = "bar"
eq_(async_engine.sync_engine._compiled_cache["foo"], "bar")
async_engine.clear_compiled_cache()
assert "foo" not in async_engine.sync_engine._compiled_cache
def test_execution_options(self, async_engine):
a2 = async_engine.execution_options(foo="bar")
assert isinstance(a2, _async_engine.AsyncEngine)
eq_(a2.sync_engine._execution_options, {"foo": "bar"})
eq_(async_engine.sync_engine._execution_options, {})
"""
attr uri, pool, dialect, engine, name, driver, echo
methods clear_compiled_cache, update_execution_options,
execution_options, get_execution_options, dispose
"""
@async_test
async def test_proxied_attrs_connection(self, async_engine):
async with async_engine.connect() as conn:
sync_conn = conn.sync_connection
is_(conn.engine, async_engine)
is_(conn.closed, sync_conn.closed)
is_(conn.dialect, async_engine.sync_engine.dialect)
eq_(
conn.default_isolation_level, sync_conn.default_isolation_level
)
@async_test
async def test_transaction_accessor(self, async_connection):
conn = async_connection
is_none(conn.get_transaction())
is_false(conn.in_transaction())
is_false(conn.in_nested_transaction())
trans = await conn.begin()
is_true(conn.in_transaction())
is_false(conn.in_nested_transaction())
is_(trans.sync_transaction, conn.get_transaction().sync_transaction)
nested = await conn.begin_nested()
is_true(conn.in_transaction())
is_true(conn.in_nested_transaction())
is_(
conn.get_nested_transaction().sync_transaction,
nested.sync_transaction,
)
eq_(conn.get_nested_transaction(), nested)
is_(trans.sync_transaction, conn.get_transaction().sync_transaction)
await nested.commit()
is_true(conn.in_transaction())
is_false(conn.in_nested_transaction())
await trans.rollback()
is_none(conn.get_transaction())
is_false(conn.in_transaction())
is_false(conn.in_nested_transaction())
@testing.requires.queue_pool
@async_test
async def test_invalidate(self, async_engine):
conn = await async_engine.connect()
is_(conn.invalidated, False)
connection_fairy = await conn.get_raw_connection()
is_(connection_fairy.is_valid, True)
dbapi_connection = connection_fairy.dbapi_connection
await conn.invalidate()
if testing.against("postgresql+asyncpg"):
assert dbapi_connection._connection.is_closed()
new_fairy = await conn.get_raw_connection()
is_not(new_fairy.dbapi_connection, dbapi_connection)
is_not(new_fairy, connection_fairy)
is_(new_fairy.is_valid, True)
is_(connection_fairy.is_valid, False)
await conn.close()
@async_test
async def test_get_dbapi_connection_raise(self, async_connection):
with testing.expect_raises_message(
exc.InvalidRequestError,
"AsyncConnection.connection accessor is not "
"implemented as the attribute",
):
async_connection.connection
@async_test
async def test_get_raw_connection(self, async_connection):
pooled = await async_connection.get_raw_connection()
is_(pooled, async_connection.sync_connection.connection)
@async_test
async def test_isolation_level(self, async_connection):
conn = async_connection
sync_isolation_level = await greenlet_spawn(
conn.sync_connection.get_isolation_level
)
isolation_level = await conn.get_isolation_level()
eq_(isolation_level, sync_isolation_level)
await conn.execution_options(isolation_level="SERIALIZABLE")
isolation_level = await conn.get_isolation_level()
eq_(isolation_level, "SERIALIZABLE")
@testing.combinations(
(
AsyncAdaptedQueuePool,
True,
),
(
QueuePool,
False,
),
(NullPool, True),
(SingletonThreadPool, False),
(StaticPool, True),
(AssertionPool, True),
argnames="pool_cls,should_work",
)
@testing.variation("instantiate", [True, False])
@async_test
async def test_pool_classes(
self, async_testing_engine, pool_cls, instantiate, should_work
):
"""test #8771"""
if instantiate:
if pool_cls in (QueuePool, AsyncAdaptedQueuePool):
pool = pool_cls(creator=testing.db.pool._creator, timeout=10)
else:
pool = pool_cls(
creator=testing.db.pool._creator,
)
options = {"pool": pool}
else:
if pool_cls in (QueuePool, AsyncAdaptedQueuePool):
options = {"poolclass": pool_cls, "pool_timeout": 10}
else:
options = {"poolclass": pool_cls}
if not should_work:
with expect_raises_message(
exc.ArgumentError,
f"Pool class {pool_cls.__name__} "
"cannot be used with asyncio engine",
):
async_testing_engine(options=options)
return
e = async_testing_engine(options=options)
if pool_cls is AssertionPool:
async with e.connect() as conn:
result = await conn.scalar(select(1))
eq_(result, 1)
return
async def go():
async with e.connect() as conn:
result = await conn.scalar(select(1))
eq_(result, 1)
return result
eq_(await asyncio.gather(*[go() for i in range(10)]), [1] * 10)
def test_cant_use_async_pool_w_create_engine(self):
"""supplemental test for #8771"""
with expect_raises_message(
exc.ArgumentError,
"Pool class AsyncAdaptedQueuePool "
"cannot be used with non-asyncio engine",
):
create_engine("sqlite://", poolclass=AsyncAdaptedQueuePool)
@testing.requires.queue_pool
@async_test
async def test_dispose(self, async_engine):
c1 = await async_engine.connect()
c2 = await async_engine.connect()
await c1.close()
await c2.close()
p1 = async_engine.pool
if isinstance(p1, AsyncAdaptedQueuePool):
eq_(async_engine.pool.checkedin(), 2)
await async_engine.dispose()
if isinstance(p1, AsyncAdaptedQueuePool):
eq_(async_engine.pool.checkedin(), 0)
is_not(p1, async_engine.pool)
@testing.requires.queue_pool
@async_test
async def test_dispose_no_close(self, async_engine):
c1 = await async_engine.connect()
c2 = await async_engine.connect()
await c1.close()
await c2.close()
p1 = async_engine.pool
if isinstance(p1, AsyncAdaptedQueuePool):
eq_(async_engine.pool.checkedin(), 2)
await async_engine.dispose(close=False)
# TODO: test that DBAPI connection was not closed
if isinstance(p1, AsyncAdaptedQueuePool):
eq_(async_engine.pool.checkedin(), 0)
is_not(p1, async_engine.pool)
@testing.requires.independent_connections
@async_test
async def test_init_once_concurrency(self, async_engine):
async with async_engine.connect() as c1, async_engine.connect() as c2:
coro = asyncio.gather(c1.scalar(select(1)), c2.scalar(select(2)))
eq_(await coro, [1, 2])
@async_test
async def test_connect_ctxmanager(self, async_engine):
async with async_engine.connect() as conn:
result = await conn.execute(select(1))
eq_(result.scalar(), 1)
@async_test
async def test_connect_plain(self, async_engine):
conn = await async_engine.connect()
try:
result = await conn.execute(select(1))
eq_(result.scalar(), 1)
finally:
await conn.close()
@async_test
async def test_connection_not_started(self, async_engine):
conn = async_engine.connect()
testing.assert_raises_message(
asyncio_exc.AsyncContextNotStarted,
"AsyncConnection context has not been started and "
"object has not been awaited.",
conn.begin,
)
@async_test
async def test_transaction_commit(self, async_engine):
users = self.tables.users
async with async_engine.begin() as conn:
await conn.execute(delete(users))
async with async_engine.connect() as conn:
eq_(await conn.scalar(select(func.count(users.c.user_id))), 0)
@async_test
async def test_savepoint_rollback_noctx(self, async_engine):
users = self.tables.users
async with async_engine.begin() as conn:
savepoint = await conn.begin_nested()
await conn.execute(delete(users))
await savepoint.rollback()
async with async_engine.connect() as conn:
eq_(await conn.scalar(select(func.count(users.c.user_id))), 19)
@async_test
async def test_savepoint_commit_noctx(self, async_engine):
users = self.tables.users
async with async_engine.begin() as conn:
savepoint = await conn.begin_nested()
await conn.execute(delete(users))
await savepoint.commit()
async with async_engine.connect() as conn:
eq_(await conn.scalar(select(func.count(users.c.user_id))), 0)
@async_test
async def test_transaction_rollback(self, async_engine):
users = self.tables.users
async with async_engine.connect() as conn:
trans = conn.begin()
await trans.start()
await conn.execute(delete(users))
await trans.rollback()
async with async_engine.connect() as conn:
eq_(await conn.scalar(select(func.count(users.c.user_id))), 19)
@async_test
async def test_conn_transaction_not_started(self, async_engine):
async with async_engine.connect() as conn:
trans = conn.begin()
with expect_raises_message(
asyncio_exc.AsyncContextNotStarted,
"AsyncTransaction context has not been started "
"and object has not been awaited.",
):
await trans.rollback(),
@testing.requires.queue_pool
@async_test
async def test_pool_exhausted_some_timeout(
self, testing_engine, async_engine
):
engine = testing_engine(
asyncio=True,
options=dict(
pool_size=1,
max_overflow=0,
pool_timeout=0.1,
),
)
async with engine.connect():
with expect_raises(exc.TimeoutError):
await engine.connect()
@async_test
async def test_engine_aclose(self, async_engine):
users = self.tables.users
async with contextlib.aclosing(async_engine.connect()) as conn:
await conn.start()
trans = conn.begin()
await trans.start()
await conn.execute(delete(users))
await trans.commit()
assert conn.closed
@testing.requires.queue_pool
@async_test
async def test_pool_exhausted_no_timeout(
self, testing_engine, async_engine
):
engine = testing_engine(
asyncio=True,
options=dict(
pool_size=1,
max_overflow=0,
pool_timeout=0,
),
)
async with engine.connect():
with expect_raises(exc.TimeoutError):
await engine.connect()
@async_test
async def test_create_async_engine_server_side_cursor(self, async_engine):
with expect_raises_message(
asyncio_exc.AsyncMethodRequired,
"Can't set server_side_cursors for async engine globally",
):
create_async_engine(
testing.db.url,
server_side_cursors=True,
)
def test_async_engine_from_config(self):
config = {
"sqlalchemy.url": testing.db.url.render_as_string(
hide_password=False
),
"sqlalchemy.echo": "true",
}
engine = async_engine_from_config(config)
assert engine.url == testing.db.url
assert engine.echo is True
assert engine.dialect.is_async is True
def test_async_creator_and_creator(self):
async def ac():
return None
def c():
return None
with expect_raises_message(
exc.ArgumentError,
"Can only specify one of 'async_creator' or 'creator', "
"not both.",
):
create_async_engine(testing.db.url, creator=c, async_creator=ac)
@async_test
async def test_async_creator_invoked(self, async_testing_engine):
"""test for #8215"""
existing_creator = testing.db.pool._creator
async def async_creator():
sync_conn = await greenlet_spawn(existing_creator)
return sync_conn.driver_connection
async_creator = mock.Mock(side_effect=async_creator)
eq_(async_creator.mock_calls, [])
engine = async_testing_engine(options={"async_creator": async_creator})
async with engine.connect() as conn:
result = await conn.scalar(select(1))
eq_(result, 1)
eq_(async_creator.mock_calls, [mock.call()])
@async_test
async def test_async_creator_accepts_args_if_called_directly(
self, async_testing_engine
):
"""supplemental test for #8215.
The "async_creator" passed to create_async_engine() is expected to take
no arguments, the same way as "creator" passed to create_engine()
works.
However, the ultimate "async_creator" received by the sync-emulating
DBAPI *does* take arguments in its ``.connect()`` method, which will be
all the other arguments passed to ``.connect()``. This functionality
is not currently used, however was decided that the creator should
internally work this way for improved flexibility; see
https://github.com/sqlalchemy/sqlalchemy/issues/8215#issuecomment-1181791539.
That contract is tested here.
""" # noqa: E501
existing_creator = testing.db.pool._creator
async def async_creator(x, y, *, z=None):
sync_conn = await greenlet_spawn(existing_creator)
return sync_conn.driver_connection
async_creator = mock.Mock(side_effect=async_creator)
async_dbapi = testing.db.dialect.loaded_dbapi
conn = await greenlet_spawn(
async_dbapi.connect, 5, y=10, z=8, async_creator_fn=async_creator
)
try:
eq_(async_creator.mock_calls, [mock.call(5, y=10, z=8)])
finally:
await greenlet_spawn(conn.close)
@testing.combinations("stream", "stream_scalars", argnames="method")
@async_test
async def test_server_side_required_for_scalars(
self, async_engine, method
):
with mock.patch.object(
async_engine.dialect, "supports_server_side_cursors", False
):
async with async_engine.connect() as c:
with expect_raises_message(
exc.InvalidRequestError,
"Can't use `stream` or `stream_scalars` with the current "
"dialect since it does not support server side cursors.",
):
if method == "stream":
await c.stream(select(1))
elif method == "stream_scalars":
await c.stream_scalars(select(1))
else:
testing.fail(method)
@async_test
@testing.requires.async_dialect_with_await_close
async def test_active_await_close(self, async_engine):
select_one_sql = select(1).compile(async_engine.sync_engine).string
async with async_engine.connect() as conn:
result = await conn.exec_driver_sql(select_one_sql)
eq_(result.scalar_one(), 1)
driver_cursor = result.context.cursor._cursor
with expect_raises(Exception):
# because the cursor should be closed
await driver_cursor.execute(select_one_sql)
@async_test
async def test_async_creator_handle_error(self, async_testing_engine):
"""test for #11956"""
existing_creator = testing.db.pool._creator
def create_and_break():
sync_conn = existing_creator()
cursor = sync_conn.cursor()
# figure out a way to get a native driver exception. This really
# only applies to asyncpg where we rewrite the exception
# hierarchy with our own emulated exception; other backends raise
# standard DBAPI exceptions (with some buggy cases here and there
# which they miss) even though they are async
try:
cursor.execute("this will raise an error")
except Exception as possibly_emulated_error:
if isinstance(
possibly_emulated_error, exc.EmulatedDBAPIException
):
raise possibly_emulated_error.driver_exception
else:
raise possibly_emulated_error
async def async_creator():
return await greenlet_spawn(create_and_break)
engine = async_testing_engine(options={"async_creator": async_creator})
with expect_raises(exc.DBAPIError):
await engine.connect()
class AsyncCreatePoolTest(fixtures.TestBase):
@config.fixture
def mock_create(self):
with patch(
"sqlalchemy.ext.asyncio.engine._create_pool_from_url",
) as p:
yield p
def test_url_only(self, mock_create):
create_async_pool_from_url("sqlite://")
mock_create.assert_called_once_with("sqlite://", _is_async=True)
def test_pool_args(self, mock_create):
create_async_pool_from_url("sqlite://", foo=99, echo=True)
mock_create.assert_called_once_with(
"sqlite://", foo=99, echo=True, _is_async=True
)
class AsyncEventTest(EngineFixture):
"""The engine events all run in their normal synchronous context.
we do not provide an asyncio event interface at this time.
"""
__backend__ = True
@async_test
async def test_no_async_listeners(self, async_engine):
with testing.expect_raises_message(
NotImplementedError,
"asynchronous events are not implemented "
"at this time. Apply synchronous listeners to the "
"AsyncEngine.sync_engine or "
"AsyncConnection.sync_connection attributes.",
):
event.listen(async_engine, "before_cursor_execute", mock.Mock())
async with async_engine.connect() as conn:
with testing.expect_raises_message(
NotImplementedError,
"asynchronous events are not implemented "
"at this time. Apply synchronous listeners to the "
"AsyncEngine.sync_engine or "
"AsyncConnection.sync_connection attributes.",
):
event.listen(conn, "before_cursor_execute", mock.Mock())
@async_test
async def test_no_async_listeners_dialect_event(self, async_engine):
with testing.expect_raises_message(
NotImplementedError,
"asynchronous events are not implemented "
"at this time. Apply synchronous listeners to the "
"AsyncEngine.sync_engine or "
"AsyncConnection.sync_connection attributes.",
):
event.listen(async_engine, "do_execute", mock.Mock())
@async_test
async def test_no_async_listeners_pool_event(self, async_engine):
with testing.expect_raises_message(
NotImplementedError,
"asynchronous events are not implemented "
"at this time. Apply synchronous listeners to the "
"AsyncEngine.sync_engine or "
"AsyncConnection.sync_connection attributes.",
):
event.listen(async_engine, "checkout", mock.Mock())
@async_test
async def test_sync_before_cursor_execute_engine(self, async_engine):
canary = mock.Mock()
event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
async with async_engine.connect() as conn:
sync_conn = conn.sync_connection
await conn.execute(select(1))
s1 = str(select(1).compile(async_engine))
eq_(
canary.mock_calls,
[mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
)
@async_test
async def test_sync_before_cursor_execute_connection(self, async_engine):
canary = mock.Mock()
async with async_engine.connect() as conn:
sync_conn = conn.sync_connection
event.listen(
async_engine.sync_engine, "before_cursor_execute", canary
)
await conn.execute(select(1))
s1 = str(select(1).compile(async_engine))
eq_(
canary.mock_calls,
[mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
)
@async_test
async def test_event_on_sync_connection(self, async_engine):
canary = mock.Mock()
async with async_engine.connect() as conn:
event.listen(conn.sync_connection, "begin", canary)
async with conn.begin():
eq_(
canary.mock_calls,
[mock.call(conn.sync_connection)],
)
class AsyncInspection(EngineFixture):
__backend__ = True
@async_test
async def test_inspect_engine(self, async_engine):
with testing.expect_raises_message(
exc.NoInspectionAvailable,
"Inspection on an AsyncEngine is currently not supported.",
):
inspect(async_engine)
@async_test
async def test_inspect_connection(self, async_engine):
async with async_engine.connect() as conn:
with testing.expect_raises_message(
exc.NoInspectionAvailable,
"Inspection on an AsyncConnection is currently not supported.",
):
inspect(conn)
class AsyncResultTest(EngineFixture):
__backend__ = True
__requires__ = ("server_side_cursors", "async_dialect")
@async_test
async def test_no_ss_cursor_w_execute(self, async_engine):
users = self.tables.users
async with async_engine.connect() as conn:
conn = await conn.execution_options(stream_results=True)
with expect_raises_message(
async_exc.AsyncMethodRequired,
r"Can't use the AsyncConnection.execute\(\) method with a "
r"server-side cursor. Use the AsyncConnection.stream\(\) "
r"method for an async streaming result set.",
):
await conn.execute(select(users))
@async_test
async def test_no_ss_cursor_w_exec_driver_sql(self, async_engine):
async with async_engine.connect() as conn:
conn = await conn.execution_options(stream_results=True)
with expect_raises_message(
async_exc.AsyncMethodRequired,
r"Can't use the AsyncConnection.exec_driver_sql\(\) "
r"method with a "
r"server-side cursor. Use the AsyncConnection.stream\(\) "
r"method for an async streaming result set.",
):
await conn.exec_driver_sql("SELECT * FROM users")
@async_test
async def test_stream_ctxmanager(self, async_engine):
async with async_engine.connect() as conn:
conn = await conn.execution_options(stream_results=True)
async with conn.stream(select(self.tables.users)) as result:
assert not result._real_result._soft_closed
assert not result.closed
with expect_raises_message(Exception, "hi"):
i = 0
async for row in result:
if i > 2:
raise Exception("hi")
i += 1
assert result._real_result._soft_closed
assert result.closed
@async_test
async def test_stream_scalars_ctxmanager(self, async_engine):
async with async_engine.connect() as conn:
conn = await conn.execution_options(stream_results=True)
async with conn.stream_scalars(
select(self.tables.users)
) as result:
assert not result._real_result._soft_closed
assert not result.closed
with expect_raises_message(Exception, "hi"):
i = 0
async for scalar in result:
if i > 2:
raise Exception("hi")
i += 1
assert result._real_result._soft_closed
assert result.closed
@testing.combinations(
(None,), ("scalars",), ("mappings",), argnames="filter_"
)
@async_test
async def test_all(self, async_engine, filter_):
users = self.tables.users
async with async_engine.connect() as conn:
result = await conn.stream(select(users))
if filter_ == "mappings":
result = result.mappings()
elif filter_ == "scalars":
result = result.scalars(1)
all_ = await result.all()
if filter_ == "mappings":
eq_(
all_,
[
{"user_id": i, "user_name": "name%d" % i}
for i in range(1, 20)
],
)
elif filter_ == "scalars":
eq_(
all_,
["name%d" % i for i in range(1, 20)],
)
else:
eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
@testing.combinations(
(None,),
("scalars",),
("stream_scalars",),
("mappings",),
argnames="filter_",
)
@async_test
async def test_aiter(self, async_engine, filter_):
users = self.tables.users
async with async_engine.connect() as conn:
if filter_ == "stream_scalars":
result = await conn.stream_scalars(select(users.c.user_name))
else:
result = await conn.stream(select(users))
if filter_ == "mappings":
result = result.mappings()
elif filter_ == "scalars":
result = result.scalars(1)
rows = []
async for row in result:
rows.append(row)
if filter_ == "mappings":
eq_(
rows,
[
{"user_id": i, "user_name": "name%d" % i}
for i in range(1, 20)
],
)
elif filter_ in ("scalars", "stream_scalars"):
eq_(
rows,
["name%d" % i for i in range(1, 20)],
)
else:
eq_(rows, [(i, "name%d" % i) for i in range(1, 20)])
@testing.combinations((None,), ("mappings",), argnames="filter_")
@async_test
async def test_keys(self, async_engine, filter_):
users = self.tables.users
async with async_engine.connect() as conn:
result = await conn.stream(select(users))
if filter_ == "mappings":
result = result.mappings()
eq_(result.keys(), ["user_id", "user_name"])
await result.close()
@async_test
async def test_unique_all(self, async_engine):
users = self.tables.users
async with async_engine.connect() as conn:
result = await conn.stream(
union_all(select(users), select(users)).order_by(
users.c.user_id
)
)
all_ = await result.unique().all()
eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
@async_test
async def test_columns_all(self, async_engine):
users = self.tables.users
async with async_engine.connect() as conn:
result = await conn.stream(select(users))
all_ = await result.columns(1).all()
eq_(all_, [("name%d" % i,) for i in range(1, 20)])
@testing.combinations(
(None,), ("scalars",), ("mappings",), argnames="filter_"
)
@testing.combinations(None, 2, 5, 10, argnames="yield_per")
@testing.combinations("method", "opt", argnames="yield_per_type")
@async_test
async def test_partitions(
self, async_engine, filter_, yield_per, yield_per_type
):
users = self.tables.users
async with async_engine.connect() as conn:
stmt = select(users)
if yield_per and yield_per_type == "opt":
stmt = stmt.execution_options(yield_per=yield_per)
result = await conn.stream(stmt)
if filter_ == "mappings":
result = result.mappings()
elif filter_ == "scalars":
result = result.scalars(1)
if yield_per and yield_per_type == "method":
result = result.yield_per(yield_per)
check_result = []
# stream() sets stream_results unconditionally
assert isinstance(
result._real_result.cursor_strategy,
_cursor.BufferedRowCursorFetchStrategy,
)
if yield_per:
partition_size = yield_per
eq_(result._real_result.cursor_strategy._bufsize, yield_per)
async for partition in result.partitions():
check_result.append(partition)
else:
eq_(result._real_result.cursor_strategy._bufsize, 5)
partition_size = 5
async for partition in result.partitions(partition_size):
check_result.append(partition)
ranges = [
(i, min(20, i + partition_size))
for i in range(1, 21, partition_size)
]
if filter_ == "mappings":
eq_(
check_result,
[
[
{"user_id": i, "user_name": "name%d" % i}
for i in range(a, b)
]
for (a, b) in ranges
],
)
elif filter_ == "scalars":
eq_(
check_result,
[["name%d" % i for i in range(a, b)] for (a, b) in ranges],
)
else:
eq_(
check_result,
[
[(i, "name%d" % i) for i in range(a, b)]
for (a, b) in ranges
],
)
@testing.combinations(
(None,), ("scalars",), ("mappings",), argnames="filter_"
)
@async_test
async def test_one_success(self, async_engine, filter_):
users = self.tables.users
async with async_engine.connect() as conn:
result = await conn.stream(
select(users).limit(1).order_by(users.c.user_name)
)
if filter_ == "mappings":
result = result.mappings()
elif filter_ == "scalars":
result = result.scalars()
u1 = await result.one()
if filter_ == "mappings":
eq_(u1, {"user_id": 1, "user_name": "name%d" % 1})
elif filter_ == "scalars":
eq_(u1, 1)
else:
eq_(u1, (1, "name%d" % 1))
@async_test
async def test_one_no_result(self, async_engine):
users = self.tables.users
async with async_engine.connect() as conn:
result = await conn.stream(
select(users).where(users.c.user_name == "nonexistent")
)
with expect_raises_message(
exc.NoResultFound, "No row was found when one was required"
):
await result.one()
@async_test
async def test_one_multi_result(self, async_engine):
users = self.tables.users
async with async_engine.connect() as conn:
result = await conn.stream(
select(users).where(users.c.user_name.in_(["name3", "name5"]))
)
with expect_raises_message(
exc.MultipleResultsFound,
"Multiple rows were found when exactly one was required",
):
await result.one()
@testing.combinations(("scalars",), ("stream_scalars",), argnames="case")
@async_test
async def test_scalars(self, async_engine, case):
users = self.tables.users
stmt = select(users).order_by(users.c.user_id)
async with async_engine.connect() as conn:
if case == "scalars":
result = (await conn.scalars(stmt)).all()
elif case == "stream_scalars":
result = await (await conn.stream_scalars(stmt)).all()
eq_(result, list(range(1, 20)))
@async_test
@testing.combinations(("stream",), ("stream_scalars",), argnames="case")
async def test_stream_fetch_many_not_complete(self, async_engine, case):
users = self.tables.users
big_query = select(users).join(users.alias("other"), true())
async with async_engine.connect() as conn:
if case == "stream":
result = await conn.stream(big_query)
elif case == "stream_scalars":
result = await conn.stream_scalars(big_query)
f1 = await result.fetchmany(5)
f2 = await result.fetchmany(10)
f3 = await result.fetchmany(7)
eq_(len(f1) + len(f2) + len(f3), 22)
res = await result.fetchall()
eq_(len(res), 19 * 19 - 22)
@async_test
@testing.combinations(("stream",), ("execute",), argnames="case")
async def test_cursor_close(self, async_engine, case):
users = self.tables.users
async with async_engine.connect() as conn:
if case == "stream":
result = await conn.stream(select(users))
cursor = result._real_result.cursor
elif case == "execute":
result = await conn.execute(select(users))
cursor = result.cursor
await conn.run_sync(lambda _: cursor.close())
@async_test
@testing.variation("case", ["scalar_one", "scalar_one_or_none", "scalar"])
async def test_stream_scalar(self, async_engine, case: testing.Variation):
users = self.tables.users
async with async_engine.connect() as conn:
result = await conn.stream(
select(users).limit(1).order_by(users.c.user_name)
)
if case.scalar_one:
u1 = await result.scalar_one()
elif case.scalar_one_or_none:
u1 = await result.scalar_one_or_none()
elif case.scalar:
u1 = await result.scalar()
else:
case.fail()
eq_(u1, 1)
class TextSyncDBAPI(fixtures.TestBase):
__requires__ = ("asyncio",)
def test_sync_dbapi_raises(self):
with expect_raises_message(
exc.InvalidRequestError,
"The asyncio extension requires an async driver to be used.",
):
create_async_engine("sqlite:///:memory:")
@testing.fixture
def async_engine(self):
engine = create_engine("sqlite:///:memory:", future=True)
engine.dialect.is_async = True
engine.dialect.supports_server_side_cursors = True
with mock.patch.object(
engine.dialect.execution_ctx_cls,
"create_server_side_cursor",
engine.dialect.execution_ctx_cls.create_default_cursor,
):
yield _async_engine.AsyncEngine(engine)
@async_test
@combinations(
lambda conn: conn.exec_driver_sql("select 1"),
lambda conn: conn.stream(text("select 1")),
lambda conn: conn.execute(text("select 1")),
argnames="case",
)
async def test_sync_driver_execution(self, async_engine, case):
with expect_raises_message(
exc.AwaitRequired,
"The current operation required an async execution but none was",
):
async with async_engine.connect() as conn:
await case(conn)
@async_test
async def test_sync_driver_run_sync(self, async_engine):
async with async_engine.connect() as conn:
res = await conn.run_sync(
lambda conn: conn.scalar(text("select 1"))
)
assert res == 1
assert await conn.run_sync(lambda _: 2) == 2
class AsyncProxyTest(EngineFixture, fixtures.TestBase):
@async_test
async def test_get_transaction(self, async_engine):
async with async_engine.connect() as conn:
async with conn.begin() as trans:
is_(trans.connection, conn)
is_(conn.get_transaction(), trans)
@async_test
async def test_get_nested_transaction(self, async_engine):
async with async_engine.connect() as conn:
async with conn.begin() as trans:
n1 = await conn.begin_nested()
is_(conn.get_nested_transaction(), n1)
n2 = await conn.begin_nested()
is_(conn.get_nested_transaction(), n2)
await n2.commit()
is_(conn.get_nested_transaction(), n1)
is_(conn.get_transaction(), trans)
@async_test
async def test_get_connection(self, async_engine):
async with async_engine.connect() as conn:
is_(
AsyncConnection._retrieve_proxy_for_target(
conn.sync_connection
),
conn,
)
def test_regenerate_connection(self, connection):
async_connection = AsyncConnection._retrieve_proxy_for_target(
connection
)
a2 = AsyncConnection._retrieve_proxy_for_target(connection)
is_(async_connection, a2)
is_not(async_connection, None)
is_(async_connection.engine, a2.engine)
is_not(async_connection.engine, None)
@testing.requires.predictable_gc
@async_test
async def test_gc_engine(self, testing_engine):
ReversibleProxy._proxy_objects.clear()
eq_(len(ReversibleProxy._proxy_objects), 0)
async_engine = AsyncEngine(testing.db)
eq_(len(ReversibleProxy._proxy_objects), 1)
del async_engine
eq_(len(ReversibleProxy._proxy_objects), 0)
@testing.requires.predictable_gc
@async_test
async def test_gc_conn(self, testing_engine):
ReversibleProxy._proxy_objects.clear()
async_engine = AsyncEngine(testing.db)
eq_(len(ReversibleProxy._proxy_objects), 1)
async with async_engine.connect() as conn:
eq_(len(ReversibleProxy._proxy_objects), 2)
async with conn.begin() as trans:
eq_(len(ReversibleProxy._proxy_objects), 3)
del trans
del conn
eq_(len(ReversibleProxy._proxy_objects), 1)
del async_engine
eq_(len(ReversibleProxy._proxy_objects), 0)
def test_regen_conn_but_not_engine(self, async_engine):
with async_engine.sync_engine.connect() as sync_conn:
async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn)
async_conn2 = AsyncConnection._retrieve_proxy_for_target(sync_conn)
is_(async_conn, async_conn2)
is_(async_conn.engine, async_engine)
def test_regen_trans_but_not_conn(self, connection_no_trans):
sync_conn = connection_no_trans
async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn)
trans = sync_conn.begin()
async_t1 = async_conn.get_transaction()
is_(async_t1.connection, async_conn)
is_(async_t1.sync_transaction, trans)
async_t2 = async_conn.get_transaction()
is_(async_t1, async_t2)
class PoolRegenTest(EngineFixture):
@testing.requires.queue_pool
@async_test
@testing.variation("do_dispose", [True, False])
async def test_gather_after_dispose(self, testing_engine, do_dispose):
engine = testing_engine(
asyncio=True, options=dict(pool_size=10, max_overflow=10)
)
async def thing(engine):
async with engine.connect() as conn:
await conn.exec_driver_sql(str(select(1).compile(engine)))
if do_dispose:
await engine.dispose()
tasks = [thing(engine) for _ in range(10)]
await asyncio.gather(*tasks)