Allow custom sync session class in `AsyncSession`.

The :class:`_asyncio.AsyncSession` now supports overriding which
:class:`_orm.Session` it uses as the proxied instance. A custom ``Session``
class can be passed using the :paramref:`.AsyncSession.sync_session_class`
parameter or by subclassing the ``AsyncSession`` and specifying a custom
:attr:`.AsyncSession.sync_session_class`.

Fixes: #6689
Change-Id: Idf9c24eae6c9f4e2fff292ed748feaa449a8deaa
This commit is contained in:
Federico Caselli
2021-08-27 22:45:56 +02:00
committed by mike bayer
parent 9131a5208f
commit af0824fd79
4 changed files with 107 additions and 7 deletions
+9
View File
@@ -0,0 +1,9 @@
.. change::
:tags: asyncio, usecase
:tickets: 6746
The :class:`_asyncio.AsyncSession` now supports overriding which
:class:`_orm.Session` it uses as the proxied instance. A custom ``Session``
class can be passed using the :paramref:`.AsyncSession.sync_session_class`
parameter or by subclassing the ``AsyncSession`` and specifying a custom
:attr:`.AsyncSession.sync_session_class`.
+3
View File
@@ -581,6 +581,9 @@ ORM Session API Documentation
.. autoclass:: AsyncSession
:members:
:exclude-members: sync_session_class
.. autoattribute:: sync_session_class
.. autoclass:: AsyncSessionTransaction
:members:
+49 -3
View File
@@ -51,9 +51,16 @@ _STREAM_OPTIONS = util.immutabledict({"stream_results": True})
class AsyncSession(ReversibleProxy):
"""Asyncio version of :class:`_orm.Session`.
The :class:`_asyncio.AsyncSession` is a proxy for a traditional
:class:`_orm.Session` instance.
.. versionadded:: 1.4
To use an :class:`_asyncio.AsyncSession` with custom :class:`_orm.Session`
implementations, see the
:paramref:`_asyncio.AsyncSession.sync_session_class` parameter.
"""
_is_asyncio = True
@@ -68,7 +75,25 @@ class AsyncSession(ReversibleProxy):
dispatch = None
def __init__(self, bind=None, binds=None, **kw):
def __init__(self, bind=None, binds=None, sync_session_class=None, **kw):
r"""Construct a new :class:`_asyncio.AsyncSession`.
All parameters other than ``sync_session_class`` are passed to the
``sync_session_class`` callable directly to instantiate a new
:class:`_orm.Session`. Refer to :meth:`_orm.Session.__init__` for
parameter documentation.
:param sync_session_class:
A :class:`_orm.Session` subclass or other callable which will be used
to construct the :class:`_orm.Session` which will be proxied. This
parameter may be used to provide custom :class:`_orm.Session`
subclasses. Defaults to the
:attr:`_asyncio.AsyncSession.sync_session_class` class-level
attribute.
.. versionadded:: 1.4.24
"""
kw["future"] = True
if bind:
self.bind = bind
@@ -81,10 +106,30 @@ class AsyncSession(ReversibleProxy):
for key, b in binds.items()
}
if sync_session_class:
self.sync_session_class = sync_session_class
self.sync_session = self._proxied = self._assign_proxied(
Session(bind=bind, binds=binds, **kw)
self.sync_session_class(bind=bind, binds=binds, **kw)
)
sync_session_class = Session
"""The class or callable that provides the
underlying :class:`_orm.Session` instance for a particular
:class:`_asyncio.AsyncSession`.
At the class level, this attribute is the default value for the
:paramref:`_asyncio.AsyncSession.sync_session_class` parameter. Custom
subclasses of :class:`_asyncio.AsyncSession` can override this.
At the instance level, this attribute indicates the current class or
callable that was used to provide the :class:`_orm.Session` instance for
this :class:`_asyncio.AsyncSession` instance.
.. versionadded:: 1.4.24
"""
async def refresh(
self, instance, attribute_names=None, with_for_update=None
):
@@ -141,7 +186,8 @@ class AsyncSession(ReversibleProxy):
**kw
):
"""Execute a statement and return a buffered
:class:`_engine.Result` object."""
:class:`_engine.Result` object.
"""
if execution_options:
execution_options = util.immutabledict(execution_options).union(
+46 -4
View File
@@ -14,11 +14,13 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio.base import ReversibleProxy
from sqlalchemy.orm import relationship
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from sqlalchemy.testing import async_test
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_true
from sqlalchemy.testing import mock
from .test_engine_py3k import AsyncFixture as _AsyncFixture
from ...orm import _fixtures
@@ -724,8 +726,6 @@ class AsyncProxyTest(AsyncFixture):
is_(inspect(u3).async_session, None)
def test_inspect_session_no_asyncio_used(self):
from sqlalchemy.orm import Session
User = self.classes.User
s1 = Session(testing.db)
@@ -734,8 +734,6 @@ class AsyncProxyTest(AsyncFixture):
is_(inspect(u1).async_session, None)
def test_inspect_session_no_asyncio_imported(self):
from sqlalchemy.orm import Session
with mock.patch("sqlalchemy.orm.state._async_provider", None):
User = self.classes.User
@@ -758,3 +756,47 @@ class AsyncProxyTest(AsyncFixture):
del async_session
eq_(len(ReversibleProxy._proxy_objects), 0)
class _MySession(Session):
pass
class _MyAS(AsyncSession):
sync_session_class = _MySession
class OverrideSyncSession(AsyncFixture):
def test_default(self, async_engine):
ass = AsyncSession(async_engine)
is_true(isinstance(ass.sync_session, Session))
is_(ass.sync_session.__class__, Session)
is_(ass.sync_session_class, Session)
def test_init_class(self, async_engine):
ass = AsyncSession(async_engine, sync_session_class=_MySession)
is_true(isinstance(ass.sync_session, _MySession))
is_(ass.sync_session_class, _MySession)
def test_init_sessionmaker(self, async_engine):
sm = sessionmaker(
async_engine, class_=AsyncSession, sync_session_class=_MySession
)
ass = sm()
is_true(isinstance(ass.sync_session, _MySession))
is_(ass.sync_session_class, _MySession)
def test_subclass(self, async_engine):
ass = _MyAS(async_engine)
is_true(isinstance(ass.sync_session, _MySession))
is_(ass.sync_session_class, _MySession)
def test_subclass_override(self, async_engine):
ass = _MyAS(async_engine, sync_session_class=Session)
is_true(not isinstance(ass.sync_session, _MySession))
is_(ass.sync_session_class, Session)