mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-15 21:27:23 -04:00
generalize scoped_session proxying and apply to asyncio elements
Reworked the proxy creation used by scoped_session() to be based on fully copied code with augmented docstrings and moved it into langhelpers. asyncio session, engine, connection can now take advantage of it so that all non-async methods are availble. Overall implementation of most important accessors / methods on AsyncConnection, etc. , including awaitable versions of invalidate, execution_options, etc. In order to support an event dispatcher on the async classes while still allowing them to hold __slots__, make some adjustments to the event system to allow that to be present, at least rudimentally. Fixes: #5628 Change-Id: I5eb6929fc1e4fdac99e4b767dcfd49672d56e2b2
This commit is contained in:
@@ -184,14 +184,17 @@ class Connection(Connectable):
|
||||
r""" Set non-SQL options for the connection which take effect
|
||||
during execution.
|
||||
|
||||
The method returns a copy of this :class:`_engine.Connection`
|
||||
which references
|
||||
the same underlying DBAPI connection, but also defines the given
|
||||
execution options which will take effect for a call to
|
||||
:meth:`execute`. As the new :class:`_engine.Connection`
|
||||
references the same
|
||||
underlying resource, it's usually a good idea to ensure that the copies
|
||||
will be discarded immediately, which is implicit if used as in::
|
||||
For a "future" style connection, this method returns this same
|
||||
:class:`_future.Connection` object with the new options added.
|
||||
|
||||
For a legacy connection, this method returns a copy of this
|
||||
:class:`_engine.Connection` which references the same underlying DBAPI
|
||||
connection, but also defines the given execution options which will
|
||||
take effect for a call to
|
||||
:meth:`execute`. As the new :class:`_engine.Connection` references the
|
||||
same underlying resource, it's usually a good idea to ensure that
|
||||
the copies will be discarded immediately, which is implicit if used
|
||||
as in::
|
||||
|
||||
result = connection.execution_options(stream_results=True).\
|
||||
execute(stmt)
|
||||
@@ -549,9 +552,10 @@ class Connection(Connectable):
|
||||
"""Invalidate the underlying DBAPI connection associated with
|
||||
this :class:`_engine.Connection`.
|
||||
|
||||
The underlying DBAPI connection is literally closed (if
|
||||
possible), and is discarded. Its source connection pool will
|
||||
typically lazily create a new connection to replace it.
|
||||
An attempt will be made to close the underlying DBAPI connection
|
||||
immediately; however if this operation fails, the error is logged
|
||||
but not raised. The connection is then discarded whether or not
|
||||
close() succeeded.
|
||||
|
||||
Upon the next use (where "use" typically means using the
|
||||
:meth:`_engine.Connection.execute` method or similar),
|
||||
@@ -580,6 +584,10 @@ class Connection(Connectable):
|
||||
will at the connection pool level invoke the
|
||||
:meth:`_events.PoolEvents.invalidate` event.
|
||||
|
||||
:param exception: an optional ``Exception`` instance that's the
|
||||
reason for the invalidation. is passed along to event handlers
|
||||
and logging functions.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`pool_connection_invalidation`
|
||||
|
||||
@@ -195,7 +195,14 @@ def _create_dispatcher_class(cls, classname, bases, dict_):
|
||||
dispatch_cls._event_names.append(ls.name)
|
||||
|
||||
if getattr(cls, "_dispatch_target", None):
|
||||
cls._dispatch_target.dispatch = dispatcher(cls)
|
||||
the_cls = cls._dispatch_target
|
||||
if (
|
||||
hasattr(the_cls, "__slots__")
|
||||
and "_slots_dispatch" in the_cls.__slots__
|
||||
):
|
||||
cls._dispatch_target.dispatch = slots_dispatcher(cls)
|
||||
else:
|
||||
cls._dispatch_target.dispatch = dispatcher(cls)
|
||||
|
||||
|
||||
def _remove_dispatcher(cls):
|
||||
@@ -304,5 +311,29 @@ class dispatcher(object):
|
||||
def __get__(self, obj, cls):
|
||||
if obj is None:
|
||||
return self.dispatch
|
||||
obj.__dict__["dispatch"] = disp = self.dispatch._for_instance(obj)
|
||||
|
||||
disp = self.dispatch._for_instance(obj)
|
||||
try:
|
||||
obj.__dict__["dispatch"] = disp
|
||||
except AttributeError as ae:
|
||||
util.raise_(
|
||||
TypeError(
|
||||
"target %r doesn't have __dict__, should it be "
|
||||
"defining _slots_dispatch?" % (obj,)
|
||||
),
|
||||
replace_context=ae,
|
||||
)
|
||||
return disp
|
||||
|
||||
|
||||
class slots_dispatcher(dispatcher):
|
||||
def __get__(self, obj, cls):
|
||||
if obj is None:
|
||||
return self.dispatch
|
||||
|
||||
if hasattr(obj, "_slots_dispatch"):
|
||||
return obj._slots_dispatch
|
||||
|
||||
disp = self.dispatch._for_instance(obj)
|
||||
obj._slots_dispatch = disp
|
||||
return disp
|
||||
|
||||
@@ -2,6 +2,8 @@ from .engine import AsyncConnection # noqa
|
||||
from .engine import AsyncEngine # noqa
|
||||
from .engine import AsyncTransaction # noqa
|
||||
from .engine import create_async_engine # noqa
|
||||
from .events import AsyncConnectionEvents # noqa
|
||||
from .events import AsyncSessionEvents # noqa
|
||||
from .result import AsyncMappingResult # noqa
|
||||
from .result import AsyncResult # noqa
|
||||
from .result import AsyncScalarResult # noqa
|
||||
|
||||
@@ -8,12 +8,11 @@ from .base import StartableContext
|
||||
from .result import AsyncResult
|
||||
from ... import exc
|
||||
from ... import util
|
||||
from ...engine import Connection
|
||||
from ...engine import create_engine as _create_engine
|
||||
from ...engine import Engine
|
||||
from ...engine import Result
|
||||
from ...engine import Transaction
|
||||
from ...engine.base import OptionEngineMixin
|
||||
from ...future import Connection
|
||||
from ...future import Engine
|
||||
from ...sql import Executable
|
||||
from ...util.concurrency import greenlet_spawn
|
||||
|
||||
@@ -41,7 +40,24 @@ def create_async_engine(*arg, **kw):
|
||||
return AsyncEngine(sync_engine)
|
||||
|
||||
|
||||
class AsyncConnection(StartableContext):
|
||||
class AsyncConnectable:
|
||||
__slots__ = "_slots_dispatch"
|
||||
|
||||
|
||||
@util.create_proxy_methods(
|
||||
Connection,
|
||||
":class:`_future.Connection`",
|
||||
":class:`_asyncio.AsyncConnection`",
|
||||
classmethods=[],
|
||||
methods=[],
|
||||
attributes=[
|
||||
"closed",
|
||||
"invalidated",
|
||||
"dialect",
|
||||
"default_isolation_level",
|
||||
],
|
||||
)
|
||||
class AsyncConnection(StartableContext, AsyncConnectable):
|
||||
"""An asyncio proxy for a :class:`_engine.Connection`.
|
||||
|
||||
:class:`_asyncio.AsyncConnection` is acquired using the
|
||||
@@ -58,15 +74,23 @@ class AsyncConnection(StartableContext):
|
||||
|
||||
""" # noqa
|
||||
|
||||
# AsyncConnection is a thin proxy; no state should be added here
|
||||
# that is not retrievable from the "sync" engine / connection, e.g.
|
||||
# current transaction, info, etc. It should be possible to
|
||||
# create a new AsyncConnection that matches this one given only the
|
||||
# "sync" elements.
|
||||
__slots__ = (
|
||||
"sync_engine",
|
||||
"sync_connection",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, sync_engine: Engine, sync_connection: Optional[Connection] = None
|
||||
self,
|
||||
async_engine: "AsyncEngine",
|
||||
sync_connection: Optional[Connection] = None,
|
||||
):
|
||||
self.sync_engine = sync_engine
|
||||
self.engine = async_engine
|
||||
self.sync_engine = async_engine.sync_engine
|
||||
self.sync_connection = sync_connection
|
||||
|
||||
async def start(self):
|
||||
@@ -79,6 +103,34 @@ class AsyncConnection(StartableContext):
|
||||
self.sync_connection = await (greenlet_spawn(self.sync_engine.connect))
|
||||
return self
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
"""Not implemented for async; call
|
||||
:meth:`_asyncio.AsyncConnection.get_raw_connection`.
|
||||
|
||||
"""
|
||||
raise exc.InvalidRequestError(
|
||||
"AsyncConnection.connection accessor is not implemented as the "
|
||||
"attribute may need to reconnect on an invalidated connection. "
|
||||
"Use the get_raw_connection() method."
|
||||
)
|
||||
|
||||
async def get_raw_connection(self):
|
||||
"""Return the pooled DBAPI-level connection in use by this
|
||||
:class:`_asyncio.AsyncConnection`.
|
||||
|
||||
This is typically the SQLAlchemy connection-pool proxied connection
|
||||
which then has an attribute .connection that refers to the actual
|
||||
DBAPI-level connection.
|
||||
"""
|
||||
conn = self._sync_connection()
|
||||
|
||||
return await greenlet_spawn(getattr, conn, "connection")
|
||||
|
||||
@property
|
||||
def _proxied(self):
|
||||
return self.sync_connection
|
||||
|
||||
def _sync_connection(self):
|
||||
if not self.sync_connection:
|
||||
self._raise_for_not_started()
|
||||
@@ -94,6 +146,43 @@ class AsyncConnection(StartableContext):
|
||||
self._sync_connection()
|
||||
return AsyncTransaction(self, nested=True)
|
||||
|
||||
async def invalidate(self, exception=None):
|
||||
"""Invalidate the underlying DBAPI connection associated with
|
||||
this :class:`_engine.Connection`.
|
||||
|
||||
See the method :meth:`_engine.Connection.invalidate` for full
|
||||
detail on this method.
|
||||
|
||||
"""
|
||||
|
||||
conn = self._sync_connection()
|
||||
return await greenlet_spawn(conn.invalidate, exception=exception)
|
||||
|
||||
async def get_isolation_level(self):
|
||||
conn = self._sync_connection()
|
||||
return await greenlet_spawn(conn.get_isolation_level)
|
||||
|
||||
async def set_isolation_level(self):
|
||||
conn = self._sync_connection()
|
||||
return await greenlet_spawn(conn.get_isolation_level)
|
||||
|
||||
async def execution_options(self, **opt):
|
||||
r"""Set non-SQL options for the connection which take effect
|
||||
during execution.
|
||||
|
||||
This returns this :class:`_asyncio.AsyncConnection` object with
|
||||
the new options added.
|
||||
|
||||
See :meth:`_future.Connection.execution_options` for full details
|
||||
on this method.
|
||||
|
||||
"""
|
||||
|
||||
conn = self._sync_connection()
|
||||
c2 = await greenlet_spawn(conn.execution_options, **opt)
|
||||
assert c2 is conn
|
||||
return self
|
||||
|
||||
async def commit(self):
|
||||
"""Commit the transaction that is currently in progress.
|
||||
|
||||
@@ -287,7 +376,19 @@ class AsyncConnection(StartableContext):
|
||||
await self.close()
|
||||
|
||||
|
||||
class AsyncEngine:
|
||||
@util.create_proxy_methods(
|
||||
Engine,
|
||||
":class:`_future.Engine`",
|
||||
":class:`_asyncio.AsyncEngine`",
|
||||
classmethods=[],
|
||||
methods=[
|
||||
"clear_compiled_cache",
|
||||
"update_execution_options",
|
||||
"get_execution_options",
|
||||
],
|
||||
attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
|
||||
)
|
||||
class AsyncEngine(AsyncConnectable):
|
||||
"""An asyncio proxy for a :class:`_engine.Engine`.
|
||||
|
||||
:class:`_asyncio.AsyncEngine` is acquired using the
|
||||
@@ -301,7 +402,12 @@ class AsyncEngine:
|
||||
|
||||
""" # noqa
|
||||
|
||||
__slots__ = ("sync_engine",)
|
||||
# AsyncEngine is a thin proxy; no state should be added here
|
||||
# that is not retrievable from the "sync" engine / connection, e.g.
|
||||
# current transaction, info, etc. It should be possible to
|
||||
# create a new AsyncEngine that matches this one given only the
|
||||
# "sync" elements.
|
||||
__slots__ = ("sync_engine", "_proxied")
|
||||
|
||||
_connection_cls = AsyncConnection
|
||||
|
||||
@@ -327,7 +433,7 @@ class AsyncEngine:
|
||||
await self.conn.close()
|
||||
|
||||
def __init__(self, sync_engine: Engine):
|
||||
self.sync_engine = sync_engine
|
||||
self.sync_engine = self._proxied = sync_engine
|
||||
|
||||
def begin(self):
|
||||
"""Return a context manager which when entered will deliver an
|
||||
@@ -363,7 +469,7 @@ class AsyncEngine:
|
||||
|
||||
"""
|
||||
|
||||
return self._connection_cls(self.sync_engine)
|
||||
return self._connection_cls(self)
|
||||
|
||||
async def raw_connection(self) -> Any:
|
||||
"""Return a "raw" DBAPI connection from the connection pool.
|
||||
@@ -375,12 +481,33 @@ class AsyncEngine:
|
||||
"""
|
||||
return await greenlet_spawn(self.sync_engine.raw_connection)
|
||||
|
||||
def execution_options(self, **opt):
|
||||
"""Return a new :class:`_asyncio.AsyncEngine` that will provide
|
||||
:class:`_asyncio.AsyncConnection` objects with the given execution
|
||||
options.
|
||||
|
||||
class AsyncOptionEngine(OptionEngineMixin, AsyncEngine):
|
||||
pass
|
||||
Proxied from :meth:`_future.Engine.execution_options`. See that
|
||||
method for details.
|
||||
|
||||
"""
|
||||
|
||||
AsyncEngine._option_cls = AsyncOptionEngine
|
||||
return AsyncEngine(self.sync_engine.execution_options(**opt))
|
||||
|
||||
async def dispose(self):
|
||||
"""Dispose of the connection pool used by this
|
||||
:class:`_asyncio.AsyncEngine`.
|
||||
|
||||
This will close all connection pool connections that are
|
||||
**currently checked in**. See the documentation for the underlying
|
||||
:meth:`_future.Engine.dispose` method for further notes.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_future.Engine.dispose`
|
||||
|
||||
"""
|
||||
|
||||
return await greenlet_spawn(self.sync_engine.dispose)
|
||||
|
||||
|
||||
class AsyncTransaction(StartableContext):
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
from .engine import AsyncConnectable
|
||||
from .session import AsyncSession
|
||||
from ...engine import events as engine_event
|
||||
from ...orm import events as orm_event
|
||||
|
||||
|
||||
class AsyncConnectionEvents(engine_event.ConnectionEvents):
|
||||
_target_class_doc = "SomeEngine"
|
||||
_dispatch_target = AsyncConnectable
|
||||
|
||||
@classmethod
|
||||
def _listen(cls, event_key, retval=False):
|
||||
raise NotImplementedError(
|
||||
"asynchronous events are not implemented at this time. Apply "
|
||||
"synchronous listeners to the AsyncEngine.sync_engine or "
|
||||
"AsyncConnection.sync_connection attributes."
|
||||
)
|
||||
|
||||
|
||||
class AsyncSessionEvents(orm_event.SessionEvents):
|
||||
_target_class_doc = "SomeSession"
|
||||
_dispatch_target = AsyncSession
|
||||
|
||||
@classmethod
|
||||
def _listen(cls, event_key, retval=False):
|
||||
raise NotImplementedError(
|
||||
"asynchronous events are not implemented at this time. Apply "
|
||||
"synchronous listeners to the AsyncSession.sync_session."
|
||||
)
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
|
||||
@@ -15,6 +14,35 @@ from ...sql import Executable
|
||||
from ...util.concurrency import greenlet_spawn
|
||||
|
||||
|
||||
@util.create_proxy_methods(
|
||||
Session,
|
||||
":class:`_orm.Session`",
|
||||
":class:`_asyncio.AsyncSession`",
|
||||
classmethods=["object_session", "identity_key"],
|
||||
methods=[
|
||||
"__contains__",
|
||||
"__iter__",
|
||||
"add",
|
||||
"add_all",
|
||||
"delete",
|
||||
"expire",
|
||||
"expire_all",
|
||||
"expunge",
|
||||
"expunge_all",
|
||||
"get_bind",
|
||||
"is_modified",
|
||||
],
|
||||
attributes=[
|
||||
"dirty",
|
||||
"deleted",
|
||||
"new",
|
||||
"identity_map",
|
||||
"is_active",
|
||||
"autoflush",
|
||||
"no_autoflush",
|
||||
"info",
|
||||
],
|
||||
)
|
||||
class AsyncSession:
|
||||
"""Asyncio version of :class:`_orm.Session`.
|
||||
|
||||
@@ -23,6 +51,16 @@ class AsyncSession:
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"binds",
|
||||
"bind",
|
||||
"sync_session",
|
||||
"_proxied",
|
||||
"_slots_dispatch",
|
||||
)
|
||||
|
||||
dispatch = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bind: AsyncEngine = None,
|
||||
@@ -31,46 +69,18 @@ class AsyncSession:
|
||||
):
|
||||
kw["future"] = True
|
||||
if bind:
|
||||
self.bind = engine
|
||||
bind = engine._get_sync_engine(bind)
|
||||
|
||||
if binds:
|
||||
self.binds = binds
|
||||
binds = {
|
||||
key: engine._get_sync_engine(b) for key, b in binds.items()
|
||||
}
|
||||
|
||||
self.sync_session = Session(bind=bind, binds=binds, **kw)
|
||||
|
||||
def add(self, instance: object) -> None:
|
||||
"""Place an object in this :class:`_asyncio.AsyncSession`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_orm.Session.add`
|
||||
|
||||
"""
|
||||
self.sync_session.add(instance)
|
||||
|
||||
def add_all(self, instances: List[object]) -> None:
|
||||
"""Add the given collection of instances to this
|
||||
:class:`_asyncio.AsyncSession`."""
|
||||
|
||||
self.sync_session.add_all(instances)
|
||||
|
||||
def expire_all(self):
|
||||
"""Expires all persistent instances within this Session.
|
||||
|
||||
See :meth:`_orm.Session.expire_all` for usage details.
|
||||
|
||||
"""
|
||||
self.sync_session.expire_all()
|
||||
|
||||
def expire(self, instance, attribute_names=None):
|
||||
"""Expire the attributes on an instance.
|
||||
|
||||
See :meth:`._orm.Session.expire` for usage details.
|
||||
|
||||
"""
|
||||
self.sync_session.expire()
|
||||
self.sync_session = self._proxied = Session(
|
||||
bind=bind, binds=binds, **kw
|
||||
)
|
||||
|
||||
async def refresh(
|
||||
self, instance, attribute_names=None, with_for_update=None
|
||||
@@ -178,8 +188,17 @@ class AsyncSession:
|
||||
:class:`.Session` object's transactional state.
|
||||
|
||||
"""
|
||||
|
||||
# POSSIBLY TODO: here, we see that the sync engine / connection
|
||||
# that are generated from AsyncEngine / AsyncConnection don't
|
||||
# provide any backlink from those sync objects back out to the
|
||||
# async ones. it's not *too* big a deal since AsyncEngine/Connection
|
||||
# are just proxies and all the state is actually in the sync
|
||||
# version of things. However! it has to stay that way :)
|
||||
sync_connection = await greenlet_spawn(self.sync_session.connection)
|
||||
return engine.AsyncConnection(sync_connection.engine, sync_connection)
|
||||
return engine.AsyncConnection(
|
||||
engine.AsyncEngine(sync_connection.engine), sync_connection
|
||||
)
|
||||
|
||||
def begin(self, **kw):
|
||||
"""Return an :class:`_asyncio.AsyncSessionTransaction` object.
|
||||
@@ -218,14 +237,22 @@ class AsyncSession:
|
||||
return AsyncSessionTransaction(self, nested=True)
|
||||
|
||||
async def rollback(self):
|
||||
"""Rollback the current transaction in progress."""
|
||||
return await greenlet_spawn(self.sync_session.rollback)
|
||||
|
||||
async def commit(self):
|
||||
"""Commit the current transaction in progress."""
|
||||
return await greenlet_spawn(self.sync_session.commit)
|
||||
|
||||
async def close(self):
|
||||
"""Close this :class:`_asyncio.AsyncSession`."""
|
||||
return await greenlet_spawn(self.sync_session.close)
|
||||
|
||||
@classmethod
|
||||
async def close_all(self):
|
||||
"""Close all :class:`_asyncio.AsyncSession` sessions."""
|
||||
return await greenlet_spawn(self.sync_session.close_all)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
|
||||
@@ -1371,7 +1371,8 @@ class SessionEvents(event.Events):
|
||||
elif isinstance(target, Session):
|
||||
return target
|
||||
else:
|
||||
return None
|
||||
# allows alternate SessionEvents-like-classes to be consulted
|
||||
return event.Events._accept_with(target)
|
||||
|
||||
@classmethod
|
||||
def _listen(cls, event_key, raw=False, restore_load_context=False, **kw):
|
||||
|
||||
@@ -9,14 +9,60 @@ from . import class_mapper
|
||||
from . import exc as orm_exc
|
||||
from .session import Session
|
||||
from .. import exc as sa_exc
|
||||
from ..util import create_proxy_methods
|
||||
from ..util import ScopedRegistry
|
||||
from ..util import ThreadLocalRegistry
|
||||
from ..util import warn
|
||||
|
||||
|
||||
__all__ = ["scoped_session"]
|
||||
|
||||
|
||||
@create_proxy_methods(
|
||||
Session,
|
||||
":class:`_orm.Session`",
|
||||
":class:`_orm.scoping.scoped_session`",
|
||||
classmethods=["close_all", "object_session", "identity_key"],
|
||||
methods=[
|
||||
"__contains__",
|
||||
"__iter__",
|
||||
"add",
|
||||
"add_all",
|
||||
"begin",
|
||||
"begin_nested",
|
||||
"close",
|
||||
"commit",
|
||||
"connection",
|
||||
"delete",
|
||||
"execute",
|
||||
"expire",
|
||||
"expire_all",
|
||||
"expunge",
|
||||
"expunge_all",
|
||||
"flush",
|
||||
"get_bind",
|
||||
"is_modified",
|
||||
"bulk_save_objects",
|
||||
"bulk_insert_mappings",
|
||||
"bulk_update_mappings",
|
||||
"merge",
|
||||
"query",
|
||||
"refresh",
|
||||
"rollback",
|
||||
"scalar",
|
||||
],
|
||||
attributes=[
|
||||
"bind",
|
||||
"dirty",
|
||||
"deleted",
|
||||
"new",
|
||||
"identity_map",
|
||||
"is_active",
|
||||
"autoflush",
|
||||
"no_autoflush",
|
||||
"info",
|
||||
"autocommit",
|
||||
],
|
||||
)
|
||||
class scoped_session(object):
|
||||
"""Provides scoped management of :class:`.Session` objects.
|
||||
|
||||
@@ -53,6 +99,10 @@ class scoped_session(object):
|
||||
else:
|
||||
self.registry = ThreadLocalRegistry(session_factory)
|
||||
|
||||
@property
|
||||
def _proxied(self):
|
||||
return self.registry()
|
||||
|
||||
def __call__(self, **kw):
|
||||
r"""Return the current :class:`.Session`, creating it
|
||||
using the :attr:`.scoped_session.session_factory` if not present.
|
||||
@@ -156,50 +206,3 @@ class scoped_session(object):
|
||||
|
||||
ScopedSession = scoped_session
|
||||
"""Old name for backwards compatibility."""
|
||||
|
||||
|
||||
def instrument(name):
|
||||
def do(self, *args, **kwargs):
|
||||
return getattr(self.registry(), name)(*args, **kwargs)
|
||||
|
||||
return do
|
||||
|
||||
|
||||
for meth in Session.public_methods:
|
||||
setattr(scoped_session, meth, instrument(meth))
|
||||
|
||||
|
||||
def makeprop(name):
|
||||
def set_(self, attr):
|
||||
setattr(self.registry(), name, attr)
|
||||
|
||||
def get(self):
|
||||
return getattr(self.registry(), name)
|
||||
|
||||
return property(get, set_)
|
||||
|
||||
|
||||
for prop in (
|
||||
"bind",
|
||||
"dirty",
|
||||
"deleted",
|
||||
"new",
|
||||
"identity_map",
|
||||
"is_active",
|
||||
"autoflush",
|
||||
"no_autoflush",
|
||||
"info",
|
||||
"autocommit",
|
||||
):
|
||||
setattr(scoped_session, prop, makeprop(prop))
|
||||
|
||||
|
||||
def clslevel(name):
|
||||
def do(cls, *args, **kwargs):
|
||||
return getattr(Session, name)(*args, **kwargs)
|
||||
|
||||
return classmethod(do)
|
||||
|
||||
|
||||
for prop in ("close_all", "object_session", "identity_key"):
|
||||
setattr(scoped_session, prop, clslevel(prop))
|
||||
|
||||
@@ -835,35 +835,6 @@ class Session(_SessionClassMethods):
|
||||
|
||||
"""
|
||||
|
||||
public_methods = (
|
||||
"__contains__",
|
||||
"__iter__",
|
||||
"add",
|
||||
"add_all",
|
||||
"begin",
|
||||
"begin_nested",
|
||||
"close",
|
||||
"commit",
|
||||
"connection",
|
||||
"delete",
|
||||
"execute",
|
||||
"expire",
|
||||
"expire_all",
|
||||
"expunge",
|
||||
"expunge_all",
|
||||
"flush",
|
||||
"get_bind",
|
||||
"is_modified",
|
||||
"bulk_save_objects",
|
||||
"bulk_insert_mappings",
|
||||
"bulk_update_mappings",
|
||||
"merge",
|
||||
"query",
|
||||
"refresh",
|
||||
"rollback",
|
||||
"scalar",
|
||||
)
|
||||
|
||||
@util.deprecated_params(
|
||||
autocommit=(
|
||||
"2.0",
|
||||
@@ -3028,7 +2999,14 @@ class Session(_SessionClassMethods):
|
||||
will unexpire attributes on access.
|
||||
|
||||
"""
|
||||
state = attributes.instance_state(obj)
|
||||
try:
|
||||
state = attributes.instance_state(obj)
|
||||
except exc.NO_STATE as err:
|
||||
util.raise_(
|
||||
exc.UnmappedInstanceError(obj),
|
||||
replace_context=err,
|
||||
)
|
||||
|
||||
to_attach = self._before_attach(state, obj)
|
||||
state._load_pending = True
|
||||
if to_attach:
|
||||
|
||||
@@ -509,6 +509,7 @@ class _ConnectionRecord(object):
|
||||
"Soft " if soft else "",
|
||||
self.connection,
|
||||
)
|
||||
|
||||
if soft:
|
||||
self._soft_invalidate_time = time.time()
|
||||
else:
|
||||
|
||||
@@ -372,7 +372,7 @@ def _pytest_fn_decorator(target):
|
||||
if add_positional_parameters:
|
||||
spec.args.extend(add_positional_parameters)
|
||||
|
||||
metadata = dict(target="target", fn="fn", name=fn.__name__)
|
||||
metadata = dict(target="target", fn="__fn", name=fn.__name__)
|
||||
metadata.update(format_argspec_plus(spec, grouped=False))
|
||||
code = (
|
||||
"""\
|
||||
@@ -382,7 +382,7 @@ def %(name)s(%(args)s):
|
||||
% metadata
|
||||
)
|
||||
decorated = _exec_code_in_env(
|
||||
code, {"target": target, "fn": fn}, fn.__name__
|
||||
code, {"target": target, "__fn": fn}, fn.__name__
|
||||
)
|
||||
if not add_positional_parameters:
|
||||
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
|
||||
|
||||
@@ -123,6 +123,7 @@ from .langhelpers import coerce_kw_type # noqa
|
||||
from .langhelpers import constructor_copy # noqa
|
||||
from .langhelpers import constructor_key # noqa
|
||||
from .langhelpers import counter # noqa
|
||||
from .langhelpers import create_proxy_methods # noqa
|
||||
from .langhelpers import decode_slice # noqa
|
||||
from .langhelpers import decorator # noqa
|
||||
from .langhelpers import dictlike_iteritems # noqa
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
modules, classes, hierarchies, attributes, functions, and methods.
|
||||
|
||||
"""
|
||||
|
||||
from functools import update_wrapper
|
||||
import hashlib
|
||||
import inspect
|
||||
@@ -462,6 +463,8 @@ def format_argspec_plus(fn, grouped=True):
|
||||
passed positionally.
|
||||
apply_kw
|
||||
Like apply_pos, except keyword-ish args are passed as keywords.
|
||||
apply_pos_proxied
|
||||
Like apply_pos but omits the self/cls argument
|
||||
|
||||
Example::
|
||||
|
||||
@@ -478,16 +481,27 @@ def format_argspec_plus(fn, grouped=True):
|
||||
spec = fn
|
||||
|
||||
args = compat.inspect_formatargspec(*spec)
|
||||
if spec[0]:
|
||||
self_arg = spec[0][0]
|
||||
elif spec[1]:
|
||||
self_arg = "%s[0]" % spec[1]
|
||||
else:
|
||||
self_arg = None
|
||||
|
||||
apply_pos = compat.inspect_formatargspec(
|
||||
spec[0], spec[1], spec[2], None, spec[4]
|
||||
)
|
||||
|
||||
if spec[0]:
|
||||
self_arg = spec[0][0]
|
||||
|
||||
apply_pos_proxied = compat.inspect_formatargspec(
|
||||
spec[0][1:], spec[1], spec[2], None, spec[4]
|
||||
)
|
||||
|
||||
elif spec[1]:
|
||||
# im not sure what this is
|
||||
self_arg = "%s[0]" % spec[1]
|
||||
|
||||
apply_pos_proxied = apply_pos
|
||||
else:
|
||||
self_arg = None
|
||||
apply_pos_proxied = apply_pos
|
||||
|
||||
num_defaults = 0
|
||||
if spec[3]:
|
||||
num_defaults += len(spec[3])
|
||||
@@ -513,6 +527,7 @@ def format_argspec_plus(fn, grouped=True):
|
||||
self_arg=self_arg,
|
||||
apply_pos=apply_pos,
|
||||
apply_kw=apply_kw,
|
||||
apply_pos_proxied=apply_pos_proxied,
|
||||
)
|
||||
else:
|
||||
return dict(
|
||||
@@ -520,6 +535,7 @@ def format_argspec_plus(fn, grouped=True):
|
||||
self_arg=self_arg,
|
||||
apply_pos=apply_pos[1:-1],
|
||||
apply_kw=apply_kw[1:-1],
|
||||
apply_pos_proxied=apply_pos_proxied[1:-1],
|
||||
)
|
||||
|
||||
|
||||
@@ -534,17 +550,140 @@ def format_argspec_init(method, grouped=True):
|
||||
|
||||
"""
|
||||
if method is object.__init__:
|
||||
args = grouped and "(self)" or "self"
|
||||
args = "(self)" if grouped else "self"
|
||||
proxied = "()" if grouped else ""
|
||||
else:
|
||||
try:
|
||||
return format_argspec_plus(method, grouped=grouped)
|
||||
except TypeError:
|
||||
args = (
|
||||
grouped
|
||||
and "(self, *args, **kwargs)"
|
||||
or "self, *args, **kwargs"
|
||||
"(self, *args, **kwargs)"
|
||||
if grouped
|
||||
else "self, *args, **kwargs"
|
||||
)
|
||||
return dict(self_arg="self", args=args, apply_pos=args, apply_kw=args)
|
||||
proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
|
||||
return dict(
|
||||
self_arg="self",
|
||||
args=args,
|
||||
apply_pos=args,
|
||||
apply_kw=args,
|
||||
apply_pos_proxied=proxied,
|
||||
)
|
||||
|
||||
|
||||
def create_proxy_methods(
|
||||
target_cls,
|
||||
target_cls_sphinx_name,
|
||||
proxy_cls_sphinx_name,
|
||||
classmethods=(),
|
||||
methods=(),
|
||||
attributes=(),
|
||||
):
|
||||
"""A class decorator that will copy attributes to a proxy class.
|
||||
|
||||
The class to be instrumented must define a single accessor "_proxied".
|
||||
|
||||
"""
|
||||
|
||||
def decorate(cls):
|
||||
def instrument(name, clslevel=False):
|
||||
fn = getattr(target_cls, name)
|
||||
spec = compat.inspect_getfullargspec(fn)
|
||||
env = {}
|
||||
|
||||
spec = _update_argspec_defaults_into_env(spec, env)
|
||||
caller_argspec = format_argspec_plus(spec, grouped=False)
|
||||
|
||||
metadata = {
|
||||
"name": fn.__name__,
|
||||
"apply_pos_proxied": caller_argspec["apply_pos_proxied"],
|
||||
"args": caller_argspec["args"],
|
||||
"self_arg": caller_argspec["self_arg"],
|
||||
}
|
||||
|
||||
if clslevel:
|
||||
code = (
|
||||
"def %(name)s(%(args)s):\n"
|
||||
" return target_cls.%(name)s(%(apply_pos_proxied)s)"
|
||||
% metadata
|
||||
)
|
||||
env["target_cls"] = target_cls
|
||||
else:
|
||||
code = (
|
||||
"def %(name)s(%(args)s):\n"
|
||||
" return %(self_arg)s._proxied.%(name)s(%(apply_pos_proxied)s)" # noqa E501
|
||||
% metadata
|
||||
)
|
||||
|
||||
proxy_fn = _exec_code_in_env(code, env, fn.__name__)
|
||||
proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__
|
||||
proxy_fn.__doc__ = inject_docstring_text(
|
||||
fn.__doc__,
|
||||
".. container:: class_bases\n\n "
|
||||
"Proxied for the %s class on behalf of the %s class."
|
||||
% (target_cls_sphinx_name, proxy_cls_sphinx_name),
|
||||
1,
|
||||
)
|
||||
|
||||
if clslevel:
|
||||
proxy_fn = classmethod(proxy_fn)
|
||||
|
||||
return proxy_fn
|
||||
|
||||
def makeprop(name):
|
||||
attr = target_cls.__dict__.get(name, None)
|
||||
|
||||
if attr is not None:
|
||||
doc = inject_docstring_text(
|
||||
attr.__doc__,
|
||||
".. container:: class_bases\n\n "
|
||||
"Proxied for the %s class on behalf of the %s class."
|
||||
% (
|
||||
target_cls_sphinx_name,
|
||||
proxy_cls_sphinx_name,
|
||||
),
|
||||
1,
|
||||
)
|
||||
else:
|
||||
doc = None
|
||||
|
||||
code = (
|
||||
"def set_(self, attr):\n"
|
||||
" self._proxied.%(name)s = attr\n"
|
||||
"def get(self):\n"
|
||||
" return self._proxied.%(name)s\n"
|
||||
"get.__doc__ = doc\n"
|
||||
"getset = property(get, set_)"
|
||||
) % {"name": name}
|
||||
|
||||
getset = _exec_code_in_env(code, {"doc": doc}, "getset")
|
||||
|
||||
return getset
|
||||
|
||||
for meth in methods:
|
||||
if hasattr(cls, meth):
|
||||
raise TypeError(
|
||||
"class %s already has a method %s" % (cls, meth)
|
||||
)
|
||||
setattr(cls, meth, instrument(meth))
|
||||
|
||||
for prop in attributes:
|
||||
if hasattr(cls, prop):
|
||||
raise TypeError(
|
||||
"class %s already has a method %s" % (cls, prop)
|
||||
)
|
||||
setattr(cls, prop, makeprop(prop))
|
||||
|
||||
for prop in classmethods:
|
||||
if hasattr(cls, prop):
|
||||
raise TypeError(
|
||||
"class %s already has a method %s" % (cls, prop)
|
||||
)
|
||||
setattr(cls, prop, instrument(prop, clslevel=True))
|
||||
|
||||
return cls
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def getargspec_init(method):
|
||||
|
||||
+77
-49
@@ -15,7 +15,19 @@ from sqlalchemy.testing.mock import Mock
|
||||
from sqlalchemy.testing.util import gc_collect
|
||||
|
||||
|
||||
class EventsTest(fixtures.TestBase):
|
||||
class TearDownLocalEventsFixture(object):
|
||||
def tearDown(self):
|
||||
classes = set()
|
||||
for entry in event.base._registrars.values():
|
||||
for evt_cls in entry:
|
||||
if evt_cls.__module__ == __name__:
|
||||
classes.add(evt_cls)
|
||||
|
||||
for evt_cls in classes:
|
||||
event.base._remove_dispatcher(evt_cls)
|
||||
|
||||
|
||||
class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase):
|
||||
"""Test class- and instance-level event registration."""
|
||||
|
||||
def setUp(self):
|
||||
@@ -34,9 +46,6 @@ class EventsTest(fixtures.TestBase):
|
||||
|
||||
self.Target = Target
|
||||
|
||||
def tearDown(self):
|
||||
event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events)
|
||||
|
||||
def test_register_class(self):
|
||||
def listen(x, y):
|
||||
pass
|
||||
@@ -258,7 +267,60 @@ class EventsTest(fixtures.TestBase):
|
||||
)
|
||||
|
||||
|
||||
class NamedCallTest(fixtures.TestBase):
|
||||
class SlotsEventsTest(fixtures.TestBase):
|
||||
@testing.requires.python3
|
||||
def test_no_slots_dispatch(self):
|
||||
class Target(object):
|
||||
__slots__ = ()
|
||||
|
||||
class TargetEvents(event.Events):
|
||||
_dispatch_target = Target
|
||||
|
||||
def event_one(self, x, y):
|
||||
pass
|
||||
|
||||
def event_two(self, x):
|
||||
pass
|
||||
|
||||
def event_three(self, x):
|
||||
pass
|
||||
|
||||
t1 = Target()
|
||||
|
||||
with testing.expect_raises_message(
|
||||
TypeError,
|
||||
r"target .*Target.* doesn't have __dict__, should it "
|
||||
"be defining _slots_dispatch",
|
||||
):
|
||||
event.listen(t1, "event_one", Mock())
|
||||
|
||||
def test_slots_dispatch(self):
|
||||
class Target(object):
|
||||
__slots__ = ("_slots_dispatch",)
|
||||
|
||||
class TargetEvents(event.Events):
|
||||
_dispatch_target = Target
|
||||
|
||||
def event_one(self, x, y):
|
||||
pass
|
||||
|
||||
def event_two(self, x):
|
||||
pass
|
||||
|
||||
def event_three(self, x):
|
||||
pass
|
||||
|
||||
t1 = Target()
|
||||
|
||||
m1 = Mock()
|
||||
event.listen(t1, "event_one", m1)
|
||||
|
||||
t1.dispatch.event_one(2, 4)
|
||||
|
||||
eq_(m1.mock_calls, [call(2, 4)])
|
||||
|
||||
|
||||
class NamedCallTest(TearDownLocalEventsFixture, fixtures.TestBase):
|
||||
def _fixture(self):
|
||||
class TargetEventsOne(event.Events):
|
||||
def event_one(self, x, y):
|
||||
@@ -373,7 +435,7 @@ class NamedCallTest(fixtures.TestBase):
|
||||
eq_(canary.mock_calls, [call({"x": 4, "y": 5, "z": 8, "q": 5})])
|
||||
|
||||
|
||||
class LegacySignatureTest(fixtures.TestBase):
|
||||
class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
|
||||
"""test adaption of legacy args"""
|
||||
|
||||
def setUp(self):
|
||||
@@ -397,11 +459,6 @@ class LegacySignatureTest(fixtures.TestBase):
|
||||
|
||||
self.TargetOne = TargetOne
|
||||
|
||||
def tearDown(self):
|
||||
event.base._remove_dispatcher(
|
||||
self.TargetOne.__dict__["dispatch"].events
|
||||
)
|
||||
|
||||
def test_legacy_accept(self):
|
||||
canary = Mock()
|
||||
|
||||
@@ -550,12 +607,7 @@ class LegacySignatureTest(fixtures.TestBase):
|
||||
)
|
||||
|
||||
|
||||
class ClsLevelListenTest(fixtures.TestBase):
|
||||
def tearDown(self):
|
||||
event.base._remove_dispatcher(
|
||||
self.TargetOne.__dict__["dispatch"].events
|
||||
)
|
||||
|
||||
class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
|
||||
def setUp(self):
|
||||
class TargetEventsOne(event.Events):
|
||||
def event_one(self, x, y):
|
||||
@@ -622,7 +674,7 @@ class ClsLevelListenTest(fixtures.TestBase):
|
||||
assert handler2 not in s2.dispatch.event_one
|
||||
|
||||
|
||||
class AcceptTargetsTest(fixtures.TestBase):
|
||||
class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
|
||||
"""Test default target acceptance."""
|
||||
|
||||
def setUp(self):
|
||||
@@ -643,14 +695,6 @@ class AcceptTargetsTest(fixtures.TestBase):
|
||||
self.TargetOne = TargetOne
|
||||
self.TargetTwo = TargetTwo
|
||||
|
||||
def tearDown(self):
|
||||
event.base._remove_dispatcher(
|
||||
self.TargetOne.__dict__["dispatch"].events
|
||||
)
|
||||
event.base._remove_dispatcher(
|
||||
self.TargetTwo.__dict__["dispatch"].events
|
||||
)
|
||||
|
||||
def test_target_accept(self):
|
||||
"""Test that events of the same name are routed to the correct
|
||||
collection based on the type of target given.
|
||||
@@ -687,7 +731,7 @@ class AcceptTargetsTest(fixtures.TestBase):
|
||||
eq_(list(t2.dispatch.event_one), [listen_two, listen_four])
|
||||
|
||||
|
||||
class CustomTargetsTest(fixtures.TestBase):
|
||||
class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
|
||||
"""Test custom target acceptance."""
|
||||
|
||||
def setUp(self):
|
||||
@@ -707,9 +751,6 @@ class CustomTargetsTest(fixtures.TestBase):
|
||||
|
||||
self.Target = Target
|
||||
|
||||
def tearDown(self):
|
||||
event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events)
|
||||
|
||||
def test_indirect(self):
|
||||
def listen(x, y):
|
||||
pass
|
||||
@@ -727,7 +768,7 @@ class CustomTargetsTest(fixtures.TestBase):
|
||||
)
|
||||
|
||||
|
||||
class SubclassGrowthTest(fixtures.TestBase):
|
||||
class SubclassGrowthTest(TearDownLocalEventsFixture, fixtures.TestBase):
|
||||
"""test that ad-hoc subclasses are garbage collected."""
|
||||
|
||||
def setUp(self):
|
||||
@@ -752,7 +793,7 @@ class SubclassGrowthTest(fixtures.TestBase):
|
||||
eq_(self.Target.__subclasses__(), [])
|
||||
|
||||
|
||||
class ListenOverrideTest(fixtures.TestBase):
|
||||
class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase):
|
||||
"""Test custom listen functions which change the listener function
|
||||
signature."""
|
||||
|
||||
@@ -778,9 +819,6 @@ class ListenOverrideTest(fixtures.TestBase):
|
||||
|
||||
self.Target = Target
|
||||
|
||||
def tearDown(self):
|
||||
event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events)
|
||||
|
||||
def test_listen_override(self):
|
||||
listen_one = Mock()
|
||||
listen_two = Mock()
|
||||
@@ -816,7 +854,7 @@ class ListenOverrideTest(fixtures.TestBase):
|
||||
eq_(listen_one.mock_calls, [call(12)])
|
||||
|
||||
|
||||
class PropagateTest(fixtures.TestBase):
|
||||
class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
|
||||
def setUp(self):
|
||||
class TargetEvents(event.Events):
|
||||
def event_one(self, arg):
|
||||
@@ -850,7 +888,7 @@ class PropagateTest(fixtures.TestBase):
|
||||
eq_(listen_two.mock_calls, [])
|
||||
|
||||
|
||||
class JoinTest(fixtures.TestBase):
|
||||
class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
|
||||
def setUp(self):
|
||||
class TargetEvents(event.Events):
|
||||
def event_one(self, target, arg):
|
||||
@@ -875,11 +913,6 @@ class JoinTest(fixtures.TestBase):
|
||||
self.TargetFactory = TargetFactory
|
||||
self.TargetElement = TargetElement
|
||||
|
||||
def tearDown(self):
|
||||
for cls in (self.TargetElement, self.TargetFactory, self.BaseTarget):
|
||||
if "dispatch" in cls.__dict__:
|
||||
event.base._remove_dispatcher(cls.__dict__["dispatch"].events)
|
||||
|
||||
def test_neither(self):
|
||||
element = self.TargetFactory().create()
|
||||
element.run_event(1)
|
||||
@@ -1075,7 +1108,7 @@ class JoinTest(fixtures.TestBase):
|
||||
)
|
||||
|
||||
|
||||
class DisableClsPropagateTest(fixtures.TestBase):
|
||||
class DisableClsPropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
|
||||
def setUp(self):
|
||||
class TargetEvents(event.Events):
|
||||
def event_one(self, target, arg):
|
||||
@@ -1093,11 +1126,6 @@ class DisableClsPropagateTest(fixtures.TestBase):
|
||||
self.BaseTarget = BaseTarget
|
||||
self.SubTarget = SubTarget
|
||||
|
||||
def tearDown(self):
|
||||
for cls in (self.SubTarget, self.BaseTarget):
|
||||
if "dispatch" in cls.__dict__:
|
||||
event.base._remove_dispatcher(cls.__dict__["dispatch"].events)
|
||||
|
||||
def test_listen_invoke_clslevel(self):
|
||||
canary = Mock()
|
||||
|
||||
@@ -1132,7 +1160,7 @@ class DisableClsPropagateTest(fixtures.TestBase):
|
||||
eq_(canary.mock_calls, [])
|
||||
|
||||
|
||||
class RemovalTest(fixtures.TestBase):
|
||||
class RemovalTest(TearDownLocalEventsFixture, fixtures.TestBase):
|
||||
def _fixture(self):
|
||||
class TargetEvents(event.Events):
|
||||
def event_one(self, x, y):
|
||||
|
||||
+199
-163
@@ -2300,7 +2300,14 @@ class SymbolTest(fixtures.TestBase):
|
||||
|
||||
|
||||
class _Py3KFixtures(object):
|
||||
pass
|
||||
def _kw_only_fixture(self):
|
||||
pass
|
||||
|
||||
def _kw_plus_posn_fixture(self):
|
||||
pass
|
||||
|
||||
def _kw_opt_fixture(self):
|
||||
pass
|
||||
|
||||
|
||||
if util.py3k:
|
||||
@@ -2321,9 +2328,193 @@ def _kw_opt_fixture(self, a, *, b, c="c"):
|
||||
for k in _locals:
|
||||
setattr(_Py3KFixtures, k, _locals[k])
|
||||
|
||||
py3k_fixtures = _Py3KFixtures()
|
||||
|
||||
|
||||
class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
|
||||
def _test_format_argspec_plus(self, fn, wanted, grouped=None):
|
||||
@testing.combinations(
|
||||
(
|
||||
lambda: None,
|
||||
{
|
||||
"args": "()",
|
||||
"self_arg": None,
|
||||
"apply_kw": "()",
|
||||
"apply_pos": "()",
|
||||
"apply_pos_proxied": "()",
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
lambda: None,
|
||||
{
|
||||
"args": "",
|
||||
"self_arg": None,
|
||||
"apply_kw": "",
|
||||
"apply_pos": "",
|
||||
"apply_pos_proxied": "",
|
||||
},
|
||||
False,
|
||||
),
|
||||
(
|
||||
lambda self: None,
|
||||
{
|
||||
"args": "(self)",
|
||||
"self_arg": "self",
|
||||
"apply_kw": "(self)",
|
||||
"apply_pos": "(self)",
|
||||
"apply_pos_proxied": "()",
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
lambda self: None,
|
||||
{
|
||||
"args": "self",
|
||||
"self_arg": "self",
|
||||
"apply_kw": "self",
|
||||
"apply_pos": "self",
|
||||
"apply_pos_proxied": "",
|
||||
},
|
||||
False,
|
||||
),
|
||||
(
|
||||
lambda *a: None,
|
||||
{
|
||||
"args": "(*a)",
|
||||
"self_arg": "a[0]",
|
||||
"apply_kw": "(*a)",
|
||||
"apply_pos": "(*a)",
|
||||
"apply_pos_proxied": "(*a)",
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
lambda **kw: None,
|
||||
{
|
||||
"args": "(**kw)",
|
||||
"self_arg": None,
|
||||
"apply_kw": "(**kw)",
|
||||
"apply_pos": "(**kw)",
|
||||
"apply_pos_proxied": "(**kw)",
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
lambda *a, **kw: None,
|
||||
{
|
||||
"args": "(*a, **kw)",
|
||||
"self_arg": "a[0]",
|
||||
"apply_kw": "(*a, **kw)",
|
||||
"apply_pos": "(*a, **kw)",
|
||||
"apply_pos_proxied": "(*a, **kw)",
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
lambda a, *b: None,
|
||||
{
|
||||
"args": "(a, *b)",
|
||||
"self_arg": "a",
|
||||
"apply_kw": "(a, *b)",
|
||||
"apply_pos": "(a, *b)",
|
||||
"apply_pos_proxied": "(*b)",
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
lambda a, **b: None,
|
||||
{
|
||||
"args": "(a, **b)",
|
||||
"self_arg": "a",
|
||||
"apply_kw": "(a, **b)",
|
||||
"apply_pos": "(a, **b)",
|
||||
"apply_pos_proxied": "(**b)",
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
lambda a, *b, **c: None,
|
||||
{
|
||||
"args": "(a, *b, **c)",
|
||||
"self_arg": "a",
|
||||
"apply_kw": "(a, *b, **c)",
|
||||
"apply_pos": "(a, *b, **c)",
|
||||
"apply_pos_proxied": "(*b, **c)",
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
lambda a, b=1, **c: None,
|
||||
{
|
||||
"args": "(a, b=1, **c)",
|
||||
"self_arg": "a",
|
||||
"apply_kw": "(a, b=b, **c)",
|
||||
"apply_pos": "(a, b, **c)",
|
||||
"apply_pos_proxied": "(b, **c)",
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
lambda a=1, b=2: None,
|
||||
{
|
||||
"args": "(a=1, b=2)",
|
||||
"self_arg": "a",
|
||||
"apply_kw": "(a=a, b=b)",
|
||||
"apply_pos": "(a, b)",
|
||||
"apply_pos_proxied": "(b)",
|
||||
},
|
||||
True,
|
||||
),
|
||||
(
|
||||
lambda a=1, b=2: None,
|
||||
{
|
||||
"args": "a=1, b=2",
|
||||
"self_arg": "a",
|
||||
"apply_kw": "a=a, b=b",
|
||||
"apply_pos": "a, b",
|
||||
"apply_pos_proxied": "b",
|
||||
},
|
||||
False,
|
||||
),
|
||||
(
|
||||
py3k_fixtures._kw_only_fixture,
|
||||
{
|
||||
"args": "self, a, *, b, c",
|
||||
"self_arg": "self",
|
||||
"apply_pos": "self, a, *, b, c",
|
||||
"apply_kw": "self, a, b=b, c=c",
|
||||
"apply_pos_proxied": "a, *, b, c",
|
||||
},
|
||||
False,
|
||||
testing.requires.python3,
|
||||
),
|
||||
(
|
||||
py3k_fixtures._kw_plus_posn_fixture,
|
||||
{
|
||||
"args": "self, a, *args, b, c",
|
||||
"self_arg": "self",
|
||||
"apply_pos": "self, a, *args, b, c",
|
||||
"apply_kw": "self, a, b=b, c=c, *args",
|
||||
"apply_pos_proxied": "a, *args, b, c",
|
||||
},
|
||||
False,
|
||||
testing.requires.python3,
|
||||
),
|
||||
(
|
||||
py3k_fixtures._kw_opt_fixture,
|
||||
{
|
||||
"args": "self, a, *, b, c='c'",
|
||||
"self_arg": "self",
|
||||
"apply_pos": "self, a, *, b, c",
|
||||
"apply_kw": "self, a, b=b, c=c",
|
||||
"apply_pos_proxied": "a, *, b, c",
|
||||
},
|
||||
False,
|
||||
testing.requires.python3,
|
||||
),
|
||||
argnames="fn,wanted,grouped",
|
||||
)
|
||||
def test_specs(self, fn, wanted, grouped):
|
||||
|
||||
# test direct function
|
||||
if grouped is None:
|
||||
@@ -2340,167 +2531,6 @@ class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
|
||||
parsed = util.format_argspec_plus(spec, grouped=grouped)
|
||||
eq_(parsed, wanted)
|
||||
|
||||
def test_specs(self):
|
||||
self._test_format_argspec_plus(
|
||||
lambda: None,
|
||||
{
|
||||
"args": "()",
|
||||
"self_arg": None,
|
||||
"apply_kw": "()",
|
||||
"apply_pos": "()",
|
||||
},
|
||||
)
|
||||
|
||||
self._test_format_argspec_plus(
|
||||
lambda: None,
|
||||
{"args": "", "self_arg": None, "apply_kw": "", "apply_pos": ""},
|
||||
grouped=False,
|
||||
)
|
||||
|
||||
self._test_format_argspec_plus(
|
||||
lambda self: None,
|
||||
{
|
||||
"args": "(self)",
|
||||
"self_arg": "self",
|
||||
"apply_kw": "(self)",
|
||||
"apply_pos": "(self)",
|
||||
},
|
||||
)
|
||||
|
||||
self._test_format_argspec_plus(
|
||||
lambda self: None,
|
||||
{
|
||||
"args": "self",
|
||||
"self_arg": "self",
|
||||
"apply_kw": "self",
|
||||
"apply_pos": "self",
|
||||
},
|
||||
grouped=False,
|
||||
)
|
||||
|
||||
self._test_format_argspec_plus(
|
||||
lambda *a: None,
|
||||
{
|
||||
"args": "(*a)",
|
||||
"self_arg": "a[0]",
|
||||
"apply_kw": "(*a)",
|
||||
"apply_pos": "(*a)",
|
||||
},
|
||||
)
|
||||
|
||||
self._test_format_argspec_plus(
|
||||
lambda **kw: None,
|
||||
{
|
||||
"args": "(**kw)",
|
||||
"self_arg": None,
|
||||
"apply_kw": "(**kw)",
|
||||
"apply_pos": "(**kw)",
|
||||
},
|
||||
)
|
||||
|
||||
self._test_format_argspec_plus(
|
||||
lambda *a, **kw: None,
|
||||
{
|
||||
"args": "(*a, **kw)",
|
||||
"self_arg": "a[0]",
|
||||
"apply_kw": "(*a, **kw)",
|
||||
"apply_pos": "(*a, **kw)",
|
||||
},
|
||||
)
|
||||
|
||||
self._test_format_argspec_plus(
|
||||
lambda a, *b: None,
|
||||
{
|
||||
"args": "(a, *b)",
|
||||
"self_arg": "a",
|
||||
"apply_kw": "(a, *b)",
|
||||
"apply_pos": "(a, *b)",
|
||||
},
|
||||
)
|
||||
|
||||
self._test_format_argspec_plus(
|
||||
lambda a, **b: None,
|
||||
{
|
||||
"args": "(a, **b)",
|
||||
"self_arg": "a",
|
||||
"apply_kw": "(a, **b)",
|
||||
"apply_pos": "(a, **b)",
|
||||
},
|
||||
)
|
||||
|
||||
self._test_format_argspec_plus(
|
||||
lambda a, *b, **c: None,
|
||||
{
|
||||
"args": "(a, *b, **c)",
|
||||
"self_arg": "a",
|
||||
"apply_kw": "(a, *b, **c)",
|
||||
"apply_pos": "(a, *b, **c)",
|
||||
},
|
||||
)
|
||||
|
||||
self._test_format_argspec_plus(
|
||||
lambda a, b=1, **c: None,
|
||||
{
|
||||
"args": "(a, b=1, **c)",
|
||||
"self_arg": "a",
|
||||
"apply_kw": "(a, b=b, **c)",
|
||||
"apply_pos": "(a, b, **c)",
|
||||
},
|
||||
)
|
||||
|
||||
self._test_format_argspec_plus(
|
||||
lambda a=1, b=2: None,
|
||||
{
|
||||
"args": "(a=1, b=2)",
|
||||
"self_arg": "a",
|
||||
"apply_kw": "(a=a, b=b)",
|
||||
"apply_pos": "(a, b)",
|
||||
},
|
||||
)
|
||||
|
||||
self._test_format_argspec_plus(
|
||||
lambda a=1, b=2: None,
|
||||
{
|
||||
"args": "a=1, b=2",
|
||||
"self_arg": "a",
|
||||
"apply_kw": "a=a, b=b",
|
||||
"apply_pos": "a, b",
|
||||
},
|
||||
grouped=False,
|
||||
)
|
||||
|
||||
if util.py3k:
|
||||
self._test_format_argspec_plus(
|
||||
self._kw_only_fixture,
|
||||
{
|
||||
"args": "self, a, *, b, c",
|
||||
"self_arg": "self",
|
||||
"apply_pos": "self, a, *, b, c",
|
||||
"apply_kw": "self, a, b=b, c=c",
|
||||
},
|
||||
grouped=False,
|
||||
)
|
||||
self._test_format_argspec_plus(
|
||||
self._kw_plus_posn_fixture,
|
||||
{
|
||||
"args": "self, a, *args, b, c",
|
||||
"self_arg": "self",
|
||||
"apply_pos": "self, a, *args, b, c",
|
||||
"apply_kw": "self, a, b=b, c=c, *args",
|
||||
},
|
||||
grouped=False,
|
||||
)
|
||||
self._test_format_argspec_plus(
|
||||
self._kw_opt_fixture,
|
||||
{
|
||||
"args": "self, a, *, b, c='c'",
|
||||
"self_arg": "self",
|
||||
"apply_pos": "self, a, *, b, c",
|
||||
"apply_kw": "self, a, b=b, c=c",
|
||||
},
|
||||
grouped=False,
|
||||
)
|
||||
|
||||
@testing.requires.cpython
|
||||
def test_init_grouped(self):
|
||||
object_spec = {
|
||||
@@ -2508,17 +2538,20 @@ class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
|
||||
"self_arg": "self",
|
||||
"apply_pos": "(self)",
|
||||
"apply_kw": "(self)",
|
||||
"apply_pos_proxied": "()",
|
||||
}
|
||||
wrapper_spec = {
|
||||
"args": "(self, *args, **kwargs)",
|
||||
"self_arg": "self",
|
||||
"apply_pos": "(self, *args, **kwargs)",
|
||||
"apply_kw": "(self, *args, **kwargs)",
|
||||
"apply_pos_proxied": "(*args, **kwargs)",
|
||||
}
|
||||
custom_spec = {
|
||||
"args": "(slef, a=123)",
|
||||
"self_arg": "slef", # yes, slef
|
||||
"apply_pos": "(slef, a)",
|
||||
"apply_pos_proxied": "(a)",
|
||||
"apply_kw": "(slef, a=a)",
|
||||
}
|
||||
|
||||
@@ -2532,18 +2565,21 @@ class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
|
||||
"self_arg": "self",
|
||||
"apply_pos": "self",
|
||||
"apply_kw": "self",
|
||||
"apply_pos_proxied": "",
|
||||
}
|
||||
wrapper_spec = {
|
||||
"args": "self, *args, **kwargs",
|
||||
"self_arg": "self",
|
||||
"apply_pos": "self, *args, **kwargs",
|
||||
"apply_kw": "self, *args, **kwargs",
|
||||
"apply_pos_proxied": "*args, **kwargs",
|
||||
}
|
||||
custom_spec = {
|
||||
"args": "slef, a=123",
|
||||
"self_arg": "slef", # yes, slef
|
||||
"apply_pos": "slef, a",
|
||||
"apply_kw": "slef, a=a",
|
||||
"apply_pos_proxied": "a",
|
||||
}
|
||||
|
||||
self._test_init(False, object_spec, wrapper_spec, custom_spec)
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Integer
|
||||
@@ -9,13 +10,19 @@ from sqlalchemy import select
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Table
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import union_all
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.ext.asyncio import engine as _async_engine
|
||||
from sqlalchemy.ext.asyncio import exc as asyncio_exc
|
||||
from sqlalchemy.testing import async_test
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import fixtures
|
||||
from sqlalchemy.testing import is_
|
||||
from sqlalchemy.testing import is_not
|
||||
from sqlalchemy.testing import mock
|
||||
from sqlalchemy.testing.asyncio import assert_raises_message_async
|
||||
from sqlalchemy.util.concurrency import greenlet_spawn
|
||||
|
||||
|
||||
class EngineFixture(fixtures.TablesTest):
|
||||
@@ -50,6 +57,117 @@ class EngineFixture(fixtures.TablesTest):
|
||||
class AsyncEngineTest(EngineFixture):
|
||||
__backend__ = True
|
||||
|
||||
def test_proxied_attrs_engine(self, async_engine):
|
||||
sync_engine = async_engine.sync_engine
|
||||
|
||||
is_(async_engine.url, sync_engine.url)
|
||||
is_(async_engine.pool, sync_engine.pool)
|
||||
is_(async_engine.dialect, sync_engine.dialect)
|
||||
eq_(async_engine.name, sync_engine.name)
|
||||
eq_(async_engine.driver, sync_engine.driver)
|
||||
eq_(async_engine.echo, sync_engine.echo)
|
||||
|
||||
def test_clear_compiled_cache(self, async_engine):
|
||||
async_engine.sync_engine._compiled_cache["foo"] = "bar"
|
||||
eq_(async_engine.sync_engine._compiled_cache["foo"], "bar")
|
||||
async_engine.clear_compiled_cache()
|
||||
assert "foo" not in async_engine.sync_engine._compiled_cache
|
||||
|
||||
def test_execution_options(self, async_engine):
|
||||
a2 = async_engine.execution_options(foo="bar")
|
||||
assert isinstance(a2, _async_engine.AsyncEngine)
|
||||
eq_(a2.sync_engine._execution_options, {"foo": "bar"})
|
||||
eq_(async_engine.sync_engine._execution_options, {})
|
||||
|
||||
"""
|
||||
|
||||
attr uri, pool, dialect, engine, name, driver, echo
|
||||
methods clear_compiled_cache, update_execution_options,
|
||||
execution_options, get_execution_options, dispose
|
||||
|
||||
"""
|
||||
|
||||
@async_test
|
||||
async def test_proxied_attrs_connection(self, async_engine):
|
||||
conn = await async_engine.connect()
|
||||
|
||||
sync_conn = conn.sync_connection
|
||||
|
||||
is_(conn.engine, async_engine)
|
||||
is_(conn.closed, sync_conn.closed)
|
||||
is_(conn.dialect, async_engine.sync_engine.dialect)
|
||||
eq_(conn.default_isolation_level, sync_conn.default_isolation_level)
|
||||
|
||||
@async_test
|
||||
async def test_invalidate(self, async_engine):
|
||||
conn = await async_engine.connect()
|
||||
|
||||
is_(conn.invalidated, False)
|
||||
|
||||
connection_fairy = await conn.get_raw_connection()
|
||||
is_(connection_fairy.is_valid, True)
|
||||
dbapi_connection = connection_fairy.connection
|
||||
|
||||
await conn.invalidate()
|
||||
assert dbapi_connection._connection.is_closed()
|
||||
|
||||
new_fairy = await conn.get_raw_connection()
|
||||
is_not(new_fairy.connection, dbapi_connection)
|
||||
is_not(new_fairy, connection_fairy)
|
||||
is_(new_fairy.is_valid, True)
|
||||
is_(connection_fairy.is_valid, False)
|
||||
|
||||
@async_test
|
||||
async def test_get_dbapi_connection_raise(self, async_engine):
|
||||
|
||||
conn = await async_engine.connect()
|
||||
|
||||
with testing.expect_raises_message(
|
||||
exc.InvalidRequestError,
|
||||
"AsyncConnection.connection accessor is not "
|
||||
"implemented as the attribute",
|
||||
):
|
||||
conn.connection
|
||||
|
||||
@async_test
|
||||
async def test_get_raw_connection(self, async_engine):
|
||||
|
||||
conn = await async_engine.connect()
|
||||
|
||||
pooled = await conn.get_raw_connection()
|
||||
is_(pooled, conn.sync_connection.connection)
|
||||
|
||||
@async_test
|
||||
async def test_isolation_level(self, async_engine):
|
||||
conn = await async_engine.connect()
|
||||
|
||||
sync_isolation_level = await greenlet_spawn(
|
||||
conn.sync_connection.get_isolation_level
|
||||
)
|
||||
isolation_level = await conn.get_isolation_level()
|
||||
|
||||
eq_(isolation_level, sync_isolation_level)
|
||||
|
||||
await conn.execution_options(isolation_level="SERIALIZABLE")
|
||||
isolation_level = await conn.get_isolation_level()
|
||||
|
||||
eq_(isolation_level, "SERIALIZABLE")
|
||||
|
||||
@async_test
|
||||
async def test_dispose(self, async_engine):
|
||||
c1 = await async_engine.connect()
|
||||
c2 = await async_engine.connect()
|
||||
|
||||
await c1.close()
|
||||
await c2.close()
|
||||
|
||||
p1 = async_engine.pool
|
||||
eq_(async_engine.pool.checkedin(), 2)
|
||||
|
||||
await async_engine.dispose()
|
||||
eq_(async_engine.pool.checkedin(), 0)
|
||||
is_not(p1, async_engine.pool)
|
||||
|
||||
@async_test
|
||||
async def test_init_once_concurrency(self, async_engine):
|
||||
c1 = async_engine.connect()
|
||||
@@ -169,6 +287,70 @@ class AsyncEngineTest(EngineFixture):
|
||||
)
|
||||
|
||||
|
||||
class AsyncEventTest(EngineFixture):
|
||||
"""The engine events all run in their normal synchronous context.
|
||||
|
||||
we do not provide an asyncio event interface at this time.
|
||||
|
||||
"""
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@async_test
|
||||
async def test_no_async_listeners(self, async_engine):
|
||||
with testing.expect_raises_message(
|
||||
NotImplementedError,
|
||||
"asynchronous events are not implemented "
|
||||
"at this time. Apply synchronous listeners to the "
|
||||
"AsyncEngine.sync_engine or "
|
||||
"AsyncConnection.sync_connection attributes.",
|
||||
):
|
||||
event.listen(async_engine, "before_cursor_execute", mock.Mock())
|
||||
|
||||
conn = await async_engine.connect()
|
||||
|
||||
with testing.expect_raises_message(
|
||||
NotImplementedError,
|
||||
"asynchronous events are not implemented "
|
||||
"at this time. Apply synchronous listeners to the "
|
||||
"AsyncEngine.sync_engine or "
|
||||
"AsyncConnection.sync_connection attributes.",
|
||||
):
|
||||
event.listen(conn, "before_cursor_execute", mock.Mock())
|
||||
|
||||
@async_test
|
||||
async def test_sync_before_cursor_execute_engine(self, async_engine):
|
||||
canary = mock.Mock()
|
||||
|
||||
event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
|
||||
|
||||
async with async_engine.connect() as conn:
|
||||
sync_conn = conn.sync_connection
|
||||
await conn.execute(text("select 1"))
|
||||
|
||||
eq_(
|
||||
canary.mock_calls,
|
||||
[mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)],
|
||||
)
|
||||
|
||||
@async_test
|
||||
async def test_sync_before_cursor_execute_connection(self, async_engine):
|
||||
canary = mock.Mock()
|
||||
|
||||
async with async_engine.connect() as conn:
|
||||
sync_conn = conn.sync_connection
|
||||
|
||||
event.listen(
|
||||
async_engine.sync_engine, "before_cursor_execute", canary
|
||||
)
|
||||
await conn.execute(text("select 1"))
|
||||
|
||||
eq_(
|
||||
canary.mock_calls,
|
||||
[mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)],
|
||||
)
|
||||
|
||||
|
||||
class AsyncResultTest(EngineFixture):
|
||||
@testing.combinations(
|
||||
(None,), ("scalars",), ("mappings",), argnames="filter_"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import exc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
@@ -9,6 +10,7 @@ from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.testing import async_test
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import is_
|
||||
from sqlalchemy.testing import mock
|
||||
from ...orm import _fixtures
|
||||
|
||||
|
||||
@@ -139,6 +141,27 @@ class AsyncSessionTransactionTest(AsyncFixture):
|
||||
|
||||
eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
|
||||
|
||||
@async_test
|
||||
async def test_delete(self, async_session):
|
||||
User = self.classes.User
|
||||
|
||||
async with async_session.begin():
|
||||
u1 = User(name="u1")
|
||||
|
||||
async_session.add(u1)
|
||||
|
||||
await async_session.flush()
|
||||
|
||||
conn = await async_session.connection()
|
||||
|
||||
eq_(await conn.scalar(select(func.count(User.id))), 1)
|
||||
|
||||
async_session.delete(u1)
|
||||
|
||||
await async_session.flush()
|
||||
|
||||
eq_(await conn.scalar(select(func.count(User.id))), 0)
|
||||
|
||||
@async_test
|
||||
async def test_flush(self, async_session):
|
||||
User = self.classes.User
|
||||
@@ -198,3 +221,38 @@ class AsyncSessionTransactionTest(AsyncFixture):
|
||||
|
||||
is_(new_u_merged, u1)
|
||||
eq_(u1.name, "new u1")
|
||||
|
||||
|
||||
class AsyncEventTest(AsyncFixture):
|
||||
"""The engine events all run in their normal synchronous context.
|
||||
|
||||
we do not provide an asyncio event interface at this time.
|
||||
|
||||
"""
|
||||
|
||||
__backend__ = True
|
||||
|
||||
@async_test
|
||||
async def test_no_async_listeners(self, async_session):
|
||||
with testing.expect_raises(
|
||||
NotImplementedError,
|
||||
"NotImplementedError: asynchronous events are not implemented "
|
||||
"at this time. Apply synchronous listeners to the "
|
||||
"AsyncEngine.sync_engine or "
|
||||
"AsyncConnection.sync_connection attributes.",
|
||||
):
|
||||
event.listen(async_session, "before_flush", mock.Mock())
|
||||
|
||||
@async_test
|
||||
async def test_sync_before_commit(self, async_session):
|
||||
canary = mock.Mock()
|
||||
|
||||
event.listen(async_session.sync_session, "before_commit", canary)
|
||||
|
||||
async with async_session.begin():
|
||||
pass
|
||||
|
||||
eq_(
|
||||
canary.mock_calls,
|
||||
[mock.call(async_session.sync_session)],
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import scoped_session
|
||||
from sqlalchemy.testing import assert_raises_message
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import fixtures
|
||||
from sqlalchemy.testing import mock
|
||||
from sqlalchemy.testing.mock import Mock
|
||||
from sqlalchemy.testing.schema import Column
|
||||
from sqlalchemy.testing.schema import Table
|
||||
@@ -127,3 +128,26 @@ class ScopedSessionTest(fixtures.MappedTest):
|
||||
mock_scope_func.return_value = 1
|
||||
s2 = Session(autocommit=True)
|
||||
assert s2.autocommit == True
|
||||
|
||||
def test_methods_etc(self):
|
||||
mock_session = Mock()
|
||||
mock_session.bind = "the bind"
|
||||
|
||||
sess = scoped_session(lambda: mock_session)
|
||||
|
||||
sess.add("add")
|
||||
sess.delete("delete")
|
||||
|
||||
eq_(sess.bind, "the bind")
|
||||
|
||||
eq_(
|
||||
mock_session.mock_calls,
|
||||
[mock.call.add("add", True), mock.call.delete("delete")],
|
||||
)
|
||||
|
||||
with mock.patch(
|
||||
"sqlalchemy.orm.session.object_session"
|
||||
) as mock_object_session:
|
||||
sess.object_session("foo")
|
||||
|
||||
eq_(mock_object_session.mock_calls, [mock.call("foo")])
|
||||
|
||||
+30
-25
@@ -1,3 +1,5 @@
|
||||
import inspect as _py_inspect
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import ForeignKey
|
||||
@@ -1820,22 +1822,28 @@ class DisposedStates(fixtures.MappedTest):
|
||||
class SessionInterface(fixtures.TestBase):
|
||||
"""Bogus args to Session methods produce actionable exceptions."""
|
||||
|
||||
# TODO: expand with message body assertions.
|
||||
|
||||
_class_methods = set(("connection", "execute", "get_bind", "scalar"))
|
||||
|
||||
def _public_session_methods(self):
|
||||
Session = sa.orm.session.Session
|
||||
|
||||
blacklist = set(("begin", "query"))
|
||||
|
||||
blacklist = {"begin", "query", "bind_mapper", "get", "bind_table"}
|
||||
specials = {"__iter__", "__contains__"}
|
||||
ok = set()
|
||||
for meth in Session.public_methods:
|
||||
if meth in blacklist:
|
||||
continue
|
||||
spec = inspect_getfullargspec(getattr(Session, meth))
|
||||
if len(spec[0]) > 1 or spec[1]:
|
||||
ok.add(meth)
|
||||
for name in dir(Session):
|
||||
if (
|
||||
name in Session.__dict__
|
||||
and (not name.startswith("_") or name in specials)
|
||||
and (
|
||||
_py_inspect.ismethod(getattr(Session, name))
|
||||
or _py_inspect.isfunction(getattr(Session, name))
|
||||
)
|
||||
):
|
||||
if name in blacklist:
|
||||
continue
|
||||
spec = inspect_getfullargspec(getattr(Session, name))
|
||||
if len(spec[0]) > 1 or spec[1]:
|
||||
ok.add(name)
|
||||
return ok
|
||||
|
||||
def _map_it(self, cls):
|
||||
@@ -1866,18 +1874,21 @@ class SessionInterface(fixtures.TestBase):
|
||||
def raises_(method, *args, **kw):
|
||||
x_raises_(create_session(), method, *args, **kw)
|
||||
|
||||
raises_("__contains__", user_arg)
|
||||
|
||||
raises_("add", user_arg)
|
||||
for name in [
|
||||
"__contains__",
|
||||
"is_modified",
|
||||
"merge",
|
||||
"refresh",
|
||||
"add",
|
||||
"delete",
|
||||
"expire",
|
||||
"expunge",
|
||||
"enable_relationship_loading",
|
||||
]:
|
||||
raises_(name, user_arg)
|
||||
|
||||
raises_("add_all", (user_arg,))
|
||||
|
||||
raises_("delete", user_arg)
|
||||
|
||||
raises_("expire", user_arg)
|
||||
|
||||
raises_("expunge", user_arg)
|
||||
|
||||
# flush will no-op without something in the unit of work
|
||||
def _():
|
||||
class OK(object):
|
||||
@@ -1891,12 +1902,6 @@ class SessionInterface(fixtures.TestBase):
|
||||
|
||||
_()
|
||||
|
||||
raises_("is_modified", user_arg)
|
||||
|
||||
raises_("merge", user_arg)
|
||||
|
||||
raises_("refresh", user_arg)
|
||||
|
||||
instance_methods = (
|
||||
self._public_session_methods()
|
||||
- self._class_methods
|
||||
|
||||
Reference in New Issue
Block a user