mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-09 10:20:00 -04:00
3d7022a9c8
Refactored all asyncio dialects so that exceptions which occur on failed connection attempts are appropriately wrapped with SQLAlchemy exception objects, allowing for consistent error handling. Fixes: #11956 Change-Id: Ic3fdbf334f059f92b03896b6429efa50968ca8a8
1631 lines
56 KiB
Python
1631 lines
56 KiB
Python
import asyncio
|
|
import contextlib
|
|
import inspect as stdlib_inspect
|
|
from unittest.mock import patch
|
|
|
|
from sqlalchemy import AssertionPool
|
|
from sqlalchemy import Column
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy import delete
|
|
from sqlalchemy import event
|
|
from sqlalchemy import exc
|
|
from sqlalchemy import func
|
|
from sqlalchemy import inspect
|
|
from sqlalchemy import Integer
|
|
from sqlalchemy import NullPool
|
|
from sqlalchemy import QueuePool
|
|
from sqlalchemy import select
|
|
from sqlalchemy import SingletonThreadPool
|
|
from sqlalchemy import StaticPool
|
|
from sqlalchemy import String
|
|
from sqlalchemy import Table
|
|
from sqlalchemy import testing
|
|
from sqlalchemy import text
|
|
from sqlalchemy import true
|
|
from sqlalchemy import union_all
|
|
from sqlalchemy.engine import cursor as _cursor
|
|
from sqlalchemy.ext.asyncio import async_engine_from_config
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
from sqlalchemy.ext.asyncio import create_async_pool_from_url
|
|
from sqlalchemy.ext.asyncio import engine as _async_engine
|
|
from sqlalchemy.ext.asyncio import exc as async_exc
|
|
from sqlalchemy.ext.asyncio import exc as asyncio_exc
|
|
from sqlalchemy.ext.asyncio.base import ReversibleProxy
|
|
from sqlalchemy.ext.asyncio.engine import AsyncConnection
|
|
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
|
from sqlalchemy.pool import AsyncAdaptedQueuePool
|
|
from sqlalchemy.testing import assertions
|
|
from sqlalchemy.testing import async_test
|
|
from sqlalchemy.testing import combinations
|
|
from sqlalchemy.testing import config
|
|
from sqlalchemy.testing import engines
|
|
from sqlalchemy.testing import eq_
|
|
from sqlalchemy.testing import eq_regex
|
|
from sqlalchemy.testing import expect_raises
|
|
from sqlalchemy.testing import expect_raises_message
|
|
from sqlalchemy.testing import fixtures
|
|
from sqlalchemy.testing import is_
|
|
from sqlalchemy.testing import is_false
|
|
from sqlalchemy.testing import is_none
|
|
from sqlalchemy.testing import is_not
|
|
from sqlalchemy.testing import is_true
|
|
from sqlalchemy.testing import mock
|
|
from sqlalchemy.testing import ne_
|
|
from sqlalchemy.util import greenlet_spawn
|
|
|
|
|
|
class AsyncFixture:
|
|
@config.fixture(
|
|
params=[
|
|
(rollback, run_second_execute, begin_nested)
|
|
for rollback in (True, False)
|
|
for run_second_execute in (True, False)
|
|
for begin_nested in (True, False)
|
|
]
|
|
)
|
|
def async_trans_ctx_manager_fixture(self, request, metadata):
|
|
rollback, run_second_execute, begin_nested = request.param
|
|
|
|
t = Table("test", metadata, Column("data", Integer))
|
|
eng = getattr(self, "bind", None) or config.db
|
|
|
|
t.create(eng)
|
|
|
|
async def run_test(subject, trans_on_subject, execute_on_subject):
|
|
async with subject.begin() as trans:
|
|
if begin_nested:
|
|
if not config.requirements.savepoints.enabled:
|
|
config.skip_test("savepoints not enabled")
|
|
if execute_on_subject:
|
|
nested_trans = subject.begin_nested()
|
|
else:
|
|
nested_trans = trans.begin_nested()
|
|
|
|
async with nested_trans:
|
|
if execute_on_subject:
|
|
await subject.execute(t.insert(), {"data": 10})
|
|
else:
|
|
await trans.execute(t.insert(), {"data": 10})
|
|
|
|
# for nested trans, we always commit/rollback on the
|
|
# "nested trans" object itself.
|
|
# only Session(future=False) will affect savepoint
|
|
# transaction for session.commit/rollback
|
|
|
|
if rollback:
|
|
await nested_trans.rollback()
|
|
else:
|
|
await nested_trans.commit()
|
|
|
|
if run_second_execute:
|
|
with assertions.expect_raises_message(
|
|
exc.InvalidRequestError,
|
|
"Can't operate on closed transaction "
|
|
"inside context manager. Please complete the "
|
|
"context manager "
|
|
"before emitting further commands.",
|
|
):
|
|
if execute_on_subject:
|
|
await subject.execute(
|
|
t.insert(), {"data": 12}
|
|
)
|
|
else:
|
|
await trans.execute(
|
|
t.insert(), {"data": 12}
|
|
)
|
|
|
|
# outside the nested trans block, but still inside the
|
|
# transaction block, we can run SQL, and it will be
|
|
# committed
|
|
if execute_on_subject:
|
|
await subject.execute(t.insert(), {"data": 14})
|
|
else:
|
|
await trans.execute(t.insert(), {"data": 14})
|
|
|
|
else:
|
|
if execute_on_subject:
|
|
await subject.execute(t.insert(), {"data": 10})
|
|
else:
|
|
await trans.execute(t.insert(), {"data": 10})
|
|
|
|
if trans_on_subject:
|
|
if rollback:
|
|
await subject.rollback()
|
|
else:
|
|
await subject.commit()
|
|
else:
|
|
if rollback:
|
|
await trans.rollback()
|
|
else:
|
|
await trans.commit()
|
|
|
|
if run_second_execute:
|
|
with assertions.expect_raises_message(
|
|
exc.InvalidRequestError,
|
|
"Can't operate on closed transaction inside "
|
|
"context "
|
|
"manager. Please complete the context manager "
|
|
"before emitting further commands.",
|
|
):
|
|
if execute_on_subject:
|
|
await subject.execute(t.insert(), {"data": 12})
|
|
else:
|
|
await trans.execute(t.insert(), {"data": 12})
|
|
|
|
expected_committed = 0
|
|
if begin_nested:
|
|
# begin_nested variant, we inserted a row after the nested
|
|
# block
|
|
expected_committed += 1
|
|
if not rollback:
|
|
# not rollback variant, our row inserted in the target
|
|
# block itself would be committed
|
|
expected_committed += 1
|
|
|
|
if execute_on_subject:
|
|
eq_(
|
|
await subject.scalar(select(func.count()).select_from(t)),
|
|
expected_committed,
|
|
)
|
|
else:
|
|
with subject.connect() as conn:
|
|
eq_(
|
|
await conn.scalar(select(func.count()).select_from(t)),
|
|
expected_committed,
|
|
)
|
|
|
|
return run_test
|
|
|
|
|
|
class EngineFixture(AsyncFixture, fixtures.TablesTest):
|
|
__requires__ = ("async_dialect",)
|
|
|
|
@testing.fixture
|
|
def async_engine(self):
|
|
return engines.testing_engine(asyncio=True, transfer_staticpool=True)
|
|
|
|
@testing.fixture
|
|
def async_connection(self, async_engine):
|
|
with async_engine.sync_engine.connect() as conn:
|
|
yield AsyncConnection(async_engine, conn)
|
|
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
Table(
|
|
"users",
|
|
metadata,
|
|
Column("user_id", Integer, primary_key=True, autoincrement=False),
|
|
Column("user_name", String(20)),
|
|
)
|
|
|
|
@classmethod
|
|
def insert_data(cls, connection):
|
|
users = cls.tables.users
|
|
connection.execute(
|
|
users.insert(),
|
|
[{"user_id": i, "user_name": "name%d" % i} for i in range(1, 20)],
|
|
)
|
|
|
|
|
|
class AsyncEngineTest(EngineFixture):
|
|
__backend__ = True
|
|
|
|
@testing.fails("the failure is the test")
|
|
@async_test
|
|
async def test_we_are_definitely_running_async_tests(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
eq_(await conn.scalar(text("select 1")), 2)
|
|
|
|
@async_test
|
|
async def test_interrupt_ctxmanager_connection(
|
|
self, async_engine, async_trans_ctx_manager_fixture
|
|
):
|
|
fn = async_trans_ctx_manager_fixture
|
|
|
|
async with async_engine.connect() as conn:
|
|
await fn(conn, trans_on_subject=False, execute_on_subject=True)
|
|
|
|
def test_proxied_attrs_engine(self, async_engine):
|
|
sync_engine = async_engine.sync_engine
|
|
|
|
is_(async_engine.url, sync_engine.url)
|
|
is_(async_engine.pool, sync_engine.pool)
|
|
is_(async_engine.dialect, sync_engine.dialect)
|
|
eq_(async_engine.name, sync_engine.name)
|
|
eq_(async_engine.driver, sync_engine.driver)
|
|
eq_(async_engine.echo, sync_engine.echo)
|
|
|
|
@async_test
|
|
async def test_run_async(self, async_engine):
|
|
async def test_meth(async_driver_connection):
|
|
# there's no method that's guaranteed to be on every
|
|
# driver, so just stringify it and compare that to the
|
|
# outside
|
|
return str(async_driver_connection)
|
|
|
|
def run_sync_to_async(connection):
|
|
connection_fairy = connection.connection
|
|
async_return = connection_fairy.run_async(
|
|
lambda driver_connection: test_meth(driver_connection)
|
|
)
|
|
assert not stdlib_inspect.iscoroutine(async_return)
|
|
return async_return
|
|
|
|
async with async_engine.connect() as conn:
|
|
driver_connection = (
|
|
await conn.get_raw_connection()
|
|
).driver_connection
|
|
res = await conn.run_sync(run_sync_to_async)
|
|
assert not stdlib_inspect.iscoroutine(res)
|
|
eq_(res, str(driver_connection))
|
|
|
|
@async_test
|
|
async def test_engine_eq_ne(self, async_engine):
|
|
e2 = _async_engine.AsyncEngine(async_engine.sync_engine)
|
|
e3 = engines.testing_engine(asyncio=True, transfer_staticpool=True)
|
|
|
|
eq_(async_engine, e2)
|
|
ne_(async_engine, e3)
|
|
|
|
is_false(async_engine == None)
|
|
|
|
def test_no_attach_to_event_loop(self, testing_engine):
|
|
"""test #6409
|
|
|
|
note this test does not seem to trigger the bug that was originally
|
|
fixed in #6409, when using python 3.10 and higher (the original issue
|
|
can repro in 3.8 at least, based on my testing). It's been simplified
|
|
to no longer explicitly create a new loop, asyncio.run() already
|
|
creates a new loop.
|
|
|
|
"""
|
|
|
|
import asyncio
|
|
import threading
|
|
|
|
errs = []
|
|
|
|
def go():
|
|
async def main():
|
|
tasks = [task() for _ in range(2)]
|
|
|
|
await asyncio.gather(*tasks)
|
|
await engine.dispose()
|
|
|
|
async def task():
|
|
async with engine.begin() as connection:
|
|
result = await connection.execute(select(1))
|
|
result.all()
|
|
|
|
try:
|
|
engine = engines.testing_engine(
|
|
asyncio=True, transfer_staticpool=False
|
|
)
|
|
|
|
asyncio.run(main())
|
|
except Exception as err:
|
|
errs.append(err)
|
|
|
|
t = threading.Thread(target=go)
|
|
t.start()
|
|
t.join()
|
|
|
|
if errs:
|
|
raise errs[0]
|
|
|
|
@async_test
|
|
async def test_connection_info(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
conn.info["foo"] = "bar"
|
|
|
|
eq_(conn.sync_connection.info, {"foo": "bar"})
|
|
|
|
@async_test
|
|
async def test_connection_eq_ne(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
c2 = _async_engine.AsyncConnection(
|
|
async_engine, conn.sync_connection
|
|
)
|
|
|
|
eq_(conn, c2)
|
|
|
|
async with async_engine.connect() as c3:
|
|
ne_(conn, c3)
|
|
|
|
is_false(conn == None)
|
|
|
|
@async_test
|
|
async def test_transaction_eq_ne(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
t1 = await conn.begin()
|
|
|
|
t2 = _async_engine.AsyncTransaction._regenerate_proxy_for_target(
|
|
t1._proxied
|
|
)
|
|
|
|
eq_(t1, t2)
|
|
|
|
is_false(t1 == None)
|
|
|
|
@testing.variation("simulate_gc", [True, False])
|
|
def test_appropriate_warning_for_gced_connection(
|
|
self, async_engine, simulate_gc
|
|
):
|
|
"""test #9237 which builds upon a not really complete solution
|
|
added for #8419."""
|
|
|
|
async def go():
|
|
conn = await async_engine.connect()
|
|
await conn.begin()
|
|
await conn.execute(select(1))
|
|
pool_connection = await conn.get_raw_connection()
|
|
return pool_connection
|
|
|
|
from sqlalchemy.util.concurrency import await_
|
|
|
|
pool_connection = await_(go())
|
|
|
|
rec = pool_connection._connection_record
|
|
ref = rec.fairy_ref
|
|
pool = pool_connection._pool
|
|
echo = False
|
|
|
|
if simulate_gc:
|
|
# not using expect_warnings() here because we also want to do a
|
|
# negative test for warnings, and we want to absolutely make sure
|
|
# the thing here that emits the warning is the correct path
|
|
from sqlalchemy.pool.base import _finalize_fairy
|
|
|
|
with (
|
|
mock.patch.object(
|
|
pool._dialect,
|
|
"do_rollback",
|
|
mock.Mock(side_effect=Exception("can't run rollback")),
|
|
),
|
|
mock.patch("sqlalchemy.util.warn") as m,
|
|
):
|
|
_finalize_fairy(
|
|
None, rec, pool, ref, echo, transaction_was_reset=False
|
|
)
|
|
|
|
if async_engine.dialect.has_terminate:
|
|
expected_msg = (
|
|
"The garbage collector is trying to clean up.*which will "
|
|
"be terminated."
|
|
)
|
|
else:
|
|
expected_msg = (
|
|
"The garbage collector is trying to clean up.*which will "
|
|
"be dropped, as it cannot be safely terminated."
|
|
)
|
|
|
|
# [1] == .args, not in 3.7
|
|
eq_regex(m.mock_calls[0][1][0], expected_msg)
|
|
else:
|
|
# the warning emitted by the pool is inside of a try/except:
|
|
# so it's impossible right now to have this warning "raise".
|
|
# for now, test by using mock.patch
|
|
|
|
with mock.patch("sqlalchemy.util.warn") as m:
|
|
pool_connection.close()
|
|
|
|
eq_(m.mock_calls, [])
|
|
|
|
@async_test
|
|
async def test_statement_compile(self, async_engine):
|
|
stmt = str(select(1).compile(async_engine))
|
|
async with async_engine.connect() as conn:
|
|
eq_(str(select(1).compile(conn)), stmt)
|
|
|
|
def test_clear_compiled_cache(self, async_engine):
|
|
async_engine.sync_engine._compiled_cache["foo"] = "bar"
|
|
eq_(async_engine.sync_engine._compiled_cache["foo"], "bar")
|
|
async_engine.clear_compiled_cache()
|
|
assert "foo" not in async_engine.sync_engine._compiled_cache
|
|
|
|
def test_execution_options(self, async_engine):
|
|
a2 = async_engine.execution_options(foo="bar")
|
|
assert isinstance(a2, _async_engine.AsyncEngine)
|
|
eq_(a2.sync_engine._execution_options, {"foo": "bar"})
|
|
eq_(async_engine.sync_engine._execution_options, {})
|
|
|
|
"""
|
|
|
|
attr uri, pool, dialect, engine, name, driver, echo
|
|
methods clear_compiled_cache, update_execution_options,
|
|
execution_options, get_execution_options, dispose
|
|
|
|
"""
|
|
|
|
@async_test
|
|
async def test_proxied_attrs_connection(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
sync_conn = conn.sync_connection
|
|
|
|
is_(conn.engine, async_engine)
|
|
is_(conn.closed, sync_conn.closed)
|
|
is_(conn.dialect, async_engine.sync_engine.dialect)
|
|
eq_(
|
|
conn.default_isolation_level, sync_conn.default_isolation_level
|
|
)
|
|
|
|
@async_test
|
|
async def test_transaction_accessor(self, async_connection):
|
|
conn = async_connection
|
|
is_none(conn.get_transaction())
|
|
is_false(conn.in_transaction())
|
|
is_false(conn.in_nested_transaction())
|
|
|
|
trans = await conn.begin()
|
|
|
|
is_true(conn.in_transaction())
|
|
is_false(conn.in_nested_transaction())
|
|
|
|
is_(trans.sync_transaction, conn.get_transaction().sync_transaction)
|
|
|
|
nested = await conn.begin_nested()
|
|
|
|
is_true(conn.in_transaction())
|
|
is_true(conn.in_nested_transaction())
|
|
|
|
is_(
|
|
conn.get_nested_transaction().sync_transaction,
|
|
nested.sync_transaction,
|
|
)
|
|
eq_(conn.get_nested_transaction(), nested)
|
|
|
|
is_(trans.sync_transaction, conn.get_transaction().sync_transaction)
|
|
|
|
await nested.commit()
|
|
|
|
is_true(conn.in_transaction())
|
|
is_false(conn.in_nested_transaction())
|
|
|
|
await trans.rollback()
|
|
|
|
is_none(conn.get_transaction())
|
|
is_false(conn.in_transaction())
|
|
is_false(conn.in_nested_transaction())
|
|
|
|
@testing.requires.queue_pool
|
|
@async_test
|
|
async def test_invalidate(self, async_engine):
|
|
conn = await async_engine.connect()
|
|
|
|
is_(conn.invalidated, False)
|
|
|
|
connection_fairy = await conn.get_raw_connection()
|
|
is_(connection_fairy.is_valid, True)
|
|
dbapi_connection = connection_fairy.dbapi_connection
|
|
|
|
await conn.invalidate()
|
|
|
|
if testing.against("postgresql+asyncpg"):
|
|
assert dbapi_connection._connection.is_closed()
|
|
|
|
new_fairy = await conn.get_raw_connection()
|
|
is_not(new_fairy.dbapi_connection, dbapi_connection)
|
|
is_not(new_fairy, connection_fairy)
|
|
is_(new_fairy.is_valid, True)
|
|
is_(connection_fairy.is_valid, False)
|
|
await conn.close()
|
|
|
|
@async_test
|
|
async def test_get_dbapi_connection_raise(self, async_connection):
|
|
with testing.expect_raises_message(
|
|
exc.InvalidRequestError,
|
|
"AsyncConnection.connection accessor is not "
|
|
"implemented as the attribute",
|
|
):
|
|
async_connection.connection
|
|
|
|
@async_test
|
|
async def test_get_raw_connection(self, async_connection):
|
|
pooled = await async_connection.get_raw_connection()
|
|
is_(pooled, async_connection.sync_connection.connection)
|
|
|
|
@async_test
|
|
async def test_isolation_level(self, async_connection):
|
|
conn = async_connection
|
|
sync_isolation_level = await greenlet_spawn(
|
|
conn.sync_connection.get_isolation_level
|
|
)
|
|
isolation_level = await conn.get_isolation_level()
|
|
|
|
eq_(isolation_level, sync_isolation_level)
|
|
|
|
await conn.execution_options(isolation_level="SERIALIZABLE")
|
|
isolation_level = await conn.get_isolation_level()
|
|
|
|
eq_(isolation_level, "SERIALIZABLE")
|
|
|
|
@testing.combinations(
|
|
(
|
|
AsyncAdaptedQueuePool,
|
|
True,
|
|
),
|
|
(
|
|
QueuePool,
|
|
False,
|
|
),
|
|
(NullPool, True),
|
|
(SingletonThreadPool, False),
|
|
(StaticPool, True),
|
|
(AssertionPool, True),
|
|
argnames="pool_cls,should_work",
|
|
)
|
|
@testing.variation("instantiate", [True, False])
|
|
@async_test
|
|
async def test_pool_classes(
|
|
self, async_testing_engine, pool_cls, instantiate, should_work
|
|
):
|
|
"""test #8771"""
|
|
if instantiate:
|
|
if pool_cls in (QueuePool, AsyncAdaptedQueuePool):
|
|
pool = pool_cls(creator=testing.db.pool._creator, timeout=10)
|
|
else:
|
|
pool = pool_cls(
|
|
creator=testing.db.pool._creator,
|
|
)
|
|
|
|
options = {"pool": pool}
|
|
else:
|
|
if pool_cls in (QueuePool, AsyncAdaptedQueuePool):
|
|
options = {"poolclass": pool_cls, "pool_timeout": 10}
|
|
else:
|
|
options = {"poolclass": pool_cls}
|
|
|
|
if not should_work:
|
|
with expect_raises_message(
|
|
exc.ArgumentError,
|
|
f"Pool class {pool_cls.__name__} "
|
|
"cannot be used with asyncio engine",
|
|
):
|
|
async_testing_engine(options=options)
|
|
return
|
|
|
|
e = async_testing_engine(options=options)
|
|
|
|
if pool_cls is AssertionPool:
|
|
async with e.connect() as conn:
|
|
result = await conn.scalar(select(1))
|
|
eq_(result, 1)
|
|
return
|
|
|
|
async def go():
|
|
async with e.connect() as conn:
|
|
result = await conn.scalar(select(1))
|
|
eq_(result, 1)
|
|
return result
|
|
|
|
eq_(await asyncio.gather(*[go() for i in range(10)]), [1] * 10)
|
|
|
|
def test_cant_use_async_pool_w_create_engine(self):
|
|
"""supplemental test for #8771"""
|
|
|
|
with expect_raises_message(
|
|
exc.ArgumentError,
|
|
"Pool class AsyncAdaptedQueuePool "
|
|
"cannot be used with non-asyncio engine",
|
|
):
|
|
create_engine("sqlite://", poolclass=AsyncAdaptedQueuePool)
|
|
|
|
@testing.requires.queue_pool
|
|
@async_test
|
|
async def test_dispose(self, async_engine):
|
|
c1 = await async_engine.connect()
|
|
c2 = await async_engine.connect()
|
|
|
|
await c1.close()
|
|
await c2.close()
|
|
|
|
p1 = async_engine.pool
|
|
|
|
if isinstance(p1, AsyncAdaptedQueuePool):
|
|
eq_(async_engine.pool.checkedin(), 2)
|
|
|
|
await async_engine.dispose()
|
|
if isinstance(p1, AsyncAdaptedQueuePool):
|
|
eq_(async_engine.pool.checkedin(), 0)
|
|
is_not(p1, async_engine.pool)
|
|
|
|
@testing.requires.queue_pool
|
|
@async_test
|
|
async def test_dispose_no_close(self, async_engine):
|
|
c1 = await async_engine.connect()
|
|
c2 = await async_engine.connect()
|
|
|
|
await c1.close()
|
|
await c2.close()
|
|
|
|
p1 = async_engine.pool
|
|
|
|
if isinstance(p1, AsyncAdaptedQueuePool):
|
|
eq_(async_engine.pool.checkedin(), 2)
|
|
|
|
await async_engine.dispose(close=False)
|
|
|
|
# TODO: test that DBAPI connection was not closed
|
|
|
|
if isinstance(p1, AsyncAdaptedQueuePool):
|
|
eq_(async_engine.pool.checkedin(), 0)
|
|
is_not(p1, async_engine.pool)
|
|
|
|
@testing.requires.independent_connections
|
|
@async_test
|
|
async def test_init_once_concurrency(self, async_engine):
|
|
async with async_engine.connect() as c1, async_engine.connect() as c2:
|
|
coro = asyncio.gather(c1.scalar(select(1)), c2.scalar(select(2)))
|
|
eq_(await coro, [1, 2])
|
|
|
|
@async_test
|
|
async def test_connect_ctxmanager(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.execute(select(1))
|
|
eq_(result.scalar(), 1)
|
|
|
|
@async_test
|
|
async def test_connect_plain(self, async_engine):
|
|
conn = await async_engine.connect()
|
|
try:
|
|
result = await conn.execute(select(1))
|
|
eq_(result.scalar(), 1)
|
|
finally:
|
|
await conn.close()
|
|
|
|
@async_test
|
|
async def test_connection_not_started(self, async_engine):
|
|
conn = async_engine.connect()
|
|
testing.assert_raises_message(
|
|
asyncio_exc.AsyncContextNotStarted,
|
|
"AsyncConnection context has not been started and "
|
|
"object has not been awaited.",
|
|
conn.begin,
|
|
)
|
|
|
|
@async_test
|
|
async def test_transaction_commit(self, async_engine):
|
|
users = self.tables.users
|
|
|
|
async with async_engine.begin() as conn:
|
|
await conn.execute(delete(users))
|
|
|
|
async with async_engine.connect() as conn:
|
|
eq_(await conn.scalar(select(func.count(users.c.user_id))), 0)
|
|
|
|
@async_test
|
|
async def test_savepoint_rollback_noctx(self, async_engine):
|
|
users = self.tables.users
|
|
|
|
async with async_engine.begin() as conn:
|
|
savepoint = await conn.begin_nested()
|
|
await conn.execute(delete(users))
|
|
await savepoint.rollback()
|
|
|
|
async with async_engine.connect() as conn:
|
|
eq_(await conn.scalar(select(func.count(users.c.user_id))), 19)
|
|
|
|
@async_test
|
|
async def test_savepoint_commit_noctx(self, async_engine):
|
|
users = self.tables.users
|
|
|
|
async with async_engine.begin() as conn:
|
|
savepoint = await conn.begin_nested()
|
|
await conn.execute(delete(users))
|
|
await savepoint.commit()
|
|
|
|
async with async_engine.connect() as conn:
|
|
eq_(await conn.scalar(select(func.count(users.c.user_id))), 0)
|
|
|
|
@async_test
|
|
async def test_transaction_rollback(self, async_engine):
|
|
users = self.tables.users
|
|
|
|
async with async_engine.connect() as conn:
|
|
trans = conn.begin()
|
|
await trans.start()
|
|
await conn.execute(delete(users))
|
|
await trans.rollback()
|
|
|
|
async with async_engine.connect() as conn:
|
|
eq_(await conn.scalar(select(func.count(users.c.user_id))), 19)
|
|
|
|
@async_test
|
|
async def test_conn_transaction_not_started(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
trans = conn.begin()
|
|
with expect_raises_message(
|
|
asyncio_exc.AsyncContextNotStarted,
|
|
"AsyncTransaction context has not been started "
|
|
"and object has not been awaited.",
|
|
):
|
|
await trans.rollback(),
|
|
|
|
@testing.requires.queue_pool
|
|
@async_test
|
|
async def test_pool_exhausted_some_timeout(self, async_engine):
|
|
engine = create_async_engine(
|
|
testing.db.url,
|
|
pool_size=1,
|
|
max_overflow=0,
|
|
pool_timeout=0.1,
|
|
)
|
|
async with engine.connect():
|
|
with expect_raises(exc.TimeoutError):
|
|
await engine.connect()
|
|
|
|
@async_test
|
|
async def test_engine_aclose(self, async_engine):
|
|
users = self.tables.users
|
|
async with contextlib.aclosing(async_engine.connect()) as conn:
|
|
await conn.start()
|
|
trans = conn.begin()
|
|
await trans.start()
|
|
await conn.execute(delete(users))
|
|
await trans.commit()
|
|
assert conn.closed
|
|
|
|
@testing.requires.queue_pool
|
|
@async_test
|
|
async def test_pool_exhausted_no_timeout(self, async_engine):
|
|
engine = create_async_engine(
|
|
testing.db.url,
|
|
pool_size=1,
|
|
max_overflow=0,
|
|
pool_timeout=0,
|
|
)
|
|
async with engine.connect():
|
|
with expect_raises(exc.TimeoutError):
|
|
await engine.connect()
|
|
|
|
@async_test
|
|
async def test_create_async_engine_server_side_cursor(self, async_engine):
|
|
with expect_raises_message(
|
|
asyncio_exc.AsyncMethodRequired,
|
|
"Can't set server_side_cursors for async engine globally",
|
|
):
|
|
create_async_engine(
|
|
testing.db.url,
|
|
server_side_cursors=True,
|
|
)
|
|
|
|
def test_async_engine_from_config(self):
|
|
config = {
|
|
"sqlalchemy.url": testing.db.url.render_as_string(
|
|
hide_password=False
|
|
),
|
|
"sqlalchemy.echo": "true",
|
|
}
|
|
engine = async_engine_from_config(config)
|
|
assert engine.url == testing.db.url
|
|
assert engine.echo is True
|
|
assert engine.dialect.is_async is True
|
|
|
|
def test_async_creator_and_creator(self):
|
|
async def ac():
|
|
return None
|
|
|
|
def c():
|
|
return None
|
|
|
|
with expect_raises_message(
|
|
exc.ArgumentError,
|
|
"Can only specify one of 'async_creator' or 'creator', "
|
|
"not both.",
|
|
):
|
|
create_async_engine(testing.db.url, creator=c, async_creator=ac)
|
|
|
|
@async_test
|
|
async def test_async_creator_invoked(self, async_testing_engine):
|
|
"""test for #8215"""
|
|
|
|
existing_creator = testing.db.pool._creator
|
|
|
|
async def async_creator():
|
|
sync_conn = await greenlet_spawn(existing_creator)
|
|
return sync_conn.driver_connection
|
|
|
|
async_creator = mock.Mock(side_effect=async_creator)
|
|
|
|
eq_(async_creator.mock_calls, [])
|
|
|
|
engine = async_testing_engine(options={"async_creator": async_creator})
|
|
async with engine.connect() as conn:
|
|
result = await conn.scalar(select(1))
|
|
eq_(result, 1)
|
|
|
|
eq_(async_creator.mock_calls, [mock.call()])
|
|
|
|
@async_test
|
|
async def test_async_creator_accepts_args_if_called_directly(
|
|
self, async_testing_engine
|
|
):
|
|
"""supplemental test for #8215.
|
|
|
|
The "async_creator" passed to create_async_engine() is expected to take
|
|
no arguments, the same way as "creator" passed to create_engine()
|
|
works.
|
|
|
|
However, the ultimate "async_creator" received by the sync-emulating
|
|
DBAPI *does* take arguments in its ``.connect()`` method, which will be
|
|
all the other arguments passed to ``.connect()``. This functionality
|
|
is not currently used, however was decided that the creator should
|
|
internally work this way for improved flexibility; see
|
|
https://github.com/sqlalchemy/sqlalchemy/issues/8215#issuecomment-1181791539.
|
|
That contract is tested here.
|
|
|
|
""" # noqa: E501
|
|
|
|
existing_creator = testing.db.pool._creator
|
|
|
|
async def async_creator(x, y, *, z=None):
|
|
sync_conn = await greenlet_spawn(existing_creator)
|
|
return sync_conn.driver_connection
|
|
|
|
async_creator = mock.Mock(side_effect=async_creator)
|
|
|
|
async_dbapi = testing.db.dialect.loaded_dbapi
|
|
|
|
conn = await greenlet_spawn(
|
|
async_dbapi.connect, 5, y=10, z=8, async_creator_fn=async_creator
|
|
)
|
|
try:
|
|
eq_(async_creator.mock_calls, [mock.call(5, y=10, z=8)])
|
|
finally:
|
|
await greenlet_spawn(conn.close)
|
|
|
|
@testing.combinations("stream", "stream_scalars", argnames="method")
|
|
@async_test
|
|
async def test_server_side_required_for_scalars(
|
|
self, async_engine, method
|
|
):
|
|
with mock.patch.object(
|
|
async_engine.dialect, "supports_server_side_cursors", False
|
|
):
|
|
async with async_engine.connect() as c:
|
|
with expect_raises_message(
|
|
exc.InvalidRequestError,
|
|
"Cant use `stream` or `stream_scalars` with the current "
|
|
"dialect since it does not support server side cursors.",
|
|
):
|
|
if method == "stream":
|
|
await c.stream(select(1))
|
|
elif method == "stream_scalars":
|
|
await c.stream_scalars(select(1))
|
|
else:
|
|
testing.fail(method)
|
|
|
|
@async_test
|
|
@testing.requires.async_dialect_with_await_close
|
|
async def test_active_await_close(self, async_engine):
|
|
select_one_sql = select(1).compile(async_engine.sync_engine).string
|
|
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.exec_driver_sql(select_one_sql)
|
|
eq_(result.scalar_one(), 1)
|
|
driver_cursor = result.context.cursor._cursor
|
|
|
|
with expect_raises(Exception):
|
|
# because the cursor should be closed
|
|
await driver_cursor.execute(select_one_sql)
|
|
|
|
@async_test
|
|
async def test_async_creator_handle_error(self, async_testing_engine):
|
|
"""test for #11956"""
|
|
|
|
existing_creator = testing.db.pool._creator
|
|
|
|
def create_and_break():
|
|
sync_conn = existing_creator()
|
|
cursor = sync_conn.cursor()
|
|
|
|
# figure out a way to get a native driver exception. This really
|
|
# only applies to asyncpg where we rewrite the exception
|
|
# hierarchy with our own emulated exception; other backends raise
|
|
# standard DBAPI exceptions (with some buggy cases here and there
|
|
# which they miss) even though they are async
|
|
try:
|
|
cursor.execute("this will raise an error")
|
|
except Exception as possibly_emulated_error:
|
|
if isinstance(
|
|
possibly_emulated_error, exc.EmulatedDBAPIException
|
|
):
|
|
raise possibly_emulated_error.driver_exception
|
|
else:
|
|
raise possibly_emulated_error
|
|
|
|
async def async_creator():
|
|
return await greenlet_spawn(create_and_break)
|
|
|
|
engine = async_testing_engine(options={"async_creator": async_creator})
|
|
|
|
with expect_raises(exc.DBAPIError):
|
|
await engine.connect()
|
|
|
|
|
|
class AsyncCreatePoolTest(fixtures.TestBase):
|
|
@config.fixture
|
|
def mock_create(self):
|
|
with patch(
|
|
"sqlalchemy.ext.asyncio.engine._create_pool_from_url",
|
|
) as p:
|
|
yield p
|
|
|
|
def test_url_only(self, mock_create):
|
|
create_async_pool_from_url("sqlite://")
|
|
mock_create.assert_called_once_with("sqlite://", _is_async=True)
|
|
|
|
def test_pool_args(self, mock_create):
|
|
create_async_pool_from_url("sqlite://", foo=99, echo=True)
|
|
mock_create.assert_called_once_with(
|
|
"sqlite://", foo=99, echo=True, _is_async=True
|
|
)
|
|
|
|
|
|
class AsyncEventTest(EngineFixture):
|
|
"""The engine events all run in their normal synchronous context.
|
|
|
|
we do not provide an asyncio event interface at this time.
|
|
|
|
"""
|
|
|
|
__backend__ = True
|
|
|
|
@async_test
|
|
async def test_no_async_listeners(self, async_engine):
|
|
with testing.expect_raises_message(
|
|
NotImplementedError,
|
|
"asynchronous events are not implemented "
|
|
"at this time. Apply synchronous listeners to the "
|
|
"AsyncEngine.sync_engine or "
|
|
"AsyncConnection.sync_connection attributes.",
|
|
):
|
|
event.listen(async_engine, "before_cursor_execute", mock.Mock())
|
|
|
|
async with async_engine.connect() as conn:
|
|
with testing.expect_raises_message(
|
|
NotImplementedError,
|
|
"asynchronous events are not implemented "
|
|
"at this time. Apply synchronous listeners to the "
|
|
"AsyncEngine.sync_engine or "
|
|
"AsyncConnection.sync_connection attributes.",
|
|
):
|
|
event.listen(conn, "before_cursor_execute", mock.Mock())
|
|
|
|
@async_test
|
|
async def test_no_async_listeners_dialect_event(self, async_engine):
|
|
with testing.expect_raises_message(
|
|
NotImplementedError,
|
|
"asynchronous events are not implemented "
|
|
"at this time. Apply synchronous listeners to the "
|
|
"AsyncEngine.sync_engine or "
|
|
"AsyncConnection.sync_connection attributes.",
|
|
):
|
|
event.listen(async_engine, "do_execute", mock.Mock())
|
|
|
|
@async_test
|
|
async def test_no_async_listeners_pool_event(self, async_engine):
|
|
with testing.expect_raises_message(
|
|
NotImplementedError,
|
|
"asynchronous events are not implemented "
|
|
"at this time. Apply synchronous listeners to the "
|
|
"AsyncEngine.sync_engine or "
|
|
"AsyncConnection.sync_connection attributes.",
|
|
):
|
|
event.listen(async_engine, "checkout", mock.Mock())
|
|
|
|
@async_test
|
|
async def test_sync_before_cursor_execute_engine(self, async_engine):
|
|
canary = mock.Mock()
|
|
|
|
event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
|
|
|
|
async with async_engine.connect() as conn:
|
|
sync_conn = conn.sync_connection
|
|
await conn.execute(select(1))
|
|
|
|
s1 = str(select(1).compile(async_engine))
|
|
eq_(
|
|
canary.mock_calls,
|
|
[mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
|
|
)
|
|
|
|
@async_test
|
|
async def test_sync_before_cursor_execute_connection(self, async_engine):
|
|
canary = mock.Mock()
|
|
|
|
async with async_engine.connect() as conn:
|
|
sync_conn = conn.sync_connection
|
|
|
|
event.listen(
|
|
async_engine.sync_engine, "before_cursor_execute", canary
|
|
)
|
|
await conn.execute(select(1))
|
|
|
|
s1 = str(select(1).compile(async_engine))
|
|
eq_(
|
|
canary.mock_calls,
|
|
[mock.call(sync_conn, mock.ANY, s1, mock.ANY, mock.ANY, False)],
|
|
)
|
|
|
|
@async_test
|
|
async def test_event_on_sync_connection(self, async_engine):
|
|
canary = mock.Mock()
|
|
|
|
async with async_engine.connect() as conn:
|
|
event.listen(conn.sync_connection, "begin", canary)
|
|
async with conn.begin():
|
|
eq_(
|
|
canary.mock_calls,
|
|
[mock.call(conn.sync_connection)],
|
|
)
|
|
|
|
|
|
class AsyncInspection(EngineFixture):
|
|
__backend__ = True
|
|
|
|
@async_test
|
|
async def test_inspect_engine(self, async_engine):
|
|
with testing.expect_raises_message(
|
|
exc.NoInspectionAvailable,
|
|
"Inspection on an AsyncEngine is currently not supported.",
|
|
):
|
|
inspect(async_engine)
|
|
|
|
@async_test
|
|
async def test_inspect_connection(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
with testing.expect_raises_message(
|
|
exc.NoInspectionAvailable,
|
|
"Inspection on an AsyncConnection is currently not supported.",
|
|
):
|
|
inspect(conn)
|
|
|
|
|
|
class AsyncResultTest(EngineFixture):
|
|
__backend__ = True
|
|
__requires__ = ("server_side_cursors", "async_dialect")
|
|
|
|
@async_test
|
|
async def test_no_ss_cursor_w_execute(self, async_engine):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
conn = await conn.execution_options(stream_results=True)
|
|
with expect_raises_message(
|
|
async_exc.AsyncMethodRequired,
|
|
r"Can't use the AsyncConnection.execute\(\) method with a "
|
|
r"server-side cursor. Use the AsyncConnection.stream\(\) "
|
|
r"method for an async streaming result set.",
|
|
):
|
|
await conn.execute(select(users))
|
|
|
|
@async_test
|
|
async def test_no_ss_cursor_w_exec_driver_sql(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
conn = await conn.execution_options(stream_results=True)
|
|
with expect_raises_message(
|
|
async_exc.AsyncMethodRequired,
|
|
r"Can't use the AsyncConnection.exec_driver_sql\(\) "
|
|
r"method with a "
|
|
r"server-side cursor. Use the AsyncConnection.stream\(\) "
|
|
r"method for an async streaming result set.",
|
|
):
|
|
await conn.exec_driver_sql("SELECT * FROM users")
|
|
|
|
@async_test
|
|
async def test_stream_ctxmanager(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
conn = await conn.execution_options(stream_results=True)
|
|
|
|
async with conn.stream(select(self.tables.users)) as result:
|
|
assert not result._real_result._soft_closed
|
|
assert not result.closed
|
|
with expect_raises_message(Exception, "hi"):
|
|
i = 0
|
|
async for row in result:
|
|
if i > 2:
|
|
raise Exception("hi")
|
|
i += 1
|
|
assert result._real_result._soft_closed
|
|
assert result.closed
|
|
|
|
@async_test
|
|
async def test_stream_scalars_ctxmanager(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
conn = await conn.execution_options(stream_results=True)
|
|
|
|
async with conn.stream_scalars(
|
|
select(self.tables.users)
|
|
) as result:
|
|
assert not result._real_result._soft_closed
|
|
assert not result.closed
|
|
with expect_raises_message(Exception, "hi"):
|
|
i = 0
|
|
async for scalar in result:
|
|
if i > 2:
|
|
raise Exception("hi")
|
|
i += 1
|
|
assert result._real_result._soft_closed
|
|
assert result.closed
|
|
|
|
@testing.combinations(
|
|
(None,), ("scalars",), ("mappings",), argnames="filter_"
|
|
)
|
|
@async_test
|
|
async def test_all(self, async_engine, filter_):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(select(users))
|
|
|
|
if filter_ == "mappings":
|
|
result = result.mappings()
|
|
elif filter_ == "scalars":
|
|
result = result.scalars(1)
|
|
|
|
all_ = await result.all()
|
|
if filter_ == "mappings":
|
|
eq_(
|
|
all_,
|
|
[
|
|
{"user_id": i, "user_name": "name%d" % i}
|
|
for i in range(1, 20)
|
|
],
|
|
)
|
|
elif filter_ == "scalars":
|
|
eq_(
|
|
all_,
|
|
["name%d" % i for i in range(1, 20)],
|
|
)
|
|
else:
|
|
eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
|
|
|
|
@testing.combinations(
|
|
(None,),
|
|
("scalars",),
|
|
("stream_scalars",),
|
|
("mappings",),
|
|
argnames="filter_",
|
|
)
|
|
@async_test
|
|
async def test_aiter(self, async_engine, filter_):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
if filter_ == "stream_scalars":
|
|
result = await conn.stream_scalars(select(users.c.user_name))
|
|
else:
|
|
result = await conn.stream(select(users))
|
|
|
|
if filter_ == "mappings":
|
|
result = result.mappings()
|
|
elif filter_ == "scalars":
|
|
result = result.scalars(1)
|
|
|
|
rows = []
|
|
|
|
async for row in result:
|
|
rows.append(row)
|
|
|
|
if filter_ == "mappings":
|
|
eq_(
|
|
rows,
|
|
[
|
|
{"user_id": i, "user_name": "name%d" % i}
|
|
for i in range(1, 20)
|
|
],
|
|
)
|
|
elif filter_ in ("scalars", "stream_scalars"):
|
|
eq_(
|
|
rows,
|
|
["name%d" % i for i in range(1, 20)],
|
|
)
|
|
else:
|
|
eq_(rows, [(i, "name%d" % i) for i in range(1, 20)])
|
|
|
|
@testing.combinations((None,), ("mappings",), argnames="filter_")
|
|
@async_test
|
|
async def test_keys(self, async_engine, filter_):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(select(users))
|
|
|
|
if filter_ == "mappings":
|
|
result = result.mappings()
|
|
|
|
eq_(result.keys(), ["user_id", "user_name"])
|
|
|
|
await result.close()
|
|
|
|
@async_test
|
|
async def test_unique_all(self, async_engine):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(
|
|
union_all(select(users), select(users)).order_by(
|
|
users.c.user_id
|
|
)
|
|
)
|
|
|
|
all_ = await result.unique().all()
|
|
eq_(all_, [(i, "name%d" % i) for i in range(1, 20)])
|
|
|
|
@async_test
|
|
async def test_columns_all(self, async_engine):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(select(users))
|
|
|
|
all_ = await result.columns(1).all()
|
|
eq_(all_, [("name%d" % i,) for i in range(1, 20)])
|
|
|
|
@testing.combinations(
|
|
(None,), ("scalars",), ("mappings",), argnames="filter_"
|
|
)
|
|
@testing.combinations(None, 2, 5, 10, argnames="yield_per")
|
|
@testing.combinations("method", "opt", argnames="yield_per_type")
|
|
@async_test
|
|
async def test_partitions(
|
|
self, async_engine, filter_, yield_per, yield_per_type
|
|
):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
stmt = select(users)
|
|
if yield_per and yield_per_type == "opt":
|
|
stmt = stmt.execution_options(yield_per=yield_per)
|
|
result = await conn.stream(stmt)
|
|
|
|
if filter_ == "mappings":
|
|
result = result.mappings()
|
|
elif filter_ == "scalars":
|
|
result = result.scalars(1)
|
|
|
|
if yield_per and yield_per_type == "method":
|
|
result = result.yield_per(yield_per)
|
|
|
|
check_result = []
|
|
|
|
# stream() sets stream_results unconditionally
|
|
assert isinstance(
|
|
result._real_result.cursor_strategy,
|
|
_cursor.BufferedRowCursorFetchStrategy,
|
|
)
|
|
|
|
if yield_per:
|
|
partition_size = yield_per
|
|
|
|
eq_(result._real_result.cursor_strategy._bufsize, yield_per)
|
|
|
|
async for partition in result.partitions():
|
|
check_result.append(partition)
|
|
else:
|
|
eq_(result._real_result.cursor_strategy._bufsize, 5)
|
|
|
|
partition_size = 5
|
|
async for partition in result.partitions(partition_size):
|
|
check_result.append(partition)
|
|
|
|
ranges = [
|
|
(i, min(20, i + partition_size))
|
|
for i in range(1, 21, partition_size)
|
|
]
|
|
|
|
if filter_ == "mappings":
|
|
eq_(
|
|
check_result,
|
|
[
|
|
[
|
|
{"user_id": i, "user_name": "name%d" % i}
|
|
for i in range(a, b)
|
|
]
|
|
for (a, b) in ranges
|
|
],
|
|
)
|
|
elif filter_ == "scalars":
|
|
eq_(
|
|
check_result,
|
|
[["name%d" % i for i in range(a, b)] for (a, b) in ranges],
|
|
)
|
|
else:
|
|
eq_(
|
|
check_result,
|
|
[
|
|
[(i, "name%d" % i) for i in range(a, b)]
|
|
for (a, b) in ranges
|
|
],
|
|
)
|
|
|
|
@testing.combinations(
|
|
(None,), ("scalars",), ("mappings",), argnames="filter_"
|
|
)
|
|
@async_test
|
|
async def test_one_success(self, async_engine, filter_):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(
|
|
select(users).limit(1).order_by(users.c.user_name)
|
|
)
|
|
|
|
if filter_ == "mappings":
|
|
result = result.mappings()
|
|
elif filter_ == "scalars":
|
|
result = result.scalars()
|
|
u1 = await result.one()
|
|
|
|
if filter_ == "mappings":
|
|
eq_(u1, {"user_id": 1, "user_name": "name%d" % 1})
|
|
elif filter_ == "scalars":
|
|
eq_(u1, 1)
|
|
else:
|
|
eq_(u1, (1, "name%d" % 1))
|
|
|
|
@async_test
|
|
async def test_one_no_result(self, async_engine):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(
|
|
select(users).where(users.c.user_name == "nonexistent")
|
|
)
|
|
|
|
with expect_raises_message(
|
|
exc.NoResultFound, "No row was found when one was required"
|
|
):
|
|
await result.one()
|
|
|
|
@async_test
|
|
async def test_one_multi_result(self, async_engine):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(
|
|
select(users).where(users.c.user_name.in_(["name3", "name5"]))
|
|
)
|
|
|
|
with expect_raises_message(
|
|
exc.MultipleResultsFound,
|
|
"Multiple rows were found when exactly one was required",
|
|
):
|
|
await result.one()
|
|
|
|
@testing.combinations(("scalars",), ("stream_scalars",), argnames="case")
|
|
@async_test
|
|
async def test_scalars(self, async_engine, case):
|
|
users = self.tables.users
|
|
stmt = select(users).order_by(users.c.user_id)
|
|
async with async_engine.connect() as conn:
|
|
if case == "scalars":
|
|
result = (await conn.scalars(stmt)).all()
|
|
elif case == "stream_scalars":
|
|
result = await (await conn.stream_scalars(stmt)).all()
|
|
|
|
eq_(result, list(range(1, 20)))
|
|
|
|
@async_test
|
|
@testing.combinations(("stream",), ("stream_scalars",), argnames="case")
|
|
async def test_stream_fetch_many_not_complete(self, async_engine, case):
|
|
users = self.tables.users
|
|
big_query = select(users).join(users.alias("other"), true())
|
|
async with async_engine.connect() as conn:
|
|
if case == "stream":
|
|
result = await conn.stream(big_query)
|
|
elif case == "stream_scalars":
|
|
result = await conn.stream_scalars(big_query)
|
|
|
|
f1 = await result.fetchmany(5)
|
|
f2 = await result.fetchmany(10)
|
|
f3 = await result.fetchmany(7)
|
|
eq_(len(f1) + len(f2) + len(f3), 22)
|
|
|
|
res = await result.fetchall()
|
|
eq_(len(res), 19 * 19 - 22)
|
|
|
|
@async_test
|
|
@testing.combinations(("stream",), ("execute",), argnames="case")
|
|
async def test_cursor_close(self, async_engine, case):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
if case == "stream":
|
|
result = await conn.stream(select(users))
|
|
cursor = result._real_result.cursor
|
|
elif case == "execute":
|
|
result = await conn.execute(select(users))
|
|
cursor = result.cursor
|
|
|
|
await conn.run_sync(lambda _: cursor.close())
|
|
|
|
@async_test
|
|
@testing.variation("case", ["scalar_one", "scalar_one_or_none", "scalar"])
|
|
async def test_stream_scalar(self, async_engine, case: testing.Variation):
|
|
users = self.tables.users
|
|
async with async_engine.connect() as conn:
|
|
result = await conn.stream(
|
|
select(users).limit(1).order_by(users.c.user_name)
|
|
)
|
|
|
|
if case.scalar_one:
|
|
u1 = await result.scalar_one()
|
|
elif case.scalar_one_or_none:
|
|
u1 = await result.scalar_one_or_none()
|
|
elif case.scalar:
|
|
u1 = await result.scalar()
|
|
else:
|
|
case.fail()
|
|
|
|
eq_(u1, 1)
|
|
|
|
|
|
class TextSyncDBAPI(fixtures.TestBase):
|
|
__requires__ = ("asyncio",)
|
|
|
|
def test_sync_dbapi_raises(self):
|
|
with expect_raises_message(
|
|
exc.InvalidRequestError,
|
|
"The asyncio extension requires an async driver to be used.",
|
|
):
|
|
create_async_engine("sqlite:///:memory:")
|
|
|
|
@testing.fixture
|
|
def async_engine(self):
|
|
engine = create_engine("sqlite:///:memory:", future=True)
|
|
engine.dialect.is_async = True
|
|
engine.dialect.supports_server_side_cursors = True
|
|
with mock.patch.object(
|
|
engine.dialect.execution_ctx_cls,
|
|
"create_server_side_cursor",
|
|
engine.dialect.execution_ctx_cls.create_default_cursor,
|
|
):
|
|
yield _async_engine.AsyncEngine(engine)
|
|
|
|
@async_test
|
|
@combinations(
|
|
lambda conn: conn.exec_driver_sql("select 1"),
|
|
lambda conn: conn.stream(text("select 1")),
|
|
lambda conn: conn.execute(text("select 1")),
|
|
argnames="case",
|
|
)
|
|
async def test_sync_driver_execution(self, async_engine, case):
|
|
with expect_raises_message(
|
|
exc.AwaitRequired,
|
|
"The current operation required an async execution but none was",
|
|
):
|
|
async with async_engine.connect() as conn:
|
|
await case(conn)
|
|
|
|
@async_test
|
|
async def test_sync_driver_run_sync(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
res = await conn.run_sync(
|
|
lambda conn: conn.scalar(text("select 1"))
|
|
)
|
|
assert res == 1
|
|
assert await conn.run_sync(lambda _: 2) == 2
|
|
|
|
|
|
class AsyncProxyTest(EngineFixture, fixtures.TestBase):
|
|
@async_test
|
|
async def test_get_transaction(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
async with conn.begin() as trans:
|
|
is_(trans.connection, conn)
|
|
is_(conn.get_transaction(), trans)
|
|
|
|
@async_test
|
|
async def test_get_nested_transaction(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
async with conn.begin() as trans:
|
|
n1 = await conn.begin_nested()
|
|
|
|
is_(conn.get_nested_transaction(), n1)
|
|
|
|
n2 = await conn.begin_nested()
|
|
|
|
is_(conn.get_nested_transaction(), n2)
|
|
|
|
await n2.commit()
|
|
|
|
is_(conn.get_nested_transaction(), n1)
|
|
|
|
is_(conn.get_transaction(), trans)
|
|
|
|
@async_test
|
|
async def test_get_connection(self, async_engine):
|
|
async with async_engine.connect() as conn:
|
|
is_(
|
|
AsyncConnection._retrieve_proxy_for_target(
|
|
conn.sync_connection
|
|
),
|
|
conn,
|
|
)
|
|
|
|
def test_regenerate_connection(self, connection):
|
|
async_connection = AsyncConnection._retrieve_proxy_for_target(
|
|
connection
|
|
)
|
|
|
|
a2 = AsyncConnection._retrieve_proxy_for_target(connection)
|
|
is_(async_connection, a2)
|
|
is_not(async_connection, None)
|
|
|
|
is_(async_connection.engine, a2.engine)
|
|
is_not(async_connection.engine, None)
|
|
|
|
@testing.requires.predictable_gc
|
|
@async_test
|
|
async def test_gc_engine(self, testing_engine):
|
|
ReversibleProxy._proxy_objects.clear()
|
|
|
|
eq_(len(ReversibleProxy._proxy_objects), 0)
|
|
|
|
async_engine = AsyncEngine(testing.db)
|
|
|
|
eq_(len(ReversibleProxy._proxy_objects), 1)
|
|
|
|
del async_engine
|
|
|
|
eq_(len(ReversibleProxy._proxy_objects), 0)
|
|
|
|
@testing.requires.predictable_gc
|
|
@async_test
|
|
async def test_gc_conn(self, testing_engine):
|
|
ReversibleProxy._proxy_objects.clear()
|
|
|
|
async_engine = AsyncEngine(testing.db)
|
|
|
|
eq_(len(ReversibleProxy._proxy_objects), 1)
|
|
|
|
async with async_engine.connect() as conn:
|
|
eq_(len(ReversibleProxy._proxy_objects), 2)
|
|
|
|
async with conn.begin() as trans:
|
|
eq_(len(ReversibleProxy._proxy_objects), 3)
|
|
|
|
del trans
|
|
|
|
del conn
|
|
|
|
eq_(len(ReversibleProxy._proxy_objects), 1)
|
|
|
|
del async_engine
|
|
|
|
eq_(len(ReversibleProxy._proxy_objects), 0)
|
|
|
|
def test_regen_conn_but_not_engine(self, async_engine):
|
|
with async_engine.sync_engine.connect() as sync_conn:
|
|
async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn)
|
|
async_conn2 = AsyncConnection._retrieve_proxy_for_target(sync_conn)
|
|
|
|
is_(async_conn, async_conn2)
|
|
is_(async_conn.engine, async_engine)
|
|
|
|
def test_regen_trans_but_not_conn(self, connection_no_trans):
|
|
sync_conn = connection_no_trans
|
|
|
|
async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn)
|
|
|
|
trans = sync_conn.begin()
|
|
|
|
async_t1 = async_conn.get_transaction()
|
|
|
|
is_(async_t1.connection, async_conn)
|
|
is_(async_t1.sync_transaction, trans)
|
|
|
|
async_t2 = async_conn.get_transaction()
|
|
is_(async_t1, async_t2)
|
|
|
|
|
|
class PoolRegenTest(EngineFixture):
|
|
@testing.requires.queue_pool
|
|
@async_test
|
|
@testing.variation("do_dispose", [True, False])
|
|
async def test_gather_after_dispose(self, testing_engine, do_dispose):
|
|
engine = testing_engine(
|
|
asyncio=True, options=dict(pool_size=10, max_overflow=10)
|
|
)
|
|
|
|
async def thing(engine):
|
|
async with engine.connect() as conn:
|
|
await conn.exec_driver_sql(str(select(1).compile(engine)))
|
|
|
|
if do_dispose:
|
|
await engine.dispose()
|
|
|
|
tasks = [thing(engine) for _ in range(10)]
|
|
await asyncio.gather(*tasks)
|