mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-21 16:12:03 -04:00
Merge "Repair async test refactor"
This commit is contained in:
@@ -41,7 +41,7 @@ def create_async_engine(*arg, **kw):
|
||||
|
||||
|
||||
class AsyncConnectable:
|
||||
__slots__ = "_slots_dispatch"
|
||||
__slots__ = "_slots_dispatch", "__weakref__"
|
||||
|
||||
|
||||
@util.create_proxy_methods(
|
||||
|
||||
@@ -22,12 +22,17 @@ import inspect
|
||||
|
||||
from . import config
|
||||
from ..util.concurrency import _util_async_run
|
||||
from ..util.concurrency import _util_async_run_coroutine_function
|
||||
|
||||
# may be set to False if the
|
||||
# --disable-asyncio flag is passed to the test runner.
|
||||
ENABLE_ASYNCIO = True
|
||||
|
||||
|
||||
def _run_coroutine_function(fn, *args, **kwargs):
|
||||
return _util_async_run_coroutine_function(fn, *args, **kwargs)
|
||||
|
||||
|
||||
def _assume_async(fn, *args, **kwargs):
|
||||
"""Run a function in an asyncio loop unconditionally.
|
||||
|
||||
|
||||
@@ -97,7 +97,10 @@ class ConnectionKiller(object):
|
||||
|
||||
self.conns = set()
|
||||
for rec in list(self.testing_engines):
|
||||
rec.dispose()
|
||||
if hasattr(rec, "sync_engine"):
|
||||
rec.sync_engine.dispose()
|
||||
else:
|
||||
rec.dispose()
|
||||
|
||||
def assert_all_closed(self):
|
||||
for rec in self.proxy_refs:
|
||||
@@ -236,10 +239,12 @@ def reconnecting_engine(url=None, options=None):
|
||||
return engine
|
||||
|
||||
|
||||
def testing_engine(url=None, options=None, future=False):
|
||||
def testing_engine(url=None, options=None, future=False, asyncio=False):
|
||||
"""Produce an engine configured by --options with optional overrides."""
|
||||
|
||||
if future or config.db and config.db._is_future:
|
||||
if asyncio:
|
||||
from sqlalchemy.ext.asyncio import create_async_engine as create_engine
|
||||
elif future or config.db and config.db._is_future:
|
||||
from sqlalchemy.future import create_engine
|
||||
else:
|
||||
from sqlalchemy import create_engine
|
||||
@@ -263,7 +268,10 @@ def testing_engine(url=None, options=None, future=False):
|
||||
default_opt.update(options)
|
||||
|
||||
engine = create_engine(url, **options)
|
||||
engine._has_events = True # enable event blocks, helps with profiling
|
||||
if asyncio:
|
||||
engine.sync_engine._has_events = True
|
||||
else:
|
||||
engine._has_events = True # enable event blocks, helps with profiling
|
||||
|
||||
if isinstance(engine.pool, pool.QueuePool):
|
||||
engine.pool._timeout = 0
|
||||
|
||||
@@ -48,11 +48,6 @@ class TestBase(object):
|
||||
# skipped.
|
||||
__skip_if__ = None
|
||||
|
||||
# If this class should be wrapped in asyncio compatibility functions
|
||||
# when using an async engine. This should be set to False only for tests
|
||||
# that use the asyncio features of sqlalchemy directly
|
||||
__asyncio_wrap__ = True
|
||||
|
||||
def assert_(self, val, msg=None):
|
||||
assert val, msg
|
||||
|
||||
@@ -95,12 +90,6 @@ class TestBase(object):
|
||||
# engines.drop_all_tables(metadata, config.db)
|
||||
|
||||
|
||||
class AsyncTestBase(TestBase):
|
||||
"""Mixin marking a test as using its own explicit asyncio patterns."""
|
||||
|
||||
__asyncio_wrap__ = False
|
||||
|
||||
|
||||
class FutureEngineMixin(object):
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
|
||||
@@ -255,7 +255,7 @@ def pytest_pycollect_makeitem(collector, name, obj):
|
||||
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
|
||||
from sqlalchemy.testing import config
|
||||
|
||||
if config.any_async and getattr(obj, "__asyncio_wrap__", True):
|
||||
if config.any_async:
|
||||
obj = _apply_maybe_async(obj)
|
||||
|
||||
ctor = getattr(pytest.Class, "from_parent", pytest.Class)
|
||||
@@ -277,6 +277,13 @@ def pytest_pycollect_makeitem(collector, name, obj):
|
||||
return []
|
||||
|
||||
|
||||
def _is_wrapped_coroutine_function(fn):
|
||||
while hasattr(fn, "__wrapped__"):
|
||||
fn = fn.__wrapped__
|
||||
|
||||
return inspect.iscoroutinefunction(fn)
|
||||
|
||||
|
||||
def _apply_maybe_async(obj, recurse=True):
|
||||
from sqlalchemy.testing import asyncio
|
||||
|
||||
@@ -286,6 +293,7 @@ def _apply_maybe_async(obj, recurse=True):
|
||||
(callable(value) or isinstance(value, classmethod))
|
||||
and not getattr(value, "_maybe_async_applied", False)
|
||||
and (name.startswith("test_") or name in setup_names)
|
||||
and not _is_wrapped_coroutine_function(value)
|
||||
):
|
||||
is_classmethod = False
|
||||
if isinstance(value, classmethod):
|
||||
@@ -656,6 +664,6 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
|
||||
|
||||
@_pytest_fn_decorator
|
||||
def decorate(fn, *args, **kwargs):
|
||||
asyncio._assume_async(fn, *args, **kwargs)
|
||||
asyncio._run_coroutine_function(fn, *args, **kwargs)
|
||||
|
||||
return decorate(fn)
|
||||
|
||||
@@ -136,6 +136,18 @@ class AsyncAdaptedLock:
|
||||
self.mutex.release()
|
||||
|
||||
|
||||
def _util_async_run_coroutine_function(fn, *args, **kwargs):
|
||||
"""for test suite/ util only"""
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
raise Exception(
|
||||
"for async run coroutine we expect that no greenlet or event "
|
||||
"loop is running when we start out"
|
||||
)
|
||||
return loop.run_until_complete(fn(*args, **kwargs))
|
||||
|
||||
|
||||
def _util_async_run(fn, *args, **kwargs):
|
||||
"""for test suite/ util only"""
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ if compat.py3k:
|
||||
from ._concurrency_py3k import greenlet_spawn
|
||||
from ._concurrency_py3k import AsyncAdaptedLock
|
||||
from ._concurrency_py3k import _util_async_run # noqa F401
|
||||
from ._concurrency_py3k import (
|
||||
_util_async_run_coroutine_function,
|
||||
) # noqa F401, E501
|
||||
from ._concurrency_py3k import asyncio # noqa F401
|
||||
|
||||
if not have_greenlet:
|
||||
@@ -42,3 +45,6 @@ if not have_greenlet:
|
||||
|
||||
def _util_async_run(fn, *arg, **kw): # noqa F81
|
||||
return fn(*arg, **kw)
|
||||
|
||||
def _util_async_run_coroutine_function(fn, *arg, **kw): # noqa F81
|
||||
_not_implemented()
|
||||
|
||||
@@ -26,7 +26,7 @@ def go(*fns):
|
||||
return sum(await_only(fn()) for fn in fns)
|
||||
|
||||
|
||||
class TestAsyncioCompat(fixtures.AsyncTestBase):
|
||||
class TestAsyncioCompat(fixtures.TestBase):
|
||||
@async_test
|
||||
async def test_ok(self):
|
||||
|
||||
@@ -53,7 +53,8 @@ class TestAsyncioCompat(fixtures.AsyncTestBase):
|
||||
to_await = run1()
|
||||
await_fallback(to_await)
|
||||
|
||||
def test_await_only_no_greenlet(self):
|
||||
@async_test
|
||||
async def test_await_only_no_greenlet(self):
|
||||
to_await = run1()
|
||||
with expect_raises_message(
|
||||
exc.InvalidRequestError,
|
||||
@@ -62,7 +63,7 @@ class TestAsyncioCompat(fixtures.AsyncTestBase):
|
||||
await_only(to_await)
|
||||
|
||||
# ensure no warning
|
||||
await_fallback(to_await)
|
||||
await greenlet_spawn(await_fallback, to_await)
|
||||
|
||||
@async_test
|
||||
async def test_await_fallback_error(self):
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import engine as _async_engine
|
||||
from sqlalchemy.ext.asyncio import exc as asyncio_exc
|
||||
from sqlalchemy.testing import async_test
|
||||
from sqlalchemy.testing import combinations
|
||||
from sqlalchemy.testing import engines
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import expect_raises
|
||||
from sqlalchemy.testing import expect_raises_message
|
||||
@@ -32,7 +33,7 @@ class EngineFixture(fixtures.TablesTest):
|
||||
|
||||
@testing.fixture
|
||||
def async_engine(self):
|
||||
return create_async_engine(testing.db.url)
|
||||
return engines.testing_engine(asyncio=True)
|
||||
|
||||
@classmethod
|
||||
def define_tables(cls, metadata):
|
||||
@@ -55,6 +56,12 @@ class EngineFixture(fixtures.TablesTest):
|
||||
class AsyncEngineTest(EngineFixture):
|
||||
__backend__ = True
|
||||
|
||||
@testing.fails("the failure is the test")
|
||||
@async_test
|
||||
async def test_we_are_definitely_running_async_tests(self, async_engine):
|
||||
async with async_engine.connect() as conn:
|
||||
eq_(await conn.scalar(text("select 1")), 2)
|
||||
|
||||
def test_proxied_attrs_engine(self, async_engine):
|
||||
sync_engine = async_engine.sync_engine
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@ from sqlalchemy import select
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.testing import async_test
|
||||
from sqlalchemy.testing import engines
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import is_
|
||||
from sqlalchemy.testing import mock
|
||||
@@ -24,7 +24,7 @@ class AsyncFixture(_fixtures.FixtureTest):
|
||||
|
||||
@testing.fixture
|
||||
def async_engine(self):
|
||||
return create_async_engine(testing.db.url)
|
||||
return engines.testing_engine(asyncio=True)
|
||||
|
||||
@testing.fixture
|
||||
def async_session(self, async_engine):
|
||||
|
||||
Reference in New Issue
Block a user