generalize scoped_session proxying and apply to asyncio elements

Reworked the proxy creation used by scoped_session() to be
based on fully copied code with augmented docstrings and
moved it into langhelpers.  asyncio session, engine,
connection can now take
advantage of it so that all non-async methods are availble.

Overall implementation of most important accessors / methods
on AsyncConnection, etc. , including awaitable versions
of invalidate, execution_options, etc.

In order to support an event dispatcher on the async
classes while still allowing them to hold __slots__,
make some adjustments to the event system to allow
that to be present, at least rudimentally.

Fixes: #5628
Change-Id: I5eb6929fc1e4fdac99e4b767dcfd49672d56e2b2
This commit is contained in:
Mike Bayer
2020-10-08 15:20:48 -04:00
parent bcc17b1d6e
commit 2665a0c4cb
19 changed files with 1070 additions and 390 deletions
+19 -11
View File
@@ -184,14 +184,17 @@ class Connection(Connectable):
r""" Set non-SQL options for the connection which take effect
during execution.
The method returns a copy of this :class:`_engine.Connection`
which references
the same underlying DBAPI connection, but also defines the given
execution options which will take effect for a call to
:meth:`execute`. As the new :class:`_engine.Connection`
references the same
underlying resource, it's usually a good idea to ensure that the copies
will be discarded immediately, which is implicit if used as in::
For a "future" style connection, this method returns this same
:class:`_future.Connection` object with the new options added.
For a legacy connection, this method returns a copy of this
:class:`_engine.Connection` which references the same underlying DBAPI
connection, but also defines the given execution options which will
take effect for a call to
:meth:`execute`. As the new :class:`_engine.Connection` references the
same underlying resource, it's usually a good idea to ensure that
the copies will be discarded immediately, which is implicit if used
as in::
result = connection.execution_options(stream_results=True).\
execute(stmt)
@@ -549,9 +552,10 @@ class Connection(Connectable):
"""Invalidate the underlying DBAPI connection associated with
this :class:`_engine.Connection`.
The underlying DBAPI connection is literally closed (if
possible), and is discarded. Its source connection pool will
typically lazily create a new connection to replace it.
An attempt will be made to close the underlying DBAPI connection
immediately; however if this operation fails, the error is logged
but not raised. The connection is then discarded whether or not
close() succeeded.
Upon the next use (where "use" typically means using the
:meth:`_engine.Connection.execute` method or similar),
@@ -580,6 +584,10 @@ class Connection(Connectable):
will at the connection pool level invoke the
:meth:`_events.PoolEvents.invalidate` event.
:param exception: an optional ``Exception`` instance that's the
reason for the invalidation. is passed along to event handlers
and logging functions.
.. seealso::
:ref:`pool_connection_invalidation`
+33 -2
View File
@@ -195,7 +195,14 @@ def _create_dispatcher_class(cls, classname, bases, dict_):
dispatch_cls._event_names.append(ls.name)
if getattr(cls, "_dispatch_target", None):
cls._dispatch_target.dispatch = dispatcher(cls)
the_cls = cls._dispatch_target
if (
hasattr(the_cls, "__slots__")
and "_slots_dispatch" in the_cls.__slots__
):
cls._dispatch_target.dispatch = slots_dispatcher(cls)
else:
cls._dispatch_target.dispatch = dispatcher(cls)
def _remove_dispatcher(cls):
@@ -304,5 +311,29 @@ class dispatcher(object):
def __get__(self, obj, cls):
if obj is None:
return self.dispatch
obj.__dict__["dispatch"] = disp = self.dispatch._for_instance(obj)
disp = self.dispatch._for_instance(obj)
try:
obj.__dict__["dispatch"] = disp
except AttributeError as ae:
util.raise_(
TypeError(
"target %r doesn't have __dict__, should it be "
"defining _slots_dispatch?" % (obj,)
),
replace_context=ae,
)
return disp
class slots_dispatcher(dispatcher):
def __get__(self, obj, cls):
if obj is None:
return self.dispatch
if hasattr(obj, "_slots_dispatch"):
return obj._slots_dispatch
disp = self.dispatch._for_instance(obj)
obj._slots_dispatch = disp
return disp
+2
View File
@@ -2,6 +2,8 @@ from .engine import AsyncConnection # noqa
from .engine import AsyncEngine # noqa
from .engine import AsyncTransaction # noqa
from .engine import create_async_engine # noqa
from .events import AsyncConnectionEvents # noqa
from .events import AsyncSessionEvents # noqa
from .result import AsyncMappingResult # noqa
from .result import AsyncResult # noqa
from .result import AsyncScalarResult # noqa
+140 -13
View File
@@ -8,12 +8,11 @@ from .base import StartableContext
from .result import AsyncResult
from ... import exc
from ... import util
from ...engine import Connection
from ...engine import create_engine as _create_engine
from ...engine import Engine
from ...engine import Result
from ...engine import Transaction
from ...engine.base import OptionEngineMixin
from ...future import Connection
from ...future import Engine
from ...sql import Executable
from ...util.concurrency import greenlet_spawn
@@ -41,7 +40,24 @@ def create_async_engine(*arg, **kw):
return AsyncEngine(sync_engine)
class AsyncConnection(StartableContext):
class AsyncConnectable:
__slots__ = "_slots_dispatch"
@util.create_proxy_methods(
Connection,
":class:`_future.Connection`",
":class:`_asyncio.AsyncConnection`",
classmethods=[],
methods=[],
attributes=[
"closed",
"invalidated",
"dialect",
"default_isolation_level",
],
)
class AsyncConnection(StartableContext, AsyncConnectable):
"""An asyncio proxy for a :class:`_engine.Connection`.
:class:`_asyncio.AsyncConnection` is acquired using the
@@ -58,15 +74,23 @@ class AsyncConnection(StartableContext):
""" # noqa
# AsyncConnection is a thin proxy; no state should be added here
# that is not retrievable from the "sync" engine / connection, e.g.
# current transaction, info, etc. It should be possible to
# create a new AsyncConnection that matches this one given only the
# "sync" elements.
__slots__ = (
"sync_engine",
"sync_connection",
)
def __init__(
self, sync_engine: Engine, sync_connection: Optional[Connection] = None
self,
async_engine: "AsyncEngine",
sync_connection: Optional[Connection] = None,
):
self.sync_engine = sync_engine
self.engine = async_engine
self.sync_engine = async_engine.sync_engine
self.sync_connection = sync_connection
async def start(self):
@@ -79,6 +103,34 @@ class AsyncConnection(StartableContext):
self.sync_connection = await (greenlet_spawn(self.sync_engine.connect))
return self
@property
def connection(self):
"""Not implemented for async; call
:meth:`_asyncio.AsyncConnection.get_raw_connection`.
"""
raise exc.InvalidRequestError(
"AsyncConnection.connection accessor is not implemented as the "
"attribute may need to reconnect on an invalidated connection. "
"Use the get_raw_connection() method."
)
async def get_raw_connection(self):
"""Return the pooled DBAPI-level connection in use by this
:class:`_asyncio.AsyncConnection`.
This is typically the SQLAlchemy connection-pool proxied connection
which then has an attribute .connection that refers to the actual
DBAPI-level connection.
"""
conn = self._sync_connection()
return await greenlet_spawn(getattr, conn, "connection")
@property
def _proxied(self):
return self.sync_connection
def _sync_connection(self):
if not self.sync_connection:
self._raise_for_not_started()
@@ -94,6 +146,43 @@ class AsyncConnection(StartableContext):
self._sync_connection()
return AsyncTransaction(self, nested=True)
async def invalidate(self, exception=None):
"""Invalidate the underlying DBAPI connection associated with
this :class:`_engine.Connection`.
See the method :meth:`_engine.Connection.invalidate` for full
detail on this method.
"""
conn = self._sync_connection()
return await greenlet_spawn(conn.invalidate, exception=exception)
async def get_isolation_level(self):
conn = self._sync_connection()
return await greenlet_spawn(conn.get_isolation_level)
async def set_isolation_level(self):
conn = self._sync_connection()
return await greenlet_spawn(conn.get_isolation_level)
async def execution_options(self, **opt):
r"""Set non-SQL options for the connection which take effect
during execution.
This returns this :class:`_asyncio.AsyncConnection` object with
the new options added.
See :meth:`_future.Connection.execution_options` for full details
on this method.
"""
conn = self._sync_connection()
c2 = await greenlet_spawn(conn.execution_options, **opt)
assert c2 is conn
return self
async def commit(self):
"""Commit the transaction that is currently in progress.
@@ -287,7 +376,19 @@ class AsyncConnection(StartableContext):
await self.close()
class AsyncEngine:
@util.create_proxy_methods(
Engine,
":class:`_future.Engine`",
":class:`_asyncio.AsyncEngine`",
classmethods=[],
methods=[
"clear_compiled_cache",
"update_execution_options",
"get_execution_options",
],
attributes=["url", "pool", "dialect", "engine", "name", "driver", "echo"],
)
class AsyncEngine(AsyncConnectable):
"""An asyncio proxy for a :class:`_engine.Engine`.
:class:`_asyncio.AsyncEngine` is acquired using the
@@ -301,7 +402,12 @@ class AsyncEngine:
""" # noqa
__slots__ = ("sync_engine",)
# AsyncEngine is a thin proxy; no state should be added here
# that is not retrievable from the "sync" engine / connection, e.g.
# current transaction, info, etc. It should be possible to
# create a new AsyncEngine that matches this one given only the
# "sync" elements.
__slots__ = ("sync_engine", "_proxied")
_connection_cls = AsyncConnection
@@ -327,7 +433,7 @@ class AsyncEngine:
await self.conn.close()
def __init__(self, sync_engine: Engine):
self.sync_engine = sync_engine
self.sync_engine = self._proxied = sync_engine
def begin(self):
"""Return a context manager which when entered will deliver an
@@ -363,7 +469,7 @@ class AsyncEngine:
"""
return self._connection_cls(self.sync_engine)
return self._connection_cls(self)
async def raw_connection(self) -> Any:
"""Return a "raw" DBAPI connection from the connection pool.
@@ -375,12 +481,33 @@ class AsyncEngine:
"""
return await greenlet_spawn(self.sync_engine.raw_connection)
def execution_options(self, **opt):
"""Return a new :class:`_asyncio.AsyncEngine` that will provide
:class:`_asyncio.AsyncConnection` objects with the given execution
options.
class AsyncOptionEngine(OptionEngineMixin, AsyncEngine):
pass
Proxied from :meth:`_future.Engine.execution_options`. See that
method for details.
"""
AsyncEngine._option_cls = AsyncOptionEngine
return AsyncEngine(self.sync_engine.execution_options(**opt))
async def dispose(self):
"""Dispose of the connection pool used by this
:class:`_asyncio.AsyncEngine`.
This will close all connection pool connections that are
**currently checked in**. See the documentation for the underlying
:meth:`_future.Engine.dispose` method for further notes.
.. seealso::
:meth:`_future.Engine.dispose`
"""
return await greenlet_spawn(self.sync_engine.dispose)
class AsyncTransaction(StartableContext):
+29
View File
@@ -0,0 +1,29 @@
from .engine import AsyncConnectable
from .session import AsyncSession
from ...engine import events as engine_event
from ...orm import events as orm_event
class AsyncConnectionEvents(engine_event.ConnectionEvents):
_target_class_doc = "SomeEngine"
_dispatch_target = AsyncConnectable
@classmethod
def _listen(cls, event_key, retval=False):
raise NotImplementedError(
"asynchronous events are not implemented at this time. Apply "
"synchronous listeners to the AsyncEngine.sync_engine or "
"AsyncConnection.sync_connection attributes."
)
class AsyncSessionEvents(orm_event.SessionEvents):
_target_class_doc = "SomeSession"
_dispatch_target = AsyncSession
@classmethod
def _listen(cls, event_key, retval=False):
raise NotImplementedError(
"asynchronous events are not implemented at this time. Apply "
"synchronous listeners to the AsyncSession.sync_session."
)
+62 -35
View File
@@ -1,6 +1,5 @@
from typing import Any
from typing import Callable
from typing import List
from typing import Mapping
from typing import Optional
@@ -15,6 +14,35 @@ from ...sql import Executable
from ...util.concurrency import greenlet_spawn
@util.create_proxy_methods(
Session,
":class:`_orm.Session`",
":class:`_asyncio.AsyncSession`",
classmethods=["object_session", "identity_key"],
methods=[
"__contains__",
"__iter__",
"add",
"add_all",
"delete",
"expire",
"expire_all",
"expunge",
"expunge_all",
"get_bind",
"is_modified",
],
attributes=[
"dirty",
"deleted",
"new",
"identity_map",
"is_active",
"autoflush",
"no_autoflush",
"info",
],
)
class AsyncSession:
"""Asyncio version of :class:`_orm.Session`.
@@ -23,6 +51,16 @@ class AsyncSession:
"""
__slots__ = (
"binds",
"bind",
"sync_session",
"_proxied",
"_slots_dispatch",
)
dispatch = None
def __init__(
self,
bind: AsyncEngine = None,
@@ -31,46 +69,18 @@ class AsyncSession:
):
kw["future"] = True
if bind:
self.bind = engine
bind = engine._get_sync_engine(bind)
if binds:
self.binds = binds
binds = {
key: engine._get_sync_engine(b) for key, b in binds.items()
}
self.sync_session = Session(bind=bind, binds=binds, **kw)
def add(self, instance: object) -> None:
"""Place an object in this :class:`_asyncio.AsyncSession`.
.. seealso::
:meth:`_orm.Session.add`
"""
self.sync_session.add(instance)
def add_all(self, instances: List[object]) -> None:
"""Add the given collection of instances to this
:class:`_asyncio.AsyncSession`."""
self.sync_session.add_all(instances)
def expire_all(self):
"""Expires all persistent instances within this Session.
See :meth:`_orm.Session.expire_all` for usage details.
"""
self.sync_session.expire_all()
def expire(self, instance, attribute_names=None):
"""Expire the attributes on an instance.
See :meth:`._orm.Session.expire` for usage details.
"""
self.sync_session.expire()
self.sync_session = self._proxied = Session(
bind=bind, binds=binds, **kw
)
async def refresh(
self, instance, attribute_names=None, with_for_update=None
@@ -178,8 +188,17 @@ class AsyncSession:
:class:`.Session` object's transactional state.
"""
# POSSIBLY TODO: here, we see that the sync engine / connection
# that are generated from AsyncEngine / AsyncConnection don't
# provide any backlink from those sync objects back out to the
# async ones. it's not *too* big a deal since AsyncEngine/Connection
# are just proxies and all the state is actually in the sync
# version of things. However! it has to stay that way :)
sync_connection = await greenlet_spawn(self.sync_session.connection)
return engine.AsyncConnection(sync_connection.engine, sync_connection)
return engine.AsyncConnection(
engine.AsyncEngine(sync_connection.engine), sync_connection
)
def begin(self, **kw):
"""Return an :class:`_asyncio.AsyncSessionTransaction` object.
@@ -218,14 +237,22 @@ class AsyncSession:
return AsyncSessionTransaction(self, nested=True)
async def rollback(self):
"""Rollback the current transaction in progress."""
return await greenlet_spawn(self.sync_session.rollback)
async def commit(self):
"""Commit the current transaction in progress."""
return await greenlet_spawn(self.sync_session.commit)
async def close(self):
"""Close this :class:`_asyncio.AsyncSession`."""
return await greenlet_spawn(self.sync_session.close)
@classmethod
async def close_all(self):
"""Close all :class:`_asyncio.AsyncSession` sessions."""
return await greenlet_spawn(self.sync_session.close_all)
async def __aenter__(self):
return self
+2 -1
View File
@@ -1371,7 +1371,8 @@ class SessionEvents(event.Events):
elif isinstance(target, Session):
return target
else:
return None
# allows alternate SessionEvents-like-classes to be consulted
return event.Events._accept_with(target)
@classmethod
def _listen(cls, event_key, raw=False, restore_load_context=False, **kw):
+51 -48
View File
@@ -9,14 +9,60 @@ from . import class_mapper
from . import exc as orm_exc
from .session import Session
from .. import exc as sa_exc
from ..util import create_proxy_methods
from ..util import ScopedRegistry
from ..util import ThreadLocalRegistry
from ..util import warn
__all__ = ["scoped_session"]
@create_proxy_methods(
Session,
":class:`_orm.Session`",
":class:`_orm.scoping.scoped_session`",
classmethods=["close_all", "object_session", "identity_key"],
methods=[
"__contains__",
"__iter__",
"add",
"add_all",
"begin",
"begin_nested",
"close",
"commit",
"connection",
"delete",
"execute",
"expire",
"expire_all",
"expunge",
"expunge_all",
"flush",
"get_bind",
"is_modified",
"bulk_save_objects",
"bulk_insert_mappings",
"bulk_update_mappings",
"merge",
"query",
"refresh",
"rollback",
"scalar",
],
attributes=[
"bind",
"dirty",
"deleted",
"new",
"identity_map",
"is_active",
"autoflush",
"no_autoflush",
"info",
"autocommit",
],
)
class scoped_session(object):
"""Provides scoped management of :class:`.Session` objects.
@@ -53,6 +99,10 @@ class scoped_session(object):
else:
self.registry = ThreadLocalRegistry(session_factory)
@property
def _proxied(self):
return self.registry()
def __call__(self, **kw):
r"""Return the current :class:`.Session`, creating it
using the :attr:`.scoped_session.session_factory` if not present.
@@ -156,50 +206,3 @@ class scoped_session(object):
ScopedSession = scoped_session
"""Old name for backwards compatibility."""
def instrument(name):
def do(self, *args, **kwargs):
return getattr(self.registry(), name)(*args, **kwargs)
return do
for meth in Session.public_methods:
setattr(scoped_session, meth, instrument(meth))
def makeprop(name):
def set_(self, attr):
setattr(self.registry(), name, attr)
def get(self):
return getattr(self.registry(), name)
return property(get, set_)
for prop in (
"bind",
"dirty",
"deleted",
"new",
"identity_map",
"is_active",
"autoflush",
"no_autoflush",
"info",
"autocommit",
):
setattr(scoped_session, prop, makeprop(prop))
def clslevel(name):
def do(cls, *args, **kwargs):
return getattr(Session, name)(*args, **kwargs)
return classmethod(do)
for prop in ("close_all", "object_session", "identity_key"):
setattr(scoped_session, prop, clslevel(prop))
+8 -30
View File
@@ -835,35 +835,6 @@ class Session(_SessionClassMethods):
"""
public_methods = (
"__contains__",
"__iter__",
"add",
"add_all",
"begin",
"begin_nested",
"close",
"commit",
"connection",
"delete",
"execute",
"expire",
"expire_all",
"expunge",
"expunge_all",
"flush",
"get_bind",
"is_modified",
"bulk_save_objects",
"bulk_insert_mappings",
"bulk_update_mappings",
"merge",
"query",
"refresh",
"rollback",
"scalar",
)
@util.deprecated_params(
autocommit=(
"2.0",
@@ -3028,7 +2999,14 @@ class Session(_SessionClassMethods):
will unexpire attributes on access.
"""
state = attributes.instance_state(obj)
try:
state = attributes.instance_state(obj)
except exc.NO_STATE as err:
util.raise_(
exc.UnmappedInstanceError(obj),
replace_context=err,
)
to_attach = self._before_attach(state, obj)
state._load_pending = True
if to_attach:
+1
View File
@@ -509,6 +509,7 @@ class _ConnectionRecord(object):
"Soft " if soft else "",
self.connection,
)
if soft:
self._soft_invalidate_time = time.time()
else:
@@ -372,7 +372,7 @@ def _pytest_fn_decorator(target):
if add_positional_parameters:
spec.args.extend(add_positional_parameters)
metadata = dict(target="target", fn="fn", name=fn.__name__)
metadata = dict(target="target", fn="__fn", name=fn.__name__)
metadata.update(format_argspec_plus(spec, grouped=False))
code = (
"""\
@@ -382,7 +382,7 @@ def %(name)s(%(args)s):
% metadata
)
decorated = _exec_code_in_env(
code, {"target": target, "fn": fn}, fn.__name__
code, {"target": target, "__fn": fn}, fn.__name__
)
if not add_positional_parameters:
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
+1
View File
@@ -123,6 +123,7 @@ from .langhelpers import coerce_kw_type # noqa
from .langhelpers import constructor_copy # noqa
from .langhelpers import constructor_key # noqa
from .langhelpers import counter # noqa
from .langhelpers import create_proxy_methods # noqa
from .langhelpers import decode_slice # noqa
from .langhelpers import decorator # noqa
from .langhelpers import dictlike_iteritems # noqa
+150 -11
View File
@@ -9,6 +9,7 @@
modules, classes, hierarchies, attributes, functions, and methods.
"""
from functools import update_wrapper
import hashlib
import inspect
@@ -462,6 +463,8 @@ def format_argspec_plus(fn, grouped=True):
passed positionally.
apply_kw
Like apply_pos, except keyword-ish args are passed as keywords.
apply_pos_proxied
Like apply_pos but omits the self/cls argument
Example::
@@ -478,16 +481,27 @@ def format_argspec_plus(fn, grouped=True):
spec = fn
args = compat.inspect_formatargspec(*spec)
if spec[0]:
self_arg = spec[0][0]
elif spec[1]:
self_arg = "%s[0]" % spec[1]
else:
self_arg = None
apply_pos = compat.inspect_formatargspec(
spec[0], spec[1], spec[2], None, spec[4]
)
if spec[0]:
self_arg = spec[0][0]
apply_pos_proxied = compat.inspect_formatargspec(
spec[0][1:], spec[1], spec[2], None, spec[4]
)
elif spec[1]:
# im not sure what this is
self_arg = "%s[0]" % spec[1]
apply_pos_proxied = apply_pos
else:
self_arg = None
apply_pos_proxied = apply_pos
num_defaults = 0
if spec[3]:
num_defaults += len(spec[3])
@@ -513,6 +527,7 @@ def format_argspec_plus(fn, grouped=True):
self_arg=self_arg,
apply_pos=apply_pos,
apply_kw=apply_kw,
apply_pos_proxied=apply_pos_proxied,
)
else:
return dict(
@@ -520,6 +535,7 @@ def format_argspec_plus(fn, grouped=True):
self_arg=self_arg,
apply_pos=apply_pos[1:-1],
apply_kw=apply_kw[1:-1],
apply_pos_proxied=apply_pos_proxied[1:-1],
)
@@ -534,17 +550,140 @@ def format_argspec_init(method, grouped=True):
"""
if method is object.__init__:
args = grouped and "(self)" or "self"
args = "(self)" if grouped else "self"
proxied = "()" if grouped else ""
else:
try:
return format_argspec_plus(method, grouped=grouped)
except TypeError:
args = (
grouped
and "(self, *args, **kwargs)"
or "self, *args, **kwargs"
"(self, *args, **kwargs)"
if grouped
else "self, *args, **kwargs"
)
return dict(self_arg="self", args=args, apply_pos=args, apply_kw=args)
proxied = "(*args, **kwargs)" if grouped else "*args, **kwargs"
return dict(
self_arg="self",
args=args,
apply_pos=args,
apply_kw=args,
apply_pos_proxied=proxied,
)
def create_proxy_methods(
target_cls,
target_cls_sphinx_name,
proxy_cls_sphinx_name,
classmethods=(),
methods=(),
attributes=(),
):
"""A class decorator that will copy attributes to a proxy class.
The class to be instrumented must define a single accessor "_proxied".
"""
def decorate(cls):
def instrument(name, clslevel=False):
fn = getattr(target_cls, name)
spec = compat.inspect_getfullargspec(fn)
env = {}
spec = _update_argspec_defaults_into_env(spec, env)
caller_argspec = format_argspec_plus(spec, grouped=False)
metadata = {
"name": fn.__name__,
"apply_pos_proxied": caller_argspec["apply_pos_proxied"],
"args": caller_argspec["args"],
"self_arg": caller_argspec["self_arg"],
}
if clslevel:
code = (
"def %(name)s(%(args)s):\n"
" return target_cls.%(name)s(%(apply_pos_proxied)s)"
% metadata
)
env["target_cls"] = target_cls
else:
code = (
"def %(name)s(%(args)s):\n"
" return %(self_arg)s._proxied.%(name)s(%(apply_pos_proxied)s)" # noqa E501
% metadata
)
proxy_fn = _exec_code_in_env(code, env, fn.__name__)
proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__
proxy_fn.__doc__ = inject_docstring_text(
fn.__doc__,
".. container:: class_bases\n\n "
"Proxied for the %s class on behalf of the %s class."
% (target_cls_sphinx_name, proxy_cls_sphinx_name),
1,
)
if clslevel:
proxy_fn = classmethod(proxy_fn)
return proxy_fn
def makeprop(name):
attr = target_cls.__dict__.get(name, None)
if attr is not None:
doc = inject_docstring_text(
attr.__doc__,
".. container:: class_bases\n\n "
"Proxied for the %s class on behalf of the %s class."
% (
target_cls_sphinx_name,
proxy_cls_sphinx_name,
),
1,
)
else:
doc = None
code = (
"def set_(self, attr):\n"
" self._proxied.%(name)s = attr\n"
"def get(self):\n"
" return self._proxied.%(name)s\n"
"get.__doc__ = doc\n"
"getset = property(get, set_)"
) % {"name": name}
getset = _exec_code_in_env(code, {"doc": doc}, "getset")
return getset
for meth in methods:
if hasattr(cls, meth):
raise TypeError(
"class %s already has a method %s" % (cls, meth)
)
setattr(cls, meth, instrument(meth))
for prop in attributes:
if hasattr(cls, prop):
raise TypeError(
"class %s already has a method %s" % (cls, prop)
)
setattr(cls, prop, makeprop(prop))
for prop in classmethods:
if hasattr(cls, prop):
raise TypeError(
"class %s already has a method %s" % (cls, prop)
)
setattr(cls, prop, instrument(prop, clslevel=True))
return cls
return decorate
def getargspec_init(method):
+77 -49
View File
@@ -15,7 +15,19 @@ from sqlalchemy.testing.mock import Mock
from sqlalchemy.testing.util import gc_collect
class EventsTest(fixtures.TestBase):
class TearDownLocalEventsFixture(object):
def tearDown(self):
classes = set()
for entry in event.base._registrars.values():
for evt_cls in entry:
if evt_cls.__module__ == __name__:
classes.add(evt_cls)
for evt_cls in classes:
event.base._remove_dispatcher(evt_cls)
class EventsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test class- and instance-level event registration."""
def setUp(self):
@@ -34,9 +46,6 @@ class EventsTest(fixtures.TestBase):
self.Target = Target
def tearDown(self):
event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events)
def test_register_class(self):
def listen(x, y):
pass
@@ -258,7 +267,60 @@ class EventsTest(fixtures.TestBase):
)
class NamedCallTest(fixtures.TestBase):
class SlotsEventsTest(fixtures.TestBase):
@testing.requires.python3
def test_no_slots_dispatch(self):
class Target(object):
__slots__ = ()
class TargetEvents(event.Events):
_dispatch_target = Target
def event_one(self, x, y):
pass
def event_two(self, x):
pass
def event_three(self, x):
pass
t1 = Target()
with testing.expect_raises_message(
TypeError,
r"target .*Target.* doesn't have __dict__, should it "
"be defining _slots_dispatch",
):
event.listen(t1, "event_one", Mock())
def test_slots_dispatch(self):
class Target(object):
__slots__ = ("_slots_dispatch",)
class TargetEvents(event.Events):
_dispatch_target = Target
def event_one(self, x, y):
pass
def event_two(self, x):
pass
def event_three(self, x):
pass
t1 = Target()
m1 = Mock()
event.listen(t1, "event_one", m1)
t1.dispatch.event_one(2, 4)
eq_(m1.mock_calls, [call(2, 4)])
class NamedCallTest(TearDownLocalEventsFixture, fixtures.TestBase):
def _fixture(self):
class TargetEventsOne(event.Events):
def event_one(self, x, y):
@@ -373,7 +435,7 @@ class NamedCallTest(fixtures.TestBase):
eq_(canary.mock_calls, [call({"x": 4, "y": 5, "z": 8, "q": 5})])
class LegacySignatureTest(fixtures.TestBase):
class LegacySignatureTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""test adaption of legacy args"""
def setUp(self):
@@ -397,11 +459,6 @@ class LegacySignatureTest(fixtures.TestBase):
self.TargetOne = TargetOne
def tearDown(self):
event.base._remove_dispatcher(
self.TargetOne.__dict__["dispatch"].events
)
def test_legacy_accept(self):
canary = Mock()
@@ -550,12 +607,7 @@ class LegacySignatureTest(fixtures.TestBase):
)
class ClsLevelListenTest(fixtures.TestBase):
def tearDown(self):
event.base._remove_dispatcher(
self.TargetOne.__dict__["dispatch"].events
)
class ClsLevelListenTest(TearDownLocalEventsFixture, fixtures.TestBase):
def setUp(self):
class TargetEventsOne(event.Events):
def event_one(self, x, y):
@@ -622,7 +674,7 @@ class ClsLevelListenTest(fixtures.TestBase):
assert handler2 not in s2.dispatch.event_one
class AcceptTargetsTest(fixtures.TestBase):
class AcceptTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test default target acceptance."""
def setUp(self):
@@ -643,14 +695,6 @@ class AcceptTargetsTest(fixtures.TestBase):
self.TargetOne = TargetOne
self.TargetTwo = TargetTwo
def tearDown(self):
event.base._remove_dispatcher(
self.TargetOne.__dict__["dispatch"].events
)
event.base._remove_dispatcher(
self.TargetTwo.__dict__["dispatch"].events
)
def test_target_accept(self):
"""Test that events of the same name are routed to the correct
collection based on the type of target given.
@@ -687,7 +731,7 @@ class AcceptTargetsTest(fixtures.TestBase):
eq_(list(t2.dispatch.event_one), [listen_two, listen_four])
class CustomTargetsTest(fixtures.TestBase):
class CustomTargetsTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test custom target acceptance."""
def setUp(self):
@@ -707,9 +751,6 @@ class CustomTargetsTest(fixtures.TestBase):
self.Target = Target
def tearDown(self):
event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events)
def test_indirect(self):
def listen(x, y):
pass
@@ -727,7 +768,7 @@ class CustomTargetsTest(fixtures.TestBase):
)
class SubclassGrowthTest(fixtures.TestBase):
class SubclassGrowthTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""test that ad-hoc subclasses are garbage collected."""
def setUp(self):
@@ -752,7 +793,7 @@ class SubclassGrowthTest(fixtures.TestBase):
eq_(self.Target.__subclasses__(), [])
class ListenOverrideTest(fixtures.TestBase):
class ListenOverrideTest(TearDownLocalEventsFixture, fixtures.TestBase):
"""Test custom listen functions which change the listener function
signature."""
@@ -778,9 +819,6 @@ class ListenOverrideTest(fixtures.TestBase):
self.Target = Target
def tearDown(self):
event.base._remove_dispatcher(self.Target.__dict__["dispatch"].events)
def test_listen_override(self):
listen_one = Mock()
listen_two = Mock()
@@ -816,7 +854,7 @@ class ListenOverrideTest(fixtures.TestBase):
eq_(listen_one.mock_calls, [call(12)])
class PropagateTest(fixtures.TestBase):
class PropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
def setUp(self):
class TargetEvents(event.Events):
def event_one(self, arg):
@@ -850,7 +888,7 @@ class PropagateTest(fixtures.TestBase):
eq_(listen_two.mock_calls, [])
class JoinTest(fixtures.TestBase):
class JoinTest(TearDownLocalEventsFixture, fixtures.TestBase):
def setUp(self):
class TargetEvents(event.Events):
def event_one(self, target, arg):
@@ -875,11 +913,6 @@ class JoinTest(fixtures.TestBase):
self.TargetFactory = TargetFactory
self.TargetElement = TargetElement
def tearDown(self):
for cls in (self.TargetElement, self.TargetFactory, self.BaseTarget):
if "dispatch" in cls.__dict__:
event.base._remove_dispatcher(cls.__dict__["dispatch"].events)
def test_neither(self):
element = self.TargetFactory().create()
element.run_event(1)
@@ -1075,7 +1108,7 @@ class JoinTest(fixtures.TestBase):
)
class DisableClsPropagateTest(fixtures.TestBase):
class DisableClsPropagateTest(TearDownLocalEventsFixture, fixtures.TestBase):
def setUp(self):
class TargetEvents(event.Events):
def event_one(self, target, arg):
@@ -1093,11 +1126,6 @@ class DisableClsPropagateTest(fixtures.TestBase):
self.BaseTarget = BaseTarget
self.SubTarget = SubTarget
def tearDown(self):
for cls in (self.SubTarget, self.BaseTarget):
if "dispatch" in cls.__dict__:
event.base._remove_dispatcher(cls.__dict__["dispatch"].events)
def test_listen_invoke_clslevel(self):
canary = Mock()
@@ -1132,7 +1160,7 @@ class DisableClsPropagateTest(fixtures.TestBase):
eq_(canary.mock_calls, [])
class RemovalTest(fixtures.TestBase):
class RemovalTest(TearDownLocalEventsFixture, fixtures.TestBase):
def _fixture(self):
class TargetEvents(event.Events):
def event_one(self, x, y):
+199 -163
View File
@@ -2300,7 +2300,14 @@ class SymbolTest(fixtures.TestBase):
class _Py3KFixtures(object):
pass
def _kw_only_fixture(self):
pass
def _kw_plus_posn_fixture(self):
pass
def _kw_opt_fixture(self):
pass
if util.py3k:
@@ -2321,9 +2328,193 @@ def _kw_opt_fixture(self, a, *, b, c="c"):
for k in _locals:
setattr(_Py3KFixtures, k, _locals[k])
py3k_fixtures = _Py3KFixtures()
class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
def _test_format_argspec_plus(self, fn, wanted, grouped=None):
@testing.combinations(
(
lambda: None,
{
"args": "()",
"self_arg": None,
"apply_kw": "()",
"apply_pos": "()",
"apply_pos_proxied": "()",
},
True,
),
(
lambda: None,
{
"args": "",
"self_arg": None,
"apply_kw": "",
"apply_pos": "",
"apply_pos_proxied": "",
},
False,
),
(
lambda self: None,
{
"args": "(self)",
"self_arg": "self",
"apply_kw": "(self)",
"apply_pos": "(self)",
"apply_pos_proxied": "()",
},
True,
),
(
lambda self: None,
{
"args": "self",
"self_arg": "self",
"apply_kw": "self",
"apply_pos": "self",
"apply_pos_proxied": "",
},
False,
),
(
lambda *a: None,
{
"args": "(*a)",
"self_arg": "a[0]",
"apply_kw": "(*a)",
"apply_pos": "(*a)",
"apply_pos_proxied": "(*a)",
},
True,
),
(
lambda **kw: None,
{
"args": "(**kw)",
"self_arg": None,
"apply_kw": "(**kw)",
"apply_pos": "(**kw)",
"apply_pos_proxied": "(**kw)",
},
True,
),
(
lambda *a, **kw: None,
{
"args": "(*a, **kw)",
"self_arg": "a[0]",
"apply_kw": "(*a, **kw)",
"apply_pos": "(*a, **kw)",
"apply_pos_proxied": "(*a, **kw)",
},
True,
),
(
lambda a, *b: None,
{
"args": "(a, *b)",
"self_arg": "a",
"apply_kw": "(a, *b)",
"apply_pos": "(a, *b)",
"apply_pos_proxied": "(*b)",
},
True,
),
(
lambda a, **b: None,
{
"args": "(a, **b)",
"self_arg": "a",
"apply_kw": "(a, **b)",
"apply_pos": "(a, **b)",
"apply_pos_proxied": "(**b)",
},
True,
),
(
lambda a, *b, **c: None,
{
"args": "(a, *b, **c)",
"self_arg": "a",
"apply_kw": "(a, *b, **c)",
"apply_pos": "(a, *b, **c)",
"apply_pos_proxied": "(*b, **c)",
},
True,
),
(
lambda a, b=1, **c: None,
{
"args": "(a, b=1, **c)",
"self_arg": "a",
"apply_kw": "(a, b=b, **c)",
"apply_pos": "(a, b, **c)",
"apply_pos_proxied": "(b, **c)",
},
True,
),
(
lambda a=1, b=2: None,
{
"args": "(a=1, b=2)",
"self_arg": "a",
"apply_kw": "(a=a, b=b)",
"apply_pos": "(a, b)",
"apply_pos_proxied": "(b)",
},
True,
),
(
lambda a=1, b=2: None,
{
"args": "a=1, b=2",
"self_arg": "a",
"apply_kw": "a=a, b=b",
"apply_pos": "a, b",
"apply_pos_proxied": "b",
},
False,
),
(
py3k_fixtures._kw_only_fixture,
{
"args": "self, a, *, b, c",
"self_arg": "self",
"apply_pos": "self, a, *, b, c",
"apply_kw": "self, a, b=b, c=c",
"apply_pos_proxied": "a, *, b, c",
},
False,
testing.requires.python3,
),
(
py3k_fixtures._kw_plus_posn_fixture,
{
"args": "self, a, *args, b, c",
"self_arg": "self",
"apply_pos": "self, a, *args, b, c",
"apply_kw": "self, a, b=b, c=c, *args",
"apply_pos_proxied": "a, *args, b, c",
},
False,
testing.requires.python3,
),
(
py3k_fixtures._kw_opt_fixture,
{
"args": "self, a, *, b, c='c'",
"self_arg": "self",
"apply_pos": "self, a, *, b, c",
"apply_kw": "self, a, b=b, c=c",
"apply_pos_proxied": "a, *, b, c",
},
False,
testing.requires.python3,
),
argnames="fn,wanted,grouped",
)
def test_specs(self, fn, wanted, grouped):
# test direct function
if grouped is None:
@@ -2340,167 +2531,6 @@ class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
parsed = util.format_argspec_plus(spec, grouped=grouped)
eq_(parsed, wanted)
def test_specs(self):
self._test_format_argspec_plus(
lambda: None,
{
"args": "()",
"self_arg": None,
"apply_kw": "()",
"apply_pos": "()",
},
)
self._test_format_argspec_plus(
lambda: None,
{"args": "", "self_arg": None, "apply_kw": "", "apply_pos": ""},
grouped=False,
)
self._test_format_argspec_plus(
lambda self: None,
{
"args": "(self)",
"self_arg": "self",
"apply_kw": "(self)",
"apply_pos": "(self)",
},
)
self._test_format_argspec_plus(
lambda self: None,
{
"args": "self",
"self_arg": "self",
"apply_kw": "self",
"apply_pos": "self",
},
grouped=False,
)
self._test_format_argspec_plus(
lambda *a: None,
{
"args": "(*a)",
"self_arg": "a[0]",
"apply_kw": "(*a)",
"apply_pos": "(*a)",
},
)
self._test_format_argspec_plus(
lambda **kw: None,
{
"args": "(**kw)",
"self_arg": None,
"apply_kw": "(**kw)",
"apply_pos": "(**kw)",
},
)
self._test_format_argspec_plus(
lambda *a, **kw: None,
{
"args": "(*a, **kw)",
"self_arg": "a[0]",
"apply_kw": "(*a, **kw)",
"apply_pos": "(*a, **kw)",
},
)
self._test_format_argspec_plus(
lambda a, *b: None,
{
"args": "(a, *b)",
"self_arg": "a",
"apply_kw": "(a, *b)",
"apply_pos": "(a, *b)",
},
)
self._test_format_argspec_plus(
lambda a, **b: None,
{
"args": "(a, **b)",
"self_arg": "a",
"apply_kw": "(a, **b)",
"apply_pos": "(a, **b)",
},
)
self._test_format_argspec_plus(
lambda a, *b, **c: None,
{
"args": "(a, *b, **c)",
"self_arg": "a",
"apply_kw": "(a, *b, **c)",
"apply_pos": "(a, *b, **c)",
},
)
self._test_format_argspec_plus(
lambda a, b=1, **c: None,
{
"args": "(a, b=1, **c)",
"self_arg": "a",
"apply_kw": "(a, b=b, **c)",
"apply_pos": "(a, b, **c)",
},
)
self._test_format_argspec_plus(
lambda a=1, b=2: None,
{
"args": "(a=1, b=2)",
"self_arg": "a",
"apply_kw": "(a=a, b=b)",
"apply_pos": "(a, b)",
},
)
self._test_format_argspec_plus(
lambda a=1, b=2: None,
{
"args": "a=1, b=2",
"self_arg": "a",
"apply_kw": "a=a, b=b",
"apply_pos": "a, b",
},
grouped=False,
)
if util.py3k:
self._test_format_argspec_plus(
self._kw_only_fixture,
{
"args": "self, a, *, b, c",
"self_arg": "self",
"apply_pos": "self, a, *, b, c",
"apply_kw": "self, a, b=b, c=c",
},
grouped=False,
)
self._test_format_argspec_plus(
self._kw_plus_posn_fixture,
{
"args": "self, a, *args, b, c",
"self_arg": "self",
"apply_pos": "self, a, *args, b, c",
"apply_kw": "self, a, b=b, c=c, *args",
},
grouped=False,
)
self._test_format_argspec_plus(
self._kw_opt_fixture,
{
"args": "self, a, *, b, c='c'",
"self_arg": "self",
"apply_pos": "self, a, *, b, c",
"apply_kw": "self, a, b=b, c=c",
},
grouped=False,
)
@testing.requires.cpython
def test_init_grouped(self):
object_spec = {
@@ -2508,17 +2538,20 @@ class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
"self_arg": "self",
"apply_pos": "(self)",
"apply_kw": "(self)",
"apply_pos_proxied": "()",
}
wrapper_spec = {
"args": "(self, *args, **kwargs)",
"self_arg": "self",
"apply_pos": "(self, *args, **kwargs)",
"apply_kw": "(self, *args, **kwargs)",
"apply_pos_proxied": "(*args, **kwargs)",
}
custom_spec = {
"args": "(slef, a=123)",
"self_arg": "slef", # yes, slef
"apply_pos": "(slef, a)",
"apply_pos_proxied": "(a)",
"apply_kw": "(slef, a=a)",
}
@@ -2532,18 +2565,21 @@ class TestFormatArgspec(_Py3KFixtures, fixtures.TestBase):
"self_arg": "self",
"apply_pos": "self",
"apply_kw": "self",
"apply_pos_proxied": "",
}
wrapper_spec = {
"args": "self, *args, **kwargs",
"self_arg": "self",
"apply_pos": "self, *args, **kwargs",
"apply_kw": "self, *args, **kwargs",
"apply_pos_proxied": "*args, **kwargs",
}
custom_spec = {
"args": "slef, a=123",
"self_arg": "slef", # yes, slef
"apply_pos": "slef, a",
"apply_kw": "slef, a=a",
"apply_pos_proxied": "a",
}
self._test_init(False, object_spec, wrapper_spec, custom_spec)
+182
View File
@@ -2,6 +2,7 @@ import asyncio
from sqlalchemy import Column
from sqlalchemy import delete
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import func
from sqlalchemy import Integer
@@ -9,13 +10,19 @@ from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import text
from sqlalchemy import union_all
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import engine as _async_engine
from sqlalchemy.ext.asyncio import exc as asyncio_exc
from sqlalchemy.testing import async_test
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_
from sqlalchemy.testing import is_not
from sqlalchemy.testing import mock
from sqlalchemy.testing.asyncio import assert_raises_message_async
from sqlalchemy.util.concurrency import greenlet_spawn
class EngineFixture(fixtures.TablesTest):
@@ -50,6 +57,117 @@ class EngineFixture(fixtures.TablesTest):
class AsyncEngineTest(EngineFixture):
__backend__ = True
def test_proxied_attrs_engine(self, async_engine):
sync_engine = async_engine.sync_engine
is_(async_engine.url, sync_engine.url)
is_(async_engine.pool, sync_engine.pool)
is_(async_engine.dialect, sync_engine.dialect)
eq_(async_engine.name, sync_engine.name)
eq_(async_engine.driver, sync_engine.driver)
eq_(async_engine.echo, sync_engine.echo)
def test_clear_compiled_cache(self, async_engine):
async_engine.sync_engine._compiled_cache["foo"] = "bar"
eq_(async_engine.sync_engine._compiled_cache["foo"], "bar")
async_engine.clear_compiled_cache()
assert "foo" not in async_engine.sync_engine._compiled_cache
def test_execution_options(self, async_engine):
a2 = async_engine.execution_options(foo="bar")
assert isinstance(a2, _async_engine.AsyncEngine)
eq_(a2.sync_engine._execution_options, {"foo": "bar"})
eq_(async_engine.sync_engine._execution_options, {})
"""
attr uri, pool, dialect, engine, name, driver, echo
methods clear_compiled_cache, update_execution_options,
execution_options, get_execution_options, dispose
"""
@async_test
async def test_proxied_attrs_connection(self, async_engine):
conn = await async_engine.connect()
sync_conn = conn.sync_connection
is_(conn.engine, async_engine)
is_(conn.closed, sync_conn.closed)
is_(conn.dialect, async_engine.sync_engine.dialect)
eq_(conn.default_isolation_level, sync_conn.default_isolation_level)
@async_test
async def test_invalidate(self, async_engine):
conn = await async_engine.connect()
is_(conn.invalidated, False)
connection_fairy = await conn.get_raw_connection()
is_(connection_fairy.is_valid, True)
dbapi_connection = connection_fairy.connection
await conn.invalidate()
assert dbapi_connection._connection.is_closed()
new_fairy = await conn.get_raw_connection()
is_not(new_fairy.connection, dbapi_connection)
is_not(new_fairy, connection_fairy)
is_(new_fairy.is_valid, True)
is_(connection_fairy.is_valid, False)
@async_test
async def test_get_dbapi_connection_raise(self, async_engine):
conn = await async_engine.connect()
with testing.expect_raises_message(
exc.InvalidRequestError,
"AsyncConnection.connection accessor is not "
"implemented as the attribute",
):
conn.connection
@async_test
async def test_get_raw_connection(self, async_engine):
conn = await async_engine.connect()
pooled = await conn.get_raw_connection()
is_(pooled, conn.sync_connection.connection)
@async_test
async def test_isolation_level(self, async_engine):
conn = await async_engine.connect()
sync_isolation_level = await greenlet_spawn(
conn.sync_connection.get_isolation_level
)
isolation_level = await conn.get_isolation_level()
eq_(isolation_level, sync_isolation_level)
await conn.execution_options(isolation_level="SERIALIZABLE")
isolation_level = await conn.get_isolation_level()
eq_(isolation_level, "SERIALIZABLE")
@async_test
async def test_dispose(self, async_engine):
c1 = await async_engine.connect()
c2 = await async_engine.connect()
await c1.close()
await c2.close()
p1 = async_engine.pool
eq_(async_engine.pool.checkedin(), 2)
await async_engine.dispose()
eq_(async_engine.pool.checkedin(), 0)
is_not(p1, async_engine.pool)
@async_test
async def test_init_once_concurrency(self, async_engine):
c1 = async_engine.connect()
@@ -169,6 +287,70 @@ class AsyncEngineTest(EngineFixture):
)
class AsyncEventTest(EngineFixture):
"""The engine events all run in their normal synchronous context.
we do not provide an asyncio event interface at this time.
"""
__backend__ = True
@async_test
async def test_no_async_listeners(self, async_engine):
with testing.expect_raises_message(
NotImplementedError,
"asynchronous events are not implemented "
"at this time. Apply synchronous listeners to the "
"AsyncEngine.sync_engine or "
"AsyncConnection.sync_connection attributes.",
):
event.listen(async_engine, "before_cursor_execute", mock.Mock())
conn = await async_engine.connect()
with testing.expect_raises_message(
NotImplementedError,
"asynchronous events are not implemented "
"at this time. Apply synchronous listeners to the "
"AsyncEngine.sync_engine or "
"AsyncConnection.sync_connection attributes.",
):
event.listen(conn, "before_cursor_execute", mock.Mock())
@async_test
async def test_sync_before_cursor_execute_engine(self, async_engine):
canary = mock.Mock()
event.listen(async_engine.sync_engine, "before_cursor_execute", canary)
async with async_engine.connect() as conn:
sync_conn = conn.sync_connection
await conn.execute(text("select 1"))
eq_(
canary.mock_calls,
[mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)],
)
@async_test
async def test_sync_before_cursor_execute_connection(self, async_engine):
canary = mock.Mock()
async with async_engine.connect() as conn:
sync_conn = conn.sync_connection
event.listen(
async_engine.sync_engine, "before_cursor_execute", canary
)
await conn.execute(text("select 1"))
eq_(
canary.mock_calls,
[mock.call(sync_conn, mock.ANY, "select 1", (), mock.ANY, False)],
)
class AsyncResultTest(EngineFixture):
@testing.combinations(
(None,), ("scalars",), ("mappings",), argnames="filter_"
+58
View File
@@ -1,3 +1,4 @@
from sqlalchemy import event
from sqlalchemy import exc
from sqlalchemy import func
from sqlalchemy import select
@@ -9,6 +10,7 @@ from sqlalchemy.orm import selectinload
from sqlalchemy.testing import async_test
from sqlalchemy.testing import eq_
from sqlalchemy.testing import is_
from sqlalchemy.testing import mock
from ...orm import _fixtures
@@ -139,6 +141,27 @@ class AsyncSessionTransactionTest(AsyncFixture):
eq_(await outer_conn.scalar(select(func.count(User.id))), 1)
@async_test
async def test_delete(self, async_session):
User = self.classes.User
async with async_session.begin():
u1 = User(name="u1")
async_session.add(u1)
await async_session.flush()
conn = await async_session.connection()
eq_(await conn.scalar(select(func.count(User.id))), 1)
async_session.delete(u1)
await async_session.flush()
eq_(await conn.scalar(select(func.count(User.id))), 0)
@async_test
async def test_flush(self, async_session):
User = self.classes.User
@@ -198,3 +221,38 @@ class AsyncSessionTransactionTest(AsyncFixture):
is_(new_u_merged, u1)
eq_(u1.name, "new u1")
class AsyncEventTest(AsyncFixture):
"""The engine events all run in their normal synchronous context.
we do not provide an asyncio event interface at this time.
"""
__backend__ = True
@async_test
async def test_no_async_listeners(self, async_session):
with testing.expect_raises(
NotImplementedError,
"NotImplementedError: asynchronous events are not implemented "
"at this time. Apply synchronous listeners to the "
"AsyncEngine.sync_engine or "
"AsyncConnection.sync_connection attributes.",
):
event.listen(async_session, "before_flush", mock.Mock())
@async_test
async def test_sync_before_commit(self, async_session):
canary = mock.Mock()
event.listen(async_session.sync_session, "before_commit", canary)
async with async_session.begin():
pass
eq_(
canary.mock_calls,
[mock.call(async_session.sync_session)],
)
+24
View File
@@ -10,6 +10,7 @@ from sqlalchemy.orm import scoped_session
from sqlalchemy.testing import assert_raises_message
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import mock
from sqlalchemy.testing.mock import Mock
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
@@ -127,3 +128,26 @@ class ScopedSessionTest(fixtures.MappedTest):
mock_scope_func.return_value = 1
s2 = Session(autocommit=True)
assert s2.autocommit == True
def test_methods_etc(self):
mock_session = Mock()
mock_session.bind = "the bind"
sess = scoped_session(lambda: mock_session)
sess.add("add")
sess.delete("delete")
eq_(sess.bind, "the bind")
eq_(
mock_session.mock_calls,
[mock.call.add("add", True), mock.call.delete("delete")],
)
with mock.patch(
"sqlalchemy.orm.session.object_session"
) as mock_object_session:
sess.object_session("foo")
eq_(mock_object_session.mock_calls, [mock.call("foo")])
+30 -25
View File
@@ -1,3 +1,5 @@
import inspect as _py_inspect
import sqlalchemy as sa
from sqlalchemy import event
from sqlalchemy import ForeignKey
@@ -1820,22 +1822,28 @@ class DisposedStates(fixtures.MappedTest):
class SessionInterface(fixtures.TestBase):
"""Bogus args to Session methods produce actionable exceptions."""
# TODO: expand with message body assertions.
_class_methods = set(("connection", "execute", "get_bind", "scalar"))
def _public_session_methods(self):
Session = sa.orm.session.Session
blacklist = set(("begin", "query"))
blacklist = {"begin", "query", "bind_mapper", "get", "bind_table"}
specials = {"__iter__", "__contains__"}
ok = set()
for meth in Session.public_methods:
if meth in blacklist:
continue
spec = inspect_getfullargspec(getattr(Session, meth))
if len(spec[0]) > 1 or spec[1]:
ok.add(meth)
for name in dir(Session):
if (
name in Session.__dict__
and (not name.startswith("_") or name in specials)
and (
_py_inspect.ismethod(getattr(Session, name))
or _py_inspect.isfunction(getattr(Session, name))
)
):
if name in blacklist:
continue
spec = inspect_getfullargspec(getattr(Session, name))
if len(spec[0]) > 1 or spec[1]:
ok.add(name)
return ok
def _map_it(self, cls):
@@ -1866,18 +1874,21 @@ class SessionInterface(fixtures.TestBase):
def raises_(method, *args, **kw):
x_raises_(create_session(), method, *args, **kw)
raises_("__contains__", user_arg)
raises_("add", user_arg)
for name in [
"__contains__",
"is_modified",
"merge",
"refresh",
"add",
"delete",
"expire",
"expunge",
"enable_relationship_loading",
]:
raises_(name, user_arg)
raises_("add_all", (user_arg,))
raises_("delete", user_arg)
raises_("expire", user_arg)
raises_("expunge", user_arg)
# flush will no-op without something in the unit of work
def _():
class OK(object):
@@ -1891,12 +1902,6 @@ class SessionInterface(fixtures.TestBase):
_()
raises_("is_modified", user_arg)
raises_("merge", user_arg)
raises_("refresh", user_arg)
instance_methods = (
self._public_session_methods()
- self._class_methods