mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-07 01:10:52 -04:00
380c234ce9
Fixed issue in the aiosqlite driver where SQLAlchemy's setting of aiosqlite's worker thread to "daemon" stopped working because the aiosqlite architecture moved the location of the worker thread in version 0.22.0. This "daemon" flag is necessary so that a program is able to exit if the SQLite connection itself was not explicitly closed, which is particularly likely with SQLAlchemy as it maintains SQLite connections in a connection pool. While it's perfectly fine to call :meth:`.AsyncEngine.dispose` before program exit, this is not historically or technically necessary for any driver of any known backend, since a primary feature of relational databases is durability. The change also implements support for "terminate" with aiosqlite when using version version 0.22.1 or greater, which implements a sync ``.stop()`` method. Fixes: #13039 Change-Id: I46efcbaab9dd028f673e113d5f6f2ceddfd133ca
1652 lines
56 KiB
Python
1652 lines
56 KiB
Python
import asyncio
|
|
import contextlib
|
|
import inspect as stdlib_inspect
|
|
from unittest.mock import patch
|
|
|
|
from sqlalchemy import AssertionPool
|
|
from sqlalchemy import Column
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy import delete
|
|
from sqlalchemy import event
|
|
from sqlalchemy import exc
|
|
from sqlalchemy import func
|
|
from sqlalchemy import inspect
|
|
from sqlalchemy import Integer
|
|
from sqlalchemy import NullPool
|
|
from sqlalchemy import QueuePool
|
|
from sqlalchemy import select
|
|
from sqlalchemy import SingletonThreadPool
|
|
from sqlalchemy import StaticPool
|
|
from sqlalchemy import String
|
|
from sqlalchemy import Table
|
|
from sqlalchemy import testing
|
|
from sqlalchemy import text
|
|
from sqlalchemy import true
|
|
from sqlalchemy import union_all
|
|
from sqlalchemy.engine import cursor as _cursor
|
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
from sqlalchemy.ext.asyncio import create_async_pool_from_url
|
|
from sqlalchemy.ext.asyncio import engine as _async_engine
|
|
from sqlalchemy.ext.asyncio import exc as async_exc
|
|
from sqlalchemy.ext.asyncio import exc as asyncio_exc
|
|
from sqlalchemy.ext.asyncio.base import ReversibleProxy
|
|
from sqlalchemy.ext.asyncio.engine import AsyncConnection
|
|
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
|
from sqlalchemy.pool import AsyncAdaptedQueuePool
|
|
from sqlalchemy.testing import assertions
|
|
from sqlalchemy.testing import async_test
|
|
from sqlalchemy.testing import combinations
|
|
from sqlalchemy.testing import config
|
|
from sqlalchemy.testing import engines
|
|
from sqlalchemy.testing import eq_
|
|
from sqlalchemy.testing import eq_regex
|
|
from sqlalchemy.testing import expect_raises
|
|
from sqlalchemy.testing import expect_raises_message
|
|
from sqlalchemy.testing import fixtures
|
|
from sqlalchemy.testing import is_
|
|
from sqlalchemy.testing import is_false
|
|
from sqlalchemy.testing import is_none
|
|
from sqlalchemy.testing import is_not
|
|
from sqlalchemy.testing import is_true
|
|
from sqlalchemy.testing import mock
|
|
from sqlalchemy.testing import ne_
|
|
from sqlalchemy.util import greenlet_spawn
|
|
|
|
|
|
class AsyncFixture:
|
|
@config.fixture(
|
|
params=[
|
|
(rollback, run_second_execute, begin_nested)
|
|
for rollback in (True, False)
|
|
for run_second_execute in (True, False)
|
|
for begin_nested in (True, False)
|
|
]
|
|
)
|
|
def async_trans_ctx_manager_fixture(self, request, metadata):
|
|
rollback, run_second_execute, begin_nested = request.param
|
|
|
|
t = Table("test", metadata, Column("data", Integer))
|
|
eng = getattr(self, "bind", None) or config.db
|
|
|
|
t.create(eng)
|
|
|
|
async def run_test(subject, trans_on_subject, execute_on_subject):
|
|
async with subject.begin() as trans:
|
|
if begin_nested:
|
|
if not config.requirements.savepoints.enabled:
|
|
config.skip_test("savepoints not enabled")
|
|
if execute_on_subject:
|
|
nested_trans = subject.begin_nested()
|
|
else:
|
|
nested_trans = trans.begin_nested()
|
|
|
|
async with nested_trans:
|
|
if execute_on_subject:
|
|
await subject.execute(t.insert(), {"data": 10})
|
|
else:
|
|
await trans.execute(t.insert(), {"data": 10})
|
|
|
|
# for nested trans, we always commit/rollback on the
|
|
# "nested trans" object itself.
|
|
# only Session(future=False) will affect savepoint
|
|
# transaction for session.commit/rollback
|
|
|
|
if rollback:
|
|
await nested_trans.rollback()
|
|
else:
|
|
await nested_trans.commit()
|
|
|
|
if run_second_execute:
|
|
with assertions.expect_raises_message(
|
|
exc.InvalidRequestError,
|
|
"Can't operate on closed transaction "
|
|
"inside context manager. Please complete the "
|
|
"context manager "
|
|
"before emitting further commands.",
|
|
):
|
|
if execute_on_subject:
|
|
await subject.execute(
|
|
t.insert(), {"data": 12}
|
|
)
|
|
else:
|
|
await trans.execute(
|
|
t.insert(), {"data": 12}
|
|
)
|
|
|
|
# outside the nested trans block, but still inside the
|
|
# transaction block, we can run SQL, and it will be
|
|
# committed
|
|
if execute_on_subject:
|
|
await subject.execute(t.insert(), {"data": 14})
|
|
else:
|
|
await trans.execute(t.insert(), {"data": 14})
|
|
|
|
else:
|
|
if execute_on_subject:
|
|
await subject.execute(t.insert(), {"data": 10})
|
|
else:
|
|
await trans.execute(t.insert(), {"data": 10})
|
|
|
|
if trans_on_subject:
|
|
if rollback:
|
|
await subject.rollback()
|
|
else:
|
|
await subject.commit()
|
|
else:
|
|
if rollback:
|
|
await trans.rollback()
|
|
else:
|
|
await trans.commit()
|
|
|
|
if run_second_execute:
|
|
with assertions.expect_raises_message(
|
|
exc.InvalidRequestError,
|
|
"Can't operate on closed transaction inside "
|
|
"context "
|
|
"manager. Please complete the context manager "
|
|
"before emitting further commands.",
|
|
):
|
|
if execute_on_subject:
|
|
await subject.execute(t.insert(), {"data": 12})
|
|
else:
|
|
await trans.execute(t.insert(), {"data": 12})
|
|
|
|
expected_committed = 0
|
|
if begin_nested:
|
|
# begin_nested variant, we inserted a row after the nested
|
|
# block
|
|
expected_committed += 1
|
|
if not rollback:
|
|
# not rollback variant, our row inserted in the target
|
|
# block itself would be committed
|
|
expected_committed += 1
|
|
|
|
if execute_on_subject:
|
|
eq_(
|
|
await subject.scalar(select(func.count()).select_from(t)),
|
|
expected_committed,
|
|
)
|
|
else:
|
|
with subject.connect() as conn:
|
|
eq_(
|
|
await conn.scalar(select(func.count()).select_from(t)),
|
|
expected_committed,
|
|
)
|
|
|
|
return run_test
|
|
|
|
|
|
class EngineFixture(AsyncFixture, fixtures.TablesTest):
|
|
__requires__ = ("async_dialect",)
|
|
|
|
@testing.fixture
|
|
def async_engine(self):
|
|
return engines.testing_engine(
|
|
asyncio=True, options={"sqlite_share_pool": True}
|
|
)
|
|
|
|
@testing.fixture
|
|
def adhoc_async_engine(self):
|
|
return engines.testing_engine(asyncio=True)
|
|
|
|
@testing.fixture
|
|
def async_connection(self, async_engine):
|
|
with async_engine.sync_engine.connect() as conn:
|
|
yield AsyncConnection(async_engine, conn)
|
|
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
Table(
|
|
"users",
|
|
metadata,
|
|
Column("user_id", Integer, primary_key=True, autoincrement=False),
|
|
Column("user_name", String(20)),
|
|
)
|
|
|
|
@classmethod
|
|
def insert_data(cls, connection):
|
|
users = cls.tables.users
|
|
connection.execute(
|
|
users.insert(),
|
|
[{"user_id": i, "user_name": "name%d" % i} for i in range(1, 20)],
|
|
)
|
|
|
|
|
|
class AsyncEngineTest(EngineFixture):
|
|
__backend__ = True
|
|
|
|
@testing.fails("the failure is the test")
|
|
@async_test
|
|
async def test_we_are_definitely_running_async_tests(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
eq_(await conn.scalar(text("select 1")), 2)
|
|
|
|
@async_test
|
|
async def test_interrupt_ctxmanager_connection(
|
|
self, async_engine, async_trans_ctx_manager_fixture
|
|
):
|
|
fn = async_trans_ctx_manager_fixture
|
|
|
|
async with async_engine.connect() as conn:
|
|
await fn(conn, trans_on_subject=False, execute_on_subject=True)
|
|
|
|
def test_proxied_attrs_engine(self, async_engine):
|
|
sync_engine = async_engine.sync_engine
|
|
|
|
is_(async_engine.url, sync_engine.url)
|
|
is_(async_engine.pool, sync_engine.pool)
|
|
is_(async_engine.dialect, sync_engine.dialect)
|
|
eq_(async_engine.name, sync_engine.name)
|
|
eq_(async_engine.driver, sync_engine.driver)
|
|
eq_(async_engine.echo, sync_engine.echo)
|
|
|
|
@async_test
|
|
async def test_run_async(self, async_engine):
|
|
async def test_meth(async_driver_connection):
|
|
# there's no method that's guaranteed to be on every
|
|
# driver, so just stringify it and compare that to the
|
|
# outside
|
|
return str(async_driver_connection)
|
|
|
|
def run_sync_to_async(connection):
|
|
connection_fairy = connection.connection
|
|
async_return = connection_fairy.run_async(
|
|
lambda driver_connection: test_meth(driver_connection)
|
|
)
|
|
assert not stdlib_inspect.iscoroutine(async_return)
|
|
return async_return
|
|
|
|
async with async_engine.connect() as conn:
|
|
driver_connection = (
|
|
await conn.get_raw_connection()
|
|
).driver_connection
|
|
res = await conn.run_sync(run_sync_to_async)
|
|
assert not stdlib_inspect.iscoroutine(res)
|
|
eq_(res, str(driver_connection))
|
|
|
|
@async_test
|
|
async def test_engine_eq_ne(self, async_engine):
|
|
e2 = _async_engine.AsyncEngine(async_engine.sync_engine)
|
|
e3 = engines.testing_engine(asyncio=True)
|
|
|
|
eq_(async_engine, e2)
|
|
ne_(async_engine, e3)
|
|
|
|
is_false(async_engine == None)
|
|
|
|
def test_no_attach_to_event_loop(self, testing_engine):
|
|
"""test #6409
|
|
|
|
note this test does not seem to trigger the bug that was originally
|
|
fixed in #6409, when using python 3.10 and higher (the original issue
|
|
can repro in 3.8 at least, based on my testing). It's been simplified
|
|
to no longer explicitly create a new loop, asyncio.run() already
|
|
creates a new loop.
|
|
|
|
"""
|
|
|
|
import asyncio
|
|
import threading
|
|
|
|
errs = []
|
|
|
|
def go():
|
|
async def main():
|
|
tasks = [task() for _ in range(2)]
|
|
|
|
await asyncio.gather(*tasks)
|
|
await engine.dispose()
|
|
|
|
async def task():
|
|
async with engine.begin() as connection:
|
|
result = await connection.execute(select(1))
|
|
result.all()
|
|
|
|
try:
|
|
engine = engines.testing_engine(asyncio=True)
|
|
|
|
asyncio.run(main())
|
|
except Exception as err:
|
|
errs.append(err)
|
|
|
|
t = threading.Thread(target=go)
|
|
t.start()
|
|
t.join()
|
|
|
|
if errs:
|
|
raise errs[0]
|
|
|
|
@async_test
|
|
async def test_connection_info(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
conn.info["foo"] = "bar"
|
|
|
|
eq_(conn.sync_connection.info, {"foo": "bar"})
|
|
|
|
@async_test
|
|
async def test_connection_eq_ne(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
c2 = _async_engine.AsyncConnection(
|
|
async_engine, conn.sync_connection
|
|
)
|
|
|
|
eq_(conn, c2)
|
|
|
|
async with async_engine.connect() as c3:
|
|
ne_(conn, c3)
|
|
|
|
is_false(conn == None)
|
|
|
|
@async_test
|
|
async def test_transaction_eq_ne(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
t1 = await conn.begin()
|
|
|
|
t2 = _async_engine.AsyncTransaction._regenerate_proxy_for_target(
|
|
t1._proxied
|
|
)
|
|
|
|
eq_(t1, t2)
|
|
|
|
is_false(t1 == None)
|
|
|
|
@testing.variation("simulate_gc", [True, False])
|
|
def test_appropriate_warning_for_gced_connection(
|
|
self, adhoc_async_engine, simulate_gc
|
|
):
|
|
"""test #9237 which builds upon a not really complete solution
|
|
added for #8419."""
|
|
|
|
async def go():
|
|
conn = await adhoc_async_engine.connect()
|
|
await conn.begin()
|
|
await conn.execute(select(1))
|
|
pool_connection = await conn.get_raw_connection()
|
|
return pool_connection
|
|
|
|
from sqlalchemy.util.concurrency import await_
|
|
|
|
pool_connection = await_(go())
|
|
|
|
rec = pool_connection._connection_record
|
|
ref = rec.fairy_ref
|
|
pool = pool_connection._pool
|
|
echo = False
|
|
|
|
if simulate_gc:
|
|
# not using expect_warnings() here because we also want to do a
|
|
# negative test for warnings, and we want to absolutely make sure
|
|
# the thing here that emits the warning is the correct path
|
|
from sqlalchemy.pool.base import _finalize_fairy
|
|
|
|
with (
|
|
mock.patch.object(
|
|
pool._dialect,
|
|
"do_rollback",
|
|
mock.Mock(side_effect=Exception("can't run rollback")),
|
|
),
|
|
mock.patch("sqlalchemy.util.warn") as m,
|
|
):
|
|
_finalize_fairy(
|
|
None, rec, pool, ref, echo, transaction_was_reset=False
|
|
)
|
|
|
|
if adhoc_async_engine.dialect.has_terminate:
|
|
expected_msg = (
|
|
"The garbage collector is trying to clean up.*which will "
|
|
"be terminated."
|
|
)
|
|
else:
|
|
expected_msg = (
|
|
"The garbage collector is trying to clean up.*which will "
|
|
"be dropped, as it cannot be safely terminated."
|
|
)
|
|
|
|
# [1] == .args, not in 3.7
|
|
eq_regex(m.mock_calls[0][1][0], expected_msg)
|
|
else:
|
|
# the warning emitted by the pool is inside of a try/except:
|
|
# so it's impossible right now to have this warning "raise".
|
|
# for now, test by using mock.patch
|
|
|
|
with mock.patch("sqlalchemy.util.warn") as m:
|
|
pool_connection.close()
|
|
|
|
eq_(m.mock_calls, [])
|
|
|
|
@async_test
|
|
@testing.skip_if(lambda config: not config.db.dialect.has_terminate)
|
|
async def test_dbapi_terminate(self, adhoc_async_engine):
|
|
|
|
conn = await adhoc_async_engine.raw_connection()
|
|
dbapi_conn = conn.dbapi_connection
|
|
dbapi_conn.terminate()
|
|
conn.invalidate()
|
|
|
|
@async_test
|
|
async def test_statement_compile(self, async_engine):
|
|
stmt = str(select(1).compile(async_engine))
|
|
async with async_engine.connect() as conn:
|
|
eq_(str(select(1).compile(conn)), stmt)
|
|
|
|
def test_clear_compiled_cache(self, async_engine):
|
|
async_engine.sync_engine._compiled_cache["foo"] = "bar"
|
|
eq_(async_engine.sync_engine._compiled_cache["foo"], "bar")
|
|
async_engine.clear_compiled_cache()
|
|
assert "foo" not in async_engine.sync_engine._compiled_cache
|
|
|
|
def test_execution_options(self, async_engine):
|
|
a2 = async_engine.execution_options(foo="bar")
|
|
assert isinstance(a2, _async_engine.AsyncEngine)
|
|
eq_(a2.sync_engine._execution_options, {"foo": "bar"})
|
|
eq_(async_engine.sync_engine._execution_options, {})
|
|
|
|
"""
|
|
|
|
attr uri, pool, dialect, engine, name, driver, echo
|
|
methods clear_compiled_cache, update_execution_options,
|
|
execution_options, get_execution_options, dispose
|
|
|
|
"""
|
|
|
|
@async_test
|
|
async def test_proxied_attrs_connection(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
sync_conn = conn.sync_connection
|
|
|
|
is_(conn.engine, async_engine)
|
|
is_(conn.closed, sync_conn.closed)
|
|
is_(conn.dialect, async_engine.sync_engine.dialect)
|
|
eq_(
|
|
conn.default_isolation_level, sync_conn.default_isolation_level
|
|
)
|
|
|
|
@async_test
|
|
async def test_transaction_accessor(self, async_connection):
|
|
conn = async_connection
|
|
is_none(conn.get_transaction())
|
|
is_false(conn.in_transaction())
|
|
is_false(conn.in_nested_transaction())
|
|
|
|
trans = await conn.begin()
|
|
|
|
is_true(conn.in_transaction())
|
|
is_false(conn.in_nested_transaction())
|
|
|
|
is_(trans.sync_transaction, conn.get_transaction().sync_transaction)
|
|
|
|
nested = await conn.begin_nested()
|
|
|
|
is_true(conn.in_transaction())
|
|
is_true(conn.in_nested_transaction())
|
|
|
|
is_(
|
|
conn.get_nested_transaction().sync_transaction,
|
|
nested.sync_transaction,
|
|
)
|
|
eq_(conn.get_nested_transaction(), nested)
|
|
|
|
is_(trans.sync_transaction, conn.get_transaction().sync_transaction)
|
|
|
|
await nested.commit()
|
|
|
|
is_true(conn.in_transaction())
|
|
is_false(conn.in_nested_transaction())
|
|
|
|
await trans.rollback()
|
|
|
|
is_none(conn.get_transaction())
|
|
is_false(conn.in_transaction())
|
|
is_false(conn.in_nested_transaction())
|
|
|
|
@testing.requires.queue_pool
|
|
@async_test
|
|
async def test_invalidate(self, async_engine):
|
|
conn = await async_engine.connect()
|
|
|
|
is_(conn.invalidated, False)
|
|
|
|
connection_fairy = await conn.get_raw_connection()
|
|
is_(connection_fairy.is_valid, True)
|
|
dbapi_connection = connection_fairy.dbapi_connection
|
|
|
|
await conn.invalidate()
|
|
|
|
if testing.against("postgresql+asyncpg"):
|
|
assert dbapi_connection._connection.is_closed()
|
|
|
|
new_fairy = await conn.get_raw_connection()
|
|
is_not(new_fairy.dbapi_connection, dbapi_connection)
|
|
is_not(new_fairy, connection_fairy)
|
|
is_(new_fairy.is_valid, True)
|
|
is_(connection_fairy.is_valid, False)
|
|
await conn.close()
|
|
|
|
@async_test
|
|
async def test_get_dbapi_connection_raise(self, async_connection):
|
|
with testing.expect_raises_message(
|
|
exc.InvalidRequestError,
|
|
"AsyncConnection.connection accessor is not "
|
|
"implemented as the attribute",
|
|
):
|
|
async_connection.connection
|
|
|
|
@async_test
|
|
async def test_get_raw_connection(self, async_connection):
|
|
pooled = await async_connection.get_raw_connection()
|
|
is_(pooled, async_connection.sync_connection.connection)
|
|
|
|
@async_test
|
|
async def test_isolation_level(self, async_connection):
|
|
conn = async_connection
|
|
sync_isolation_level = await greenlet_spawn(
|
|
conn.sync_connection.get_isolation_level
|
|
)
|
|
isolation_level = await conn.get_isolation_level()
|
|
|
|
eq_(isolation_level, sync_isolation_level)
|
|
|
|
await conn.execution_options(isolation_level="SERIALIZABLE")
|
|
isolation_level = await conn.get_isolation_level()
|
|
|
|
eq_(isolation_level, "SERIALIZABLE")
|
|
|
|
@testing.combinations(
|
|
(
|
|
AsyncAdaptedQueuePool,
|
|
True,
|
|
),
|
|
(
|
|
QueuePool,
|
|
False,
|
|
),
|
|
(NullPool, True),
|
|
(SingletonThreadPool, False),
|
|
(StaticPool, True),
|
|
(AssertionPool, True),
|
|
argnames="pool_cls,should_work",
|
|
)
|
|
@testing.variation("instantiate", [True, False])
|
|
@async_test
|
|
async def test_pool_classes(
|
|
self, async_testing_engine, pool_cls, instantiate, should_work
|
|
):
|
|
"""test #8771"""
|
|
if instantiate:
|
|
if pool_cls in (QueuePool, AsyncAdaptedQueuePool):
|
|
pool = pool_cls(creator=testing.db.pool._creator, timeout=10)
|
|
else:
|
|
pool = pool_cls(
|
|
creator=testing.db.pool._creator,
|
|
)
|
|
|
|
options = {"pool": pool}
|
|
else:
|
|
if pool_cls in (QueuePool, AsyncAdaptedQueuePool):
|
|
options = {"poolclass": pool_cls, "pool_timeout": 10}
|
|
else:
|
|
options = {"poolclass": pool_cls}
|
|
|
|
if not should_work:
|
|
with expect_raises_message(
|
|
exc.ArgumentError,
|
|
f"Pool class {pool_cls.__name__} "
|
|
"cannot be used with asyncio engine",
|
|
):
|
|
async_testing_engine(options=options)
|
|
return
|
|
|
|
e = async_testing_engine(options=options)
|
|
|
|
if pool_cls is AssertionPool:
|
|
async with e.connect() as conn:
|
|
result = await conn.scalar(select(1))
|
|
eq_(result, 1)
|
|
return
|
|
|
|
async def go():
|
|
async with e.connect() as conn:
|
|
result = await conn.scalar(select(1))
|
|
eq_(result, 1)
|
|
return result
|
|
|
|
eq_(await asyncio.gather(*[go() for i in range(10)]), [1] * 10)
|
|
|
|
def test_cant_use_async_pool_w_create_engine(self):
|
|
"""supplemental test for #8771"""
|
|
|
|
with expect_raises_message(
|
|
exc.ArgumentError,
|
|
"Pool class AsyncAdaptedQueuePool "
|
|
"cannot be used with non-asyncio engine",
|
|
):
|
|
create_engine("sqlite://", poolclass=AsyncAdaptedQueuePool)
|
|
|
|
@testing.requires.queue_pool
|
|
@async_test
|
|
async def test_dispose(self, async_engine):
|
|
c1 = await async_engine.connect()
|
|
c2 = await async_engine.connect()
|
|
|
|
await c1.close()
|
|
await c2.close()
|
|
|
|
p1 = async_engine.pool
|
|
|
|
if isinstance(p1, AsyncAdaptedQueuePool):
|
|
eq_(async_engine.pool.checkedin(), 2)
|
|
|
|
await async_engine.dispose()
|
|
if isinstance(p1, AsyncAdaptedQueuePool):
|
|
eq_(async_engine.pool.checkedin(), 0)
|
|
is_not(p1, async_engine.pool)
|
|
|
|
@testing.requires.queue_pool
|
|
@async_test
|
|
async def test_dispose_no_close(self, async_engine):
|
|
c1 = await async_engine.connect()
|
|
c2 = await async_engine.connect()
|
|
|
|
await c1.close()
|
|
await c2.close()
|
|
|
|
p1 = async_engine.pool
|
|
|
|
if isinstance(p1, AsyncAdaptedQueuePool):
|
|
eq_(async_engine.pool.checkedin(), 2)
|
|
|
|
await async_engine.dispose(close=False)
|
|
|
|
# TODO: test that DBAPI connection was not closed
|
|
|
|
if isinstance(p1, AsyncAdaptedQueuePool):
|
|
eq_(async_engine.pool.checkedin(), 0)
|
|
is_not(p1, async_engine.pool)
|
|
|
|
@testing.requires.independent_connections
|
|
@async_test
|
|
async def test_init_once_concurrency(self, async_engine):
|
|
async with async_engine.connect() as c1, async_engine.connect() as c2:
|
|
coro = asyncio.gather(c1.scalar(select(1)), c2.scalar(select(2)))
|
|
eq_(await coro, [1, 2])
|
|
|
|
@async_test
|
|
async def test_connect_ctxmanager(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.execute(select(1))
|
|
eq_(result.scalar(), 1)
|
|
|
|
@async_test
|
|
async def test_connect_plain(self, async_engine):
|
|
conn = await async_engine.connect()
|
|
try:
|
|
result = await conn.execute(select(1))
|
|
eq_(result.scalar(), 1)
|
|
finally:
|
|
await conn.close()
|
|
|
|
@async_test
|
|
async def test_connection_not_started(self, async_engine):
|
|
conn = async_engine.connect()
|
|
testing.assert_raises_message(
|
|
asyncio_exc.AsyncContextNotStarted,
|
|
"AsyncConnection context has not been started and "
|
|
"object has not been awaited.",
|
|
conn.begin,
|
|
)
|
|
|
|
@async_test
|
|
async def test_transaction_commit(self, async_engine):
|
|
users = self.tables.users
|
|
|
|
async with async_engine.begin() as conn:
|
|
await conn.execute(delete(users))
|
|
|
|
async with async_engine.connect() as conn:
|
|
eq_(await conn.scalar(select(func.count(users.c.user_id))), 0)
|
|
|
|
@async_test
|
|
async def test_savepoint_rollback_noctx(self, async_engine):
|
|
users = self.tables.users
|
|
|
|
async with async_engine.begin() as conn:
|
|
savepoint = await conn.begin_nested()
|
|
await conn.execute(delete(users))
|
|
await savepoint.rollback()
|
|
|
|
async with async_engine.connect() as conn:
|
|
eq_(await conn.scalar(select(func.count(users.c.user_id))), 19)
|
|
|
|
@async_test
|
|
async def test_savepoint_commit_noctx(self, async_engine):
|
|
users = self.tables.users
|
|
|
|
async with async_engine.begin() as conn:
|
|
savepoint = await conn.begin_nested()
|
|
await conn.execute(delete(users))
|
|
await savepoint.commit()
|
|
|
|
async with async_engine.connect() as conn:
|
|
eq_(await conn.scalar(select(func.count(users.c.user_id))), 0)
|
|
|
|
@async_test
|
|
async def test_transaction_rollback(self, async_engine):
|
|
users = self.tables.users
|
|
|
|
async with async_engine.connect() as conn:
|
|
trans = conn.begin()
|
|
await trans.start()
|
|
await conn.execute(delete(users))
|
|
await trans.rollback()
|
|
|
|
async with async_engine.connect() as conn:
|
|
eq_(await conn.scalar(select(func.count(users.c.user_id))), 19)
|
|
|
|
@async_test
|
|
async def test_conn_transaction_not_started(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
trans = conn.begin()
|
|
with expect_raises_message(
|
|
asyncio_exc.AsyncContextNotStarted,
|
|
"AsyncTransaction context has not been started "
|
|
"and object has not been awaited.",
|
|
):
|
|
await trans.rollback(),
|
|
|
|
@testing.requires.queue_pool
|
|
@async_test
|
|
async def test_pool_exhausted_some_timeout(
|
|
self, testing_engine, async_engine
|
|
):
|
|
engine = testing_engine(
|
|
asyncio=True,
|
|
options=dict(
|
|
pool_size=1,
|
|
max_overflow=0,
|
|
pool_timeout=0.1,
|
|
),
|
|
)
|
|
async with engine.connect():
|
|
with expect_raises(exc.TimeoutError):
|
|
await engine.connect()
|
|
|
|
@async_test
|
|
async def test_engine_aclose(self, async_engine):
|
|
users = self.tables.users
|
|
async with contextlib.aclosing(async_engine.connect()) as conn:
|
|
await conn.start()
|
|
trans = conn.begin()
|
|
await trans.start()
|
|
await conn.execute(delete(users))
|
|
await trans.commit()
|
|
assert conn.closed
|
|
|
|
@testing.requires.queue_pool
|
|
@async_test
|
|
async def test_pool_exhausted_no_timeout(
|
|
self, testing_engine, async_engine
|
|
):
|
|
engine = testing_engine(
|
|
asyncio=True,
|
|
options=dict(
|
|
pool_size=1,
|
|
max_overflow=0,
|
|
pool_timeout=0,
|
|
),
|
|
)
|
|
async with engine.connect():
|
|
with expect_raises(exc.TimeoutError):
|
|
await engine.connect()
|
|
|
|
@async_test
|
|
async def test_create_async_engine_server_side_cursor(self, async_engine):
|
|
with expect_raises_message(
|
|
asyncio_exc.AsyncMethodRequired,
|
|
"Can't set server_side_cursors for async engine globally",
|
|
):
|
|
create_async_engine(
|
|
testing.db.url,
|
|
server_side_cursors=True,
|
|
)
|
|
|
|
def test_async_engine_from_config(self):
|
|
config = {
|
|
"sqlalchemy.url": testing.db.url.render_as_string(
|
|
hide_password=False
|
|
),
|
|
"sqlalchemy.echo": "true",
|
|
}
|
|
engine = async_engine_from_config(config)
|
|
assert engine.url == testing.db.url
|
|
assert engine.echo is True
|
|
assert engine.dialect.is_async is True
|
|
|
|
def test_async_creator_and_creator(self):
|
|
async def ac():
|
|
return None
|
|
|
|
def c():
|
|
return None
|
|
|
|
with expect_raises_message(
|
|
exc.ArgumentError,
|
|
"Can only specify one of 'async_creator' or 'creator', "
|
|
"not both.",
|
|
):
|
|
create_async_engine(testing.db.url, creator=c, async_creator=ac)
|
|
|
|
@async_test
|
|
async def test_async_creator_invoked(self, async_testing_engine):
|
|
"""test for #8215"""
|
|
|
|
existing_creator = testing.db.pool._creator
|
|
|
|
async def async_creator():
|
|
sync_conn = await greenlet_spawn(existing_creator)
|
|
return sync_conn.driver_connection
|
|
|
|
async_creator = mock.Mock(side_effect=async_creator)
|
|
|
|
eq_(async_creator.mock_calls, [])
|
|
|
|
engine = async_testing_engine(options={"async_creator": async_creator})
|
|
async with engine.connect() as conn:
|
|
result = await conn.scalar(select(1))
|
|
eq_(result, 1)
|
|
|
|
eq_(async_creator.mock_calls, [mock.call()])
|
|
|
|
@async_test
|
|
async def test_async_creator_accepts_args_if_called_directly(
|
|
self, async_testing_engine
|
|
):
|
|
"""supplemental test for #8215.
|
|
|
|
The "async_creator" passed to create_async_engine() is expected to take
|
|
no arguments, the same way as "creator" passed to create_engine()
|
|
works.
|
|
|
|
However, the ultimate "async_creator" received by the sync-emulating
|
|
DBAPI *does* take arguments in its ``.connect()`` method, which will be
|
|
all the other arguments passed to ``.connect()``. This functionality
|
|
is not currently used, however was decided that the creator should
|
|
internally work this way for improved flexibility; see
|
|
https://github.com/sqlalchemy/sqlalchemy/issues/8215#issuecomment-1181791539.
|
|
That contract is tested here.
|
|
|
|
""" # noqa: E501
|
|
|
|
existing_creator = testing.db.pool._creator
|
|
|
|
async def async_creator(x, y, *, z=None):
|
|
sync_conn = await greenlet_spawn(existing_creator)
|
|
return sync_conn.driver_connection
|
|
|
|
async_creator = mock.Mock(side_effect=async_creator)
|
|
|
|
async_dbapi = testing.db.dialect.loaded_dbapi
|
|
|
|
conn = await greenlet_spawn(
|
|
async_dbapi.connect, 5, y=10, z=8, async_creator_fn=async_creator
|
|
)
|
|
try:
|
|
eq_(async_creator.mock_calls, [mock.call(5, y=10, z=8)])
|
|
finally:
|
|
await greenlet_spawn(conn.close)
|
|
|
|
@testing.combinations("stream", "stream_scalars", argnames="method")
|
|
@async_test
|
|
async def test_server_side_required_for_scalars(
|
|
self, async_engine, method
|
|
):
|
|
with mock.patch.object(
|
|
async_engine.dialect, "supports_server_side_cursors", False
|
|
):
|
|
async with async_engine.connect() as c:
|
|
with expect_raises_message(
|
|
exc.InvalidRequestError,
|
|
"Can't use `stream` or `stream_scalars` with the current "
|
|
"dialect since it does not support server side cursors.",
|
|
):
|
|
if method == "stream":
|
|
await c.stream(select(1))
|
|
elif method == "stream_scalars":
|
|
await c.stream_scalars(select(1))
|
|
else:
|
|
testing.fail(method)
|
|
|
|
@async_test
|
|
@testing.requires.async_dialect_with_await_close
|
|
async def test_active_await_close(self, async_engine):
|
|
select_one_sql = select(1).compile(async_engine.sync_engine).string
|
|
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.exec_driver_sql(select_one_sql)
|
|
eq_(result.scalar_one(), 1)
|
|
driver_cursor = result.context.cursor._cursor
|
|
|
|
with expect_raises(Exception):
|
|
# because the cursor should be closed
|
|
await driver_cursor.execute(select_one_sql)
|
|
|
|
@async_test
|
|
async def test_async_creator_handle_error(self, async_testing_engine):
|
|
"""test for #11956"""
|
|
|
|
existing_creator = testing.db.pool._creator
|
|
|
|
def create_and_break():
|
|
sync_conn = existing_creator()
|
|
cursor = sync_conn.cursor()
|
|
|
|
# figure out a way to get a native driver exception. This really
|
|
# only applies to asyncpg where we rewrite the exception
|
|
# hierarchy with our own emulated exception; other backends raise
|
|
# standard DBAPI exceptions (with some buggy cases here and there
|
|
# which they miss) even though they are async
|
|
try:
|
|
cursor.execute("this will raise an error")
|
|
except Exception as possibly_emulated_error:
|
|
if isinstance(
|
|
possibly_emulated_error, exc.EmulatedDBAPIException
|
|
):
|
|
raise possibly_emulated_error.driver_exception
|
|
else:
|
|
raise possibly_emulated_error
|
|
|
|
async def async_creator():
|
|
return await greenlet_spawn(create_and_break)
|
|
|
|
engine = async_testing_engine(options={"async_creator": async_creator})
|
|
|
|
with expect_raises(exc.DBAPIError):
|
|
await engine.connect()
|
|
|
|
|
|
class AsyncCreatePoolTest(fixtures.TestBase):
|
|
@config.fixture
|
|
def mock_create(self):
|
|
with patch(
|
|
"sqlalchemy.ext.asyncio.engine._create_pool_from_url",
|
|
) as p:
|
|
yield p
|
|
|
|
def test_url_only(self, mock_create):
|
|
create_async_pool_from_url("sqlite://")
|
|
mock_create.assert_called_once_with("sqlite://", _is_async=True)
|
|
|
|
def test_pool_args(self, mock_create):
|
|
create_async_pool_from_url("sqlite://", foo=99, echo=True)
|
|
mock_create.assert_called_once_with(
|
|
"sqlite://", foo=99, echo=True, _is_async=True
|
|
)
|
|
|
|
|
|
class AsyncEventTest(EngineFixture):
|
|
"""The engine events all run in their normal synchronous context.
|
|
|
|
we do not provide an asyncio event interface at this time.
|
|
|
|
"""
|
|
|
|
__backend__ = True
|
|
|
|
@async_test
|
|
async def test_no_async_listeners(self, async_engine):
|
|
with testing.expect_raises_message(
|
|
NotImplementedError,
|
|
"asynchronous events are not implemented "
|
|
"at this time. Apply synchronous listeners to the "
|
|
"AsyncEngine.sync_engine or "
|
|
"AsyncConnection.sync_connection attributes.",
|
|
):
|
|
event.listen(async_engine, "before_cursor_execute", mock.Mock())
|
|
|
|
async with async_engine.connect() as conn:
|
|
with testing.expect_raises_message(
|
|
NotImplementedError,
|
|
"asynchronous events are not implemented "
|
|
"at this time. Apply synchronous listeners to the "
|
|
"AsyncEngine.sync_engine or "
|
|
"AsyncConnection.sync_connection attributes.",
|
|
):
|
|
event.listen(conn, "before_cursor_execute", mock.Mock())
|
|
|
|
@async_test
|
|
async def test_no_async_listeners_dialect_event(self, async_engine):
|
|
with testing.expect_raises_message(
|
|
NotImplementedError,
|
|
"asynchronous events are not implemented "
|
|
"at this time. Apply synchronous listeners to the "
|
|
"AsyncEngine.sync_engine or "
|
|
"AsyncConnection.sync_connection attributes.",
|
|
):
|
|
event.listen(async_engine, "do_execute", mock.Mock())
|
|
|
|
@async_test
|
|
async def test_no_async_listeners_pool_event(self, async_engine):
|
|
with testing.expect_raises_message(
|
|
NotImplementedError,
|
|
"asynchronous events are not implemented "
|
|
"at this time. Apply synchronous listeners to the "
|
|
"AsyncEngine.sync_engine or "
|
|
"AsyncConnection.sync_connection attributes.",
|
|
):
|
|
event.listen(async_engine, "checkout", mock.Mock())
|
|
|
|
@async_test
|
|
async def test_sync_before_cursor_execute_engine(self, async_engine):
|
|
canary = mock.Mock()
|
|
|
|
event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
|
|
|
|
async with async_engine.connect() as conn:
|
|
sync_conn = conn.sync_connection
|
|
await conn.execute(select(1))
|
|
|
|
s1 = str(select(1).compile(async_engine))
|
|
eq_(
|
|
canary.mock_calls,
|
|
[mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
|
|
)
|
|
|
|
@async_test
|
|
async def test_sync_before_cursor_execute_connection(self, async_engine):
|
|
canary = mock.Mock()
|
|
|
|
async with async_engine.connect() as conn:
|
|
sync_conn = conn.sync_connection
|
|
|
|
event.listen(
|
|
async_engine.sync_engine, "before_cursor_execute", canary
|
|
)
|
|
await conn.execute(select(1))
|
|
|
|
s1 = str(select(1).compile(async_engine))
|
|
eq_(
|
|
canary.mock_calls,
|
|
[mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
|
|
)
|
|
|
|
@async_test
|
|
async def test_event_on_sync_connection(self, async_engine):
|
|
canary = mock.Mock()
|
|
|
|
async with async_engine.connect() as conn:
|
|
event.listen(conn.sync_connection, "begin", canary)
|
|
async with conn.begin():
|
|
eq_(
|
|
canary.mock_calls,
|
|
[mock.call(conn.sync_connection)],
|
|
)
|
|
|
|
|
|
class AsyncInspection(EngineFixture):
|
|
__backend__ = True
|
|
|
|
@async_test
|
|
async def test_inspect_engine(self, async_engine):
|
|
with testing.expect_raises_message(
|
|
exc.NoInspectionAvailable,
|
|
"Inspection on an AsyncEngine is currently not supported.",
|
|
):
|
|
inspect(async_engine)
|
|
|
|
@async_test
|
|
async def test_inspect_connection(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
with testing.expect_raises_message(
|
|
exc.NoInspectionAvailable,
|
|
"Inspection on an AsyncConnection is currently not supported.",
|
|
):
|
|
inspect(conn)
|
|
|
|
|
|
class AsyncResultTest(EngineFixture):
|
|
__backend__ = True
|
|
__requires__ = ("server_side_cursors", "async_dialect")
|
|
|
|
@async_test
|
|
async def test_no_ss_cursor_w_execute(self, async_engine):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
conn = await conn.execution_options(stream_results=True)
|
|
with expect_raises_message(
|
|
async_exc.AsyncMethodRequired,
|
|
r"Can't use the AsyncConnection.execute\(\) method with a "
|
|
r"server-side cursor. Use the AsyncConnection.stream\(\) "
|
|
r"method for an async streaming result set.",
|
|
):
|
|
await conn.execute(select(users))
|
|
|
|
@async_test
|
|
async def test_no_ss_cursor_w_exec_driver_sql(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
conn = await conn.execution_options(stream_results=True)
|
|
with expect_raises_message(
|
|
async_exc.AsyncMethodRequired,
|
|
r"Can't use the AsyncConnection.exec_driver_sql\(\) "
|
|
r"method with a "
|
|
r"server-side cursor. Use the AsyncConnection.stream\(\) "
|
|
r"method for an async streaming result set.",
|
|
):
|
|
await conn.exec_driver_sql("SELECT * FROM users")
|
|
|
|
@async_test
|
|
async def test_stream_ctxmanager(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
conn = await conn.execution_options(stream_results=True)
|
|
|
|
async with conn.stream(select(self.tables.users)) as result:
|
|
assert not result._real_result._soft_closed
|
|
assert not result.closed
|
|
with expect_raises_message(Exception, "hi"):
|
|
i = 0
|
|
async for row in result:
|
|
if i > 2:
|
|
raise Exception("hi")
|
|
i += 1
|
|
assert result._real_result._soft_closed
|
|
assert result.closed
|
|
|
|
@async_test
|
|
async def test_stream_scalars_ctxmanager(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
conn = await conn.execution_options(stream_results=True)
|
|
|
|
async with conn.stream_scalars(
|
|
select(self.tables.users)
|
|
) as result:
|
|
assert not result._real_result._soft_closed
|
|
assert not result.closed
|
|
with expect_raises_message(Exception, "hi"):
|
|
i = 0
|
|
async for scalar in result:
|
|
if i > 2:
|
|
raise Exception("hi")
|
|
i += 1
|
|
assert result._real_result._soft_closed
|
|
assert result.closed
|
|
|
|
@testing.combinations(
|
|
(None,), ("scalars",), ("mappings",), argnames="filter_"
|
|
)
|
|
@async_test
|
|
async def test_all(self, async_engine, filter_):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(select(users))
|
|
|
|
if filter_ == "mappings":
|
|
result = result.mappings()
|
|
elif filter_ == "scalars":
|
|
result = result.scalars(1)
|
|
|
|
all_ = await result.all()
|
|
if filter_ == "mappings":
|
|
eq_(
|
|
all_,
|
|
[
|
|
{"user_id": i, "user_name": "name%d" % i}
|
|
for i in range(1, 20)
|
|
],
|
|
)
|
|
elif filter_ == "scalars":
|
|
eq_(
|
|
all_,
|
|
["name%d" % i for i in range(1, 20)],
|
|
)
|
|
else:
|
|
eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
|
|
|
|
@testing.combinations(
|
|
(None,),
|
|
("scalars",),
|
|
("stream_scalars",),
|
|
("mappings",),
|
|
argnames="filter_",
|
|
)
|
|
@async_test
|
|
async def test_aiter(self, async_engine, filter_):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
if filter_ == "stream_scalars":
|
|
result = await conn.stream_scalars(select(users.c.user_name))
|
|
else:
|
|
result = await conn.stream(select(users))
|
|
|
|
if filter_ == "mappings":
|
|
result = result.mappings()
|
|
elif filter_ == "scalars":
|
|
result = result.scalars(1)
|
|
|
|
rows = []
|
|
|
|
async for row in result:
|
|
rows.append(row)
|
|
|
|
if filter_ == "mappings":
|
|
eq_(
|
|
rows,
|
|
[
|
|
{"user_id": i, "user_name": "name%d" % i}
|
|
for i in range(1, 20)
|
|
],
|
|
)
|
|
elif filter_ in ("scalars", "stream_scalars"):
|
|
eq_(
|
|
rows,
|
|
["name%d" % i for i in range(1, 20)],
|
|
)
|
|
else:
|
|
eq_(rows, [(i, "name%d" % i) for i in range(1, 20)])
|
|
|
|
@testing.combinations((None,), ("mappings",), argnames="filter_")
|
|
@async_test
|
|
async def test_keys(self, async_engine, filter_):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(select(users))
|
|
|
|
if filter_ == "mappings":
|
|
result = result.mappings()
|
|
|
|
eq_(result.keys(), ["user_id", "user_name"])
|
|
|
|
await result.close()
|
|
|
|
@async_test
|
|
async def test_unique_all(self, async_engine):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(
|
|
union_all(select(users), select(users)).order_by(
|
|
users.c.user_id
|
|
)
|
|
)
|
|
|
|
all_ = await result.unique().all()
|
|
eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
|
|
|
|
@async_test
|
|
async def test_columns_all(self, async_engine):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(select(users))
|
|
|
|
all_ = await result.columns(1).all()
|
|
eq_(all_, [("name%d" % i,) for i in range(1, 20)])
|
|
|
|
@testing.combinations(
|
|
(None,), ("scalars",), ("mappings",), argnames="filter_"
|
|
)
|
|
@testing.combinations(None, 2, 5, 10, argnames="yield_per")
|
|
@testing.combinations("method", "opt", argnames="yield_per_type")
|
|
@async_test
|
|
async def test_partitions(
|
|
self, async_engine, filter_, yield_per, yield_per_type
|
|
):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
stmt = select(users)
|
|
if yield_per and yield_per_type == "opt":
|
|
stmt = stmt.execution_options(yield_per=yield_per)
|
|
result = await conn.stream(stmt)
|
|
|
|
if filter_ == "mappings":
|
|
result = result.mappings()
|
|
elif filter_ == "scalars":
|
|
result = result.scalars(1)
|
|
|
|
if yield_per and yield_per_type == "method":
|
|
result = result.yield_per(yield_per)
|
|
|
|
check_result = []
|
|
|
|
# stream() sets stream_results unconditionally
|
|
assert isinstance(
|
|
result._real_result.cursor_strategy,
|
|
_cursor.BufferedRowCursorFetchStrategy,
|
|
)
|
|
|
|
if yield_per:
|
|
partition_size = yield_per
|
|
|
|
eq_(result._real_result.cursor_strategy._bufsize, yield_per)
|
|
|
|
async for partition in result.partitions():
|
|
check_result.append(partition)
|
|
else:
|
|
eq_(result._real_result.cursor_strategy._bufsize, 5)
|
|
|
|
partition_size = 5
|
|
async for partition in result.partitions(partition_size):
|
|
check_result.append(partition)
|
|
|
|
ranges = [
|
|
(i, min(20, i + partition_size))
|
|
for i in range(1, 21, partition_size)
|
|
]
|
|
|
|
if filter_ == "mappings":
|
|
eq_(
|
|
check_result,
|
|
[
|
|
[
|
|
{"user_id": i, "user_name": "name%d" % i}
|
|
for i in range(a, b)
|
|
]
|
|
for (a, b) in ranges
|
|
],
|
|
)
|
|
elif filter_ == "scalars":
|
|
eq_(
|
|
check_result,
|
|
[["name%d" % i for i in range(a, b)] for (a, b) in ranges],
|
|
)
|
|
else:
|
|
eq_(
|
|
check_result,
|
|
[
|
|
[(i, "name%d" % i) for i in range(a, b)]
|
|
for (a, b) in ranges
|
|
],
|
|
)
|
|
|
|
@testing.combinations(
|
|
(None,), ("scalars",), ("mappings",), argnames="filter_"
|
|
)
|
|
@async_test
|
|
async def test_one_success(self, async_engine, filter_):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(
|
|
select(users).limit(1).order_by(users.c.user_name)
|
|
)
|
|
|
|
if filter_ == "mappings":
|
|
result = result.mappings()
|
|
elif filter_ == "scalars":
|
|
result = result.scalars()
|
|
u1 = await result.one()
|
|
|
|
if filter_ == "mappings":
|
|
eq_(u1, {"user_id": 1, "user_name": "name%d" % 1})
|
|
elif filter_ == "scalars":
|
|
eq_(u1, 1)
|
|
else:
|
|
eq_(u1, (1, "name%d" % 1))
|
|
|
|
@async_test
|
|
async def test_one_no_result(self, async_engine):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(
|
|
select(users).where(users.c.user_name == "nonexistent")
|
|
)
|
|
|
|
with expect_raises_message(
|
|
exc.NoResultFound, "No row was found when one was required"
|
|
):
|
|
await result.one()
|
|
|
|
@async_test
|
|
async def test_one_multi_result(self, async_engine):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(
|
|
select(users).where(users.c.user_name.in_(["name3", "name5"]))
|
|
)
|
|
|
|
with expect_raises_message(
|
|
exc.MultipleResultsFound,
|
|
"Multiple rows were found when exactly one was required",
|
|
):
|
|
await result.one()
|
|
|
|
@testing.combinations(("scalars",), ("stream_scalars",), argnames="case")
|
|
@async_test
|
|
async def test_scalars(self, async_engine, case):
|
|
users = self.tables.users
|
|
stmt = select(users).order_by(users.c.user_id)
|
|
async with async_engine.connect() as conn:
|
|
if case == "scalars":
|
|
result = (await conn.scalars(stmt)).all()
|
|
elif case == "stream_scalars":
|
|
result = await (await conn.stream_scalars(stmt)).all()
|
|
|
|
eq_(result, list(range(1, 20)))
|
|
|
|
@async_test
|
|
@testing.combinations(("stream",), ("stream_scalars",), argnames="case")
|
|
async def test_stream_fetch_many_not_complete(self, async_engine, case):
|
|
users = self.tables.users
|
|
big_query = select(users).join(users.alias("other"), true())
|
|
async with async_engine.connect() as conn:
|
|
if case == "stream":
|
|
result = await conn.stream(big_query)
|
|
elif case == "stream_scalars":
|
|
result = await conn.stream_scalars(big_query)
|
|
|
|
f1 = await result.fetchmany(5)
|
|
f2 = await result.fetchmany(10)
|
|
f3 = await result.fetchmany(7)
|
|
eq_(len(f1) + len(f2) + len(f3), 22)
|
|
|
|
res = await result.fetchall()
|
|
eq_(len(res), 19 * 19 - 22)
|
|
|
|
@async_test
|
|
@testing.combinations(("stream",), ("execute",), argnames="case")
|
|
async def test_cursor_close(self, async_engine, case):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
if case == "stream":
|
|
result = await conn.stream(select(users))
|
|
cursor = result._real_result.cursor
|
|
elif case == "execute":
|
|
result = await conn.execute(select(users))
|
|
cursor = result.cursor
|
|
|
|
await conn.run_sync(lambda _: cursor.close())
|
|
|
|
@async_test
|
|
@testing.variation("case", ["scalar_one", "scalar_one_or_none", "scalar"])
|
|
async def test_stream_scalar(self, async_engine, case: testing.Variation):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(
|
|
select(users).limit(1).order_by(users.c.user_name)
|
|
)
|
|
|
|
if case.scalar_one:
|
|
u1 = await result.scalar_one()
|
|
elif case.scalar_one_or_none:
|
|
u1 = await result.scalar_one_or_none()
|
|
elif case.scalar:
|
|
u1 = await result.scalar()
|
|
else:
|
|
case.fail()
|
|
|
|
eq_(u1, 1)
|
|
|
|
|
|
class TextSyncDBAPI(fixtures.TestBase):
|
|
__requires__ = ("asyncio",)
|
|
|
|
def test_sync_dbapi_raises(self):
|
|
with expect_raises_message(
|
|
exc.InvalidRequestError,
|
|
"The asyncio extension requires an async driver to be used.",
|
|
):
|
|
create_async_engine("sqlite:///:memory:")
|
|
|
|
@testing.fixture
|
|
def async_engine(self):
|
|
engine = create_engine("sqlite:///:memory:", future=True)
|
|
engine.dialect.is_async = True
|
|
engine.dialect.supports_server_side_cursors = True
|
|
with mock.patch.object(
|
|
engine.dialect.execution_ctx_cls,
|
|
"create_server_side_cursor",
|
|
engine.dialect.execution_ctx_cls.create_default_cursor,
|
|
):
|
|
yield _async_engine.AsyncEngine(engine)
|
|
|
|
@async_test
|
|
@combinations(
|
|
lambda conn: conn.exec_driver_sql("select 1"),
|
|
lambda conn: conn.stream(text("select 1")),
|
|
lambda conn: conn.execute(text("select 1")),
|
|
argnames="case",
|
|
)
|
|
async def test_sync_driver_execution(self, async_engine, case):
|
|
with expect_raises_message(
|
|
exc.AwaitRequired,
|
|
"The current operation required an async execution but none was",
|
|
):
|
|
async with async_engine.connect() as conn:
|
|
await case(conn)
|
|
|
|
@async_test
|
|
async def test_sync_driver_run_sync(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
res = await conn.run_sync(
|
|
lambda conn: conn.scalar(text("select 1"))
|
|
)
|
|
assert res == 1
|
|
assert await conn.run_sync(lambda _: 2) == 2
|
|
|
|
|
|
class AsyncProxyTest(EngineFixture, fixtures.TestBase):
|
|
@async_test
|
|
async def test_get_transaction(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
async with conn.begin() as trans:
|
|
is_(trans.connection, conn)
|
|
is_(conn.get_transaction(), trans)
|
|
|
|
@async_test
|
|
async def test_get_nested_transaction(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
async with conn.begin() as trans:
|
|
n1 = await conn.begin_nested()
|
|
|
|
is_(conn.get_nested_transaction(), n1)
|
|
|
|
n2 = await conn.begin_nested()
|
|
|
|
is_(conn.get_nested_transaction(), n2)
|
|
|
|
await n2.commit()
|
|
|
|
is_(conn.get_nested_transaction(), n1)
|
|
|
|
is_(conn.get_transaction(), trans)
|
|
|
|
@async_test
|
|
async def test_get_connection(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
is_(
|
|
AsyncConnection._retrieve_proxy_for_target(
|
|
conn.sync_connection
|
|
),
|
|
conn,
|
|
)
|
|
|
|
def test_regenerate_connection(self, connection):
|
|
async_connection = AsyncConnection._retrieve_proxy_for_target(
|
|
connection
|
|
)
|
|
|
|
a2 = AsyncConnection._retrieve_proxy_for_target(connection)
|
|
is_(async_connection, a2)
|
|
is_not(async_connection, None)
|
|
|
|
is_(async_connection.engine, a2.engine)
|
|
is_not(async_connection.engine, None)
|
|
|
|
@testing.requires.predictable_gc
|
|
@async_test
|
|
async def test_gc_engine(self, testing_engine):
|
|
ReversibleProxy._proxy_objects.clear()
|
|
|
|
eq_(len(ReversibleProxy._proxy_objects), 0)
|
|
|
|
async_engine = AsyncEngine(testing.db)
|
|
|
|
eq_(len(ReversibleProxy._proxy_objects), 1)
|
|
|
|
del async_engine
|
|
|
|
eq_(len(ReversibleProxy._proxy_objects), 0)
|
|
|
|
@testing.requires.predictable_gc
|
|
@async_test
|
|
async def test_gc_conn(self, testing_engine):
|
|
ReversibleProxy._proxy_objects.clear()
|
|
|
|
async_engine = AsyncEngine(testing.db)
|
|
|
|
eq_(len(ReversibleProxy._proxy_objects), 1)
|
|
|
|
async with async_engine.connect() as conn:
|
|
eq_(len(ReversibleProxy._proxy_objects), 2)
|
|
|
|
async with conn.begin() as trans:
|
|
eq_(len(ReversibleProxy._proxy_objects), 3)
|
|
|
|
del trans
|
|
|
|
del conn
|
|
|
|
eq_(len(ReversibleProxy._proxy_objects), 1)
|
|
|
|
del async_engine
|
|
|
|
eq_(len(ReversibleProxy._proxy_objects), 0)
|
|
|
|
def test_regen_conn_but_not_engine(self, async_engine):
|
|
with async_engine.sync_engine.connect() as sync_conn:
|
|
async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn)
|
|
async_conn2 = AsyncConnection._retrieve_proxy_for_target(sync_conn)
|
|
|
|
is_(async_conn, async_conn2)
|
|
is_(async_conn.engine, async_engine)
|
|
|
|
def test_regen_trans_but_not_conn(self, connection_no_trans):
|
|
sync_conn = connection_no_trans
|
|
|
|
async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn)
|
|
|
|
trans = sync_conn.begin()
|
|
|
|
async_t1 = async_conn.get_transaction()
|
|
|
|
is_(async_t1.connection, async_conn)
|
|
is_(async_t1.sync_transaction, trans)
|
|
|
|
async_t2 = async_conn.get_transaction()
|
|
is_(async_t1, async_t2)
|
|
|
|
|
|
class PoolRegenTest(EngineFixture):
|
|
@testing.requires.queue_pool
|
|
@async_test
|
|
@testing.variation("do_dispose", [True, False])
|
|
async def test_gather_after_dispose(self, testing_engine, do_dispose):
|
|
engine = testing_engine(
|
|
asyncio=True, options=dict(pool_size=10, max_overflow=10)
|
|
)
|
|
|
|
async def thing(engine):
|
|
async with engine.connect() as conn:
|
|
await conn.exec_driver_sql(str(select(1).compile(engine)))
|
|
|
|
if do_dispose:
|
|
await engine.dispose()
|
|
|
|
tasks = [thing(engine) for _ in range(10)]
|
|
await asyncio.gather(*tasks)
|