Merge "Repair async test refactor"

This commit is contained in:
mike bayer
2021-01-03 00:52:41 +00:00
committed by Gerrit Code Review
10 changed files with 60 additions and 24 deletions
+1 -1
View File
@@ -41,7 +41,7 @@ def create_async_engine(*arg, **kw):
class AsyncConnectable:
__slots__ = "_slots_dispatch"
__slots__ = "_slots_dispatch", "__weakref__"
@util.create_proxy_methods(
+5
View File
@@ -22,12 +22,17 @@ import inspect
from . import config
from ..util.concurrency import _util_async_run
from ..util.concurrency import _util_async_run_coroutine_function
# may be set to False if the
# --disable-asyncio flag is passed to the test runner.
ENABLE_ASYNCIO = True
def _run_coroutine_function(fn, *args, **kwargs):
return _util_async_run_coroutine_function(fn, *args, **kwargs)
def _assume_async(fn, *args, **kwargs):
"""Run a function in an asyncio loop unconditionally.
+12 -4
View File
@@ -97,7 +97,10 @@ class ConnectionKiller(object):
self.conns = set()
for rec in list(self.testing_engines):
rec.dispose()
if hasattr(rec, "sync_engine"):
rec.sync_engine.dispose()
else:
rec.dispose()
def assert_all_closed(self):
for rec in self.proxy_refs:
@@ -236,10 +239,12 @@ def reconnecting_engine(url=None, options=None):
return engine
def testing_engine(url=None, options=None, future=False):
def testing_engine(url=None, options=None, future=False, asyncio=False):
"""Produce an engine configured by --options with optional overrides."""
if future or config.db and config.db._is_future:
if asyncio:
from sqlalchemy.ext.asyncio import create_async_engine as create_engine
elif future or config.db and config.db._is_future:
from sqlalchemy.future import create_engine
else:
from sqlalchemy import create_engine
@@ -263,7 +268,10 @@ def testing_engine(url=None, options=None, future=False):
default_opt.update(options)
engine = create_engine(url, **options)
engine._has_events = True # enable event blocks, helps with profiling
if asyncio:
engine.sync_engine._has_events = True
else:
engine._has_events = True # enable event blocks, helps with profiling
if isinstance(engine.pool, pool.QueuePool):
engine.pool._timeout = 0
-11
View File
@@ -48,11 +48,6 @@ class TestBase(object):
# skipped.
__skip_if__ = None
# If this class should be wrapped in asyncio compatibility functions
# when using an async engine. This should be set to False only for tests
# that use the asyncio features of sqlalchemy directly
__asyncio_wrap__ = True
def assert_(self, val, msg=None):
assert val, msg
@@ -95,12 +90,6 @@ class TestBase(object):
# engines.drop_all_tables(metadata, config.db)
class AsyncTestBase(TestBase):
"""Mixin marking a test as using its own explicit asyncio patterns."""
__asyncio_wrap__ = False
class FutureEngineMixin(object):
@classmethod
def setup_class(cls):
+10 -2
View File
@@ -255,7 +255,7 @@ def pytest_pycollect_makeitem(collector, name, obj):
if inspect.isclass(obj) and plugin_base.want_class(name, obj):
from sqlalchemy.testing import config
if config.any_async and getattr(obj, "__asyncio_wrap__", True):
if config.any_async:
obj = _apply_maybe_async(obj)
ctor = getattr(pytest.Class, "from_parent", pytest.Class)
@@ -277,6 +277,13 @@ def pytest_pycollect_makeitem(collector, name, obj):
return []
def _is_wrapped_coroutine_function(fn):
while hasattr(fn, "__wrapped__"):
fn = fn.__wrapped__
return inspect.iscoroutinefunction(fn)
def _apply_maybe_async(obj, recurse=True):
from sqlalchemy.testing import asyncio
@@ -286,6 +293,7 @@ def _apply_maybe_async(obj, recurse=True):
(callable(value) or isinstance(value, classmethod))
and not getattr(value, "_maybe_async_applied", False)
and (name.startswith("test_") or name in setup_names)
and not _is_wrapped_coroutine_function(value)
):
is_classmethod = False
if isinstance(value, classmethod):
@@ -656,6 +664,6 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
@_pytest_fn_decorator
def decorate(fn, *args, **kwargs):
asyncio._assume_async(fn, *args, **kwargs)
asyncio._run_coroutine_function(fn, *args, **kwargs)
return decorate(fn)
+12
View File
@@ -136,6 +136,18 @@ class AsyncAdaptedLock:
self.mutex.release()
def _util_async_run_coroutine_function(fn, *args, **kwargs):
"""for test suite/ util only"""
loop = asyncio.get_event_loop()
if loop.is_running():
raise Exception(
"for async run coroutine we expect that no greenlet or event "
"loop is running when we start out"
)
return loop.run_until_complete(fn(*args, **kwargs))
def _util_async_run(fn, *args, **kwargs):
"""for test suite/ util only"""
+6
View File
@@ -14,6 +14,9 @@ if compat.py3k:
from ._concurrency_py3k import greenlet_spawn
from ._concurrency_py3k import AsyncAdaptedLock
from ._concurrency_py3k import _util_async_run # noqa F401
from ._concurrency_py3k import (
_util_async_run_coroutine_function,
) # noqa F401, E501
from ._concurrency_py3k import asyncio # noqa F401
if not have_greenlet:
@@ -42,3 +45,6 @@ if not have_greenlet:
def _util_async_run(fn, *arg, **kw): # noqa F81
return fn(*arg, **kw)
def _util_async_run_coroutine_function(fn, *arg, **kw): # noqa F81
_not_implemented()
+4 -3
View File
@@ -26,7 +26,7 @@ def go(*fns):
return sum(await_only(fn()) for fn in fns)
class TestAsyncioCompat(fixtures.AsyncTestBase):
class TestAsyncioCompat(fixtures.TestBase):
@async_test
async def test_ok(self):
@@ -53,7 +53,8 @@ class TestAsyncioCompat(fixtures.AsyncTestBase):
to_await = run1()
await_fallback(to_await)
def test_await_only_no_greenlet(self):
@async_test
async def test_await_only_no_greenlet(self):
to_await = run1()
with expect_raises_message(
exc.InvalidRequestError,
@@ -62,7 +63,7 @@ class TestAsyncioCompat(fixtures.AsyncTestBase):
await_only(to_await)
# ensure no warning
await_fallback(to_await)
await greenlet_spawn(await_fallback, to_await)
@async_test
async def test_await_fallback_error(self):
+8 -1
View File
@@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import engine as _async_engine
from sqlalchemy.ext.asyncio import exc as asyncio_exc
from sqlalchemy.testing import async_test
from sqlalchemy.testing import combinations
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import expect_raises
from sqlalchemy.testing import expect_raises_message
@@ -32,7 +33,7 @@ class EngineFixture(fixtures.TablesTest):
@testing.fixture
def async_engine(self):
return create_async_engine(testing.db.url)
return engines.testing_engine(asyncio=True)
@classmethod
def define_tables(cls, metadata):
@@ -55,6 +56,12 @@ class EngineFixture(fixtures.TablesTest):
class AsyncEngineTest(EngineFixture):
__backend__ = True
@testing.fails("the failure is the test")
@async_test
async def test_we_are_definitely_running_async_tests(self, async_engine):
async with async_engine.connect() as conn:
eq_(await conn.scalar(text("select 1")), 2)
def test_proxied_attrs_engine(self, async_engine):
sync_engine = async_engine.sync_engine
+2 -2
View File
@@ -5,10 +5,10 @@ from sqlalchemy import select
from sqlalchemy import testing
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import sessionmaker
from sqlalchemy.testing import async_test
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
@@ -24,7 +24,7 @@ class AsyncFixture(_fixtures.FixtureTest):
@testing.fixture
def async_engine(self):
return create_async_engine(testing.db.url)
return engines.testing_engine(asyncio=True)
@testing.fixture
def async_session(self, async_engine):