use code generation for scoped_session

our decorator thing generates code in any case,
so point it at the file itself to generate real code
for the blocks rather than doing things dynamically.

this will allow typing tools to have no problem
whatsoever and we also reduce import time overhead.
file size will be a lot bigger though, shrugs.

syntax / dupe method / etc. checking will be accomplished
by our existing linting / typing / formatting tools.

As we are also using "from __future__ import annotations",
we also no longer have to apply quotes to generated
annotations.

Change-Id: I20962cb65bda63ff0fb67357ab346e9b1ef4f108
This commit is contained in:
Mike Bayer
2022-04-05 19:00:19 -04:00
committed by mike bayer
parent 15ef11e0ed
commit 98eae4e181
13 changed files with 3936 additions and 127 deletions
+1
View File
@@ -4,6 +4,7 @@
recursive-include doc *.html *.css *.txt *.js *.png *.py Makefile *.rst *.sty
recursive-include examples *.py *.xml
recursive-include test *.py *.dat *.testpatch
recursive-include tools *.py
# for some reason in some environments stale Cython .c files
# are being pulled in, these should never be in a dist
+6 -2
View File
@@ -2672,14 +2672,18 @@ class Engine(
@property
def name(self) -> str:
"""String name of the :class:`~sqlalchemy.engine.interfaces.Dialect`
in use by this :class:`Engine`."""
in use by this :class:`Engine`.
"""
return self.dialect.name
@property
def driver(self) -> str:
"""Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect`
in use by this :class:`Engine`."""
in use by this :class:`Engine`.
"""
return self.dialect.driver
+273 -2
View File
@@ -4,6 +4,10 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
from . import exc as async_exc
from .base import ProxyComparable
from .base import StartableContext
@@ -72,7 +76,7 @@ class AsyncConnectable:
@util.create_proxy_methods(
Connection,
":class:`_future.Connection`",
":class:`_engine.Connection`",
":class:`_asyncio.AsyncConnection`",
classmethods=[],
methods=[],
@@ -125,6 +129,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
.. seealso::
:ref:`asyncio_events`
"""
sync_engine: Engine
@@ -137,6 +142,7 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
.. seealso::
:ref:`asyncio_events`
"""
@classmethod
@@ -552,10 +558,100 @@ class AsyncConnection(ProxyComparable, StartableContext, AsyncConnectable):
async def __aexit__(self, type_, value, traceback):
await self.close()
# START PROXY METHODS AsyncConnection
# code within this block is **programmatically,
# statically generated** by tools/generate_proxy_methods.py
@property
def closed(self) -> Any:
r"""Return True if this connection is closed.
.. container:: class_bases
Proxied for the :class:`_engine.Connection` class
on behalf of the :class:`_asyncio.AsyncConnection` class.
""" # noqa: E501
return self._proxied.closed
@property
def invalidated(self) -> Any:
r"""Return True if this connection was invalidated.
.. container:: class_bases
Proxied for the :class:`_engine.Connection` class
on behalf of the :class:`_asyncio.AsyncConnection` class.
This does not indicate whether or not the connection was
invalidated at the pool level, however
""" # noqa: E501
return self._proxied.invalidated
@property
def dialect(self) -> Any:
r"""Proxy for the :attr:`_engine.Connection.dialect` attribute
on behalf of the :class:`_asyncio.AsyncConnection` class.
""" # noqa: E501
return self._proxied.dialect
@dialect.setter
def dialect(self, attr: Any) -> None:
self._proxied.dialect = attr
@property
def default_isolation_level(self) -> Any:
r"""The default isolation level assigned to this
:class:`_engine.Connection`.
.. container:: class_bases
Proxied for the :class:`_engine.Connection` class
on behalf of the :class:`_asyncio.AsyncConnection` class.
This is the isolation level setting that the
:class:`_engine.Connection`
has when first procured via the :meth:`_engine.Engine.connect` method.
This level stays in place until the
:paramref:`.Connection.execution_options.isolation_level` is used
to change the setting on a per-:class:`_engine.Connection` basis.
Unlike :meth:`_engine.Connection.get_isolation_level`,
this attribute is set
ahead of time from the first connection procured by the dialect,
so SQL query is not invoked when this accessor is called.
.. versionadded:: 0.9.9
.. seealso::
:meth:`_engine.Connection.get_isolation_level`
- view current level
:paramref:`_sa.create_engine.isolation_level`
- set per :class:`_engine.Engine` isolation level
:paramref:`.Connection.execution_options.isolation_level`
- set per :class:`_engine.Connection` isolation level
""" # noqa: E501
return self._proxied.default_isolation_level
# END PROXY METHODS AsyncConnection
@util.create_proxy_methods(
Engine,
":class:`_future.Engine`",
":class:`_engine.Engine`",
":class:`_asyncio.AsyncEngine`",
classmethods=[],
methods=[
@@ -701,6 +797,181 @@ class AsyncEngine(ProxyComparable, AsyncConnectable):
return await greenlet_spawn(self.sync_engine.dispose)
# START PROXY METHODS AsyncEngine
# code within this block is **programmatically,
# statically generated** by tools/generate_proxy_methods.py
def clear_compiled_cache(self) -> None:
r"""Clear the compiled cache associated with the dialect.
.. container:: class_bases
Proxied for the :class:`_engine.Engine` class on
behalf of the :class:`_asyncio.AsyncEngine` class.
This applies **only** to the built-in cache that is established
via the :paramref:`_engine.create_engine.query_cache_size` parameter.
It will not impact any dictionary caches that were passed via the
:paramref:`.Connection.execution_options.query_cache` parameter.
.. versionadded:: 1.4
""" # noqa: E501
return self._proxied.clear_compiled_cache()
def update_execution_options(self, **opt: Any) -> None:
r"""Update the default execution_options dictionary
of this :class:`_engine.Engine`.
.. container:: class_bases
Proxied for the :class:`_engine.Engine` class on
behalf of the :class:`_asyncio.AsyncEngine` class.
The given keys/values in \**opt are added to the
default execution options that will be used for
all connections. The initial contents of this dictionary
can be sent via the ``execution_options`` parameter
to :func:`_sa.create_engine`.
.. seealso::
:meth:`_engine.Connection.execution_options`
:meth:`_engine.Engine.execution_options`
""" # noqa: E501
return self._proxied.update_execution_options(**opt)
def get_execution_options(self) -> _ExecuteOptions:
r"""Get the non-SQL options which will take effect during execution.
.. container:: class_bases
Proxied for the :class:`_engine.Engine` class on
behalf of the :class:`_asyncio.AsyncEngine` class.
.. versionadded: 1.3
.. seealso::
:meth:`_engine.Engine.execution_options`
""" # noqa: E501
return self._proxied.get_execution_options()
@property
def url(self) -> URL:
r"""Proxy for the :attr:`_engine.Engine.url` attribute
on behalf of the :class:`_asyncio.AsyncEngine` class.
""" # noqa: E501
return self._proxied.url
@url.setter
def url(self, attr: URL) -> None:
self._proxied.url = attr
@property
def pool(self) -> Pool:
r"""Proxy for the :attr:`_engine.Engine.pool` attribute
on behalf of the :class:`_asyncio.AsyncEngine` class.
""" # noqa: E501
return self._proxied.pool
@pool.setter
def pool(self, attr: Pool) -> None:
self._proxied.pool = attr
@property
def dialect(self) -> Dialect:
r"""Proxy for the :attr:`_engine.Engine.dialect` attribute
on behalf of the :class:`_asyncio.AsyncEngine` class.
""" # noqa: E501
return self._proxied.dialect
@dialect.setter
def dialect(self, attr: Dialect) -> None:
self._proxied.dialect = attr
@property
def engine(self) -> Any:
r""".. container:: class_bases
Proxied for the :class:`_engine.Engine` class
on behalf of the :class:`_asyncio.AsyncEngine` class.
""" # noqa: E501
return self._proxied.engine
@property
def name(self) -> Any:
r"""String name of the :class:`~sqlalchemy.engine.interfaces.Dialect`
in use by this :class:`Engine`.
.. container:: class_bases
Proxied for the :class:`_engine.Engine` class
on behalf of the :class:`_asyncio.AsyncEngine` class.
""" # noqa: E501
return self._proxied.name
@property
def driver(self) -> Any:
r"""Driver name of the :class:`~sqlalchemy.engine.interfaces.Dialect`
in use by this :class:`Engine`.
.. container:: class_bases
Proxied for the :class:`_engine.Engine` class
on behalf of the :class:`_asyncio.AsyncEngine` class.
""" # noqa: E501
return self._proxied.driver
@property
def echo(self) -> Any:
r"""When ``True``, enable log output for this element.
.. container:: class_bases
Proxied for the :class:`_engine.Engine` class
on behalf of the :class:`_asyncio.AsyncEngine` class.
This has the effect of setting the Python logging level for the namespace
of this element's class and object reference. A value of boolean ``True``
indicates that the loglevel ``logging.INFO`` will be set for the logger,
whereas the string value ``debug`` will set the loglevel to
``logging.DEBUG``.
""" # noqa: E501
return self._proxied.echo
@echo.setter
def echo(self, attr: Any) -> None:
self._proxied.echo = attr
# END PROXY METHODS AsyncEngine
class AsyncTransaction(ProxyComparable, StartableContext):
"""An asyncio proxy for a :class:`_engine.Transaction`."""
File diff suppressed because it is too large Load Diff
+509 -1
View File
@@ -4,6 +4,10 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from __future__ import annotations
from typing import Any
from . import engine
from . import result as _result
from .base import ReversibleProxy
@@ -312,7 +316,9 @@ class AsyncSession(ReversibleProxy):
**kw,
):
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object."""
:class:`_asyncio.AsyncResult` object.
"""
if execution_options:
execution_options = util.immutabledict(execution_options).union(
@@ -629,6 +635,508 @@ class AsyncSession(ReversibleProxy):
# TODO: can this use asynccontextmanager ??
return _AsyncSessionContextManager(self)
# START PROXY METHODS AsyncSession
# code within this block is **programmatically,
# statically generated** by tools/generate_proxy_methods.py
def __contains__(self, instance):
r"""Return True if the instance is associated with this session.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
The instance may be pending or persistent within the Session for a
result of True.
""" # noqa: E501
return self._proxied.__contains__(instance)
def __iter__(self):
r"""Iterate over all pending or persistent instances within this
Session.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
""" # noqa: E501
return self._proxied.__iter__()
def add(self, instance: Any, _warn: bool = True) -> None:
r"""Place an object in the ``Session``.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
Its state will be persisted to the database on the next flush
operation.
Repeated calls to ``add()`` will be ignored. The opposite of ``add()``
is ``expunge()``.
""" # noqa: E501
return self._proxied.add(instance, _warn=_warn)
def add_all(self, instances):
r"""Add the given collection of instances to this ``Session``.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
""" # noqa: E501
return self._proxied.add_all(instances)
def expire(self, instance, attribute_names=None):
r"""Expire the attributes on an instance.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
Marks the attributes of an instance as out of date. When an expired
attribute is next accessed, a query will be issued to the
:class:`.Session` object's current transactional context in order to
load all expired attributes for the given instance. Note that
a highly isolated transaction will return the same values as were
previously read in that same transaction, regardless of changes
in database state outside of that transaction.
To expire all objects in the :class:`.Session` simultaneously,
use :meth:`Session.expire_all`.
The :class:`.Session` object's default behavior is to
expire all state whenever the :meth:`Session.rollback`
or :meth:`Session.commit` methods are called, so that new
state can be loaded for the new transaction. For this reason,
calling :meth:`Session.expire` only makes sense for the specific
case that a non-ORM SQL statement was emitted in the current
transaction.
:param instance: The instance to be refreshed.
:param attribute_names: optional list of string attribute names
indicating a subset of attributes to be expired.
.. seealso::
:ref:`session_expire` - introductory material
:meth:`.Session.expire`
:meth:`.Session.refresh`
:meth:`_orm.Query.populate_existing`
""" # noqa: E501
return self._proxied.expire(instance, attribute_names=attribute_names)
def expire_all(self):
r"""Expires all persistent instances within this Session.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
When any attributes on a persistent instance is next accessed,
a query will be issued using the
:class:`.Session` object's current transactional context in order to
load all expired attributes for the given instance. Note that
a highly isolated transaction will return the same values as were
previously read in that same transaction, regardless of changes
in database state outside of that transaction.
To expire individual objects and individual attributes
on those objects, use :meth:`Session.expire`.
The :class:`.Session` object's default behavior is to
expire all state whenever the :meth:`Session.rollback`
or :meth:`Session.commit` methods are called, so that new
state can be loaded for the new transaction. For this reason,
calling :meth:`Session.expire_all` is not usually needed,
assuming the transaction is isolated.
.. seealso::
:ref:`session_expire` - introductory material
:meth:`.Session.expire`
:meth:`.Session.refresh`
:meth:`_orm.Query.populate_existing`
""" # noqa: E501
return self._proxied.expire_all()
def expunge(self, instance):
r"""Remove the `instance` from this ``Session``.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
This will free all internal references to the instance. Cascading
will be applied according to the *expunge* cascade rule.
""" # noqa: E501
return self._proxied.expunge(instance)
def expunge_all(self):
r"""Remove all object instances from this ``Session``.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
This is equivalent to calling ``expunge(obj)`` on all objects in this
``Session``.
""" # noqa: E501
return self._proxied.expunge_all()
def is_modified(self, instance, include_collections=True):
r"""Return ``True`` if the given instance has locally
modified attributes.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
This method retrieves the history for each instrumented
attribute on the instance and performs a comparison of the current
value to its previously committed value, if any.
It is in effect a more expensive and accurate
version of checking for the given instance in the
:attr:`.Session.dirty` collection; a full test for
each attribute's net "dirty" status is performed.
E.g.::
return session.is_modified(someobject)
A few caveats to this method apply:
* Instances present in the :attr:`.Session.dirty` collection may
report ``False`` when tested with this method. This is because
the object may have received change events via attribute mutation,
thus placing it in :attr:`.Session.dirty`, but ultimately the state
is the same as that loaded from the database, resulting in no net
change here.
* Scalar attributes may not have recorded the previously set
value when a new value was applied, if the attribute was not loaded,
or was expired, at the time the new value was received - in these
cases, the attribute is assumed to have a change, even if there is
ultimately no net change against its database value. SQLAlchemy in
most cases does not need the "old" value when a set event occurs, so
it skips the expense of a SQL call if the old value isn't present,
based on the assumption that an UPDATE of the scalar value is
usually needed, and in those few cases where it isn't, is less
expensive on average than issuing a defensive SELECT.
The "old" value is fetched unconditionally upon set only if the
attribute container has the ``active_history`` flag set to ``True``.
This flag is set typically for primary key attributes and scalar
object references that are not a simple many-to-one. To set this
flag for any arbitrary mapped column, use the ``active_history``
argument with :func:`.column_property`.
:param instance: mapped instance to be tested for pending changes.
:param include_collections: Indicates if multivalued collections
should be included in the operation. Setting this to ``False`` is a
way to detect only local-column based properties (i.e. scalar columns
or many-to-one foreign keys) that would result in an UPDATE for this
instance upon flush.
""" # noqa: E501
return self._proxied.is_modified(
instance, include_collections=include_collections
)
def in_transaction(self):
r"""Return True if this :class:`_orm.Session` has begun a transaction.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
.. versionadded:: 1.4
.. seealso::
:attr:`_orm.Session.is_active`
""" # noqa: E501
return self._proxied.in_transaction()
def in_nested_transaction(self):
r"""Return True if this :class:`_orm.Session` has begun a nested
transaction, e.g. SAVEPOINT.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
.. versionadded:: 1.4
""" # noqa: E501
return self._proxied.in_nested_transaction()
@property
def dirty(self) -> Any:
r"""The set of all persistent instances considered dirty.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class
on behalf of the :class:`_asyncio.AsyncSession` class.
E.g.::
some_mapped_object in session.dirty
Instances are considered dirty when they were modified but not
deleted.
Note that this 'dirty' calculation is 'optimistic'; most
attribute-setting or collection modification operations will
mark an instance as 'dirty' and place it in this set, even if
there is no net change to the attribute's value. At flush
time, the value of each attribute is compared to its
previously saved value, and if there's no net change, no SQL
operation will occur (this is a more expensive operation so
it's only done at flush time).
To check if an instance has actionable net changes to its
attributes, use the :meth:`.Session.is_modified` method.
""" # noqa: E501
return self._proxied.dirty
@property
def deleted(self) -> Any:
r"""The set of all instances marked as 'deleted' within this ``Session``
.. container:: class_bases
Proxied for the :class:`_orm.Session` class
on behalf of the :class:`_asyncio.AsyncSession` class.
""" # noqa: E501
return self._proxied.deleted
@property
def new(self) -> Any:
r"""The set of all instances marked as 'new' within this ``Session``.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class
on behalf of the :class:`_asyncio.AsyncSession` class.
""" # noqa: E501
return self._proxied.new
@property
def identity_map(self) -> identity.IdentityMap:
r"""Proxy for the :attr:`_orm.Session.identity_map` attribute
on behalf of the :class:`_asyncio.AsyncSession` class.
""" # noqa: E501
return self._proxied.identity_map
@identity_map.setter
def identity_map(self, attr: identity.IdentityMap) -> None:
self._proxied.identity_map = attr
@property
def is_active(self) -> Any:
r"""True if this :class:`.Session` not in "partial rollback" state.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class
on behalf of the :class:`_asyncio.AsyncSession` class.
.. versionchanged:: 1.4 The :class:`_orm.Session` no longer begins
a new transaction immediately, so this attribute will be False
when the :class:`_orm.Session` is first instantiated.
"partial rollback" state typically indicates that the flush process
of the :class:`_orm.Session` has failed, and that the
:meth:`_orm.Session.rollback` method must be emitted in order to
fully roll back the transaction.
If this :class:`_orm.Session` is not in a transaction at all, the
:class:`_orm.Session` will autobegin when it is first used, so in this
case :attr:`_orm.Session.is_active` will return True.
Otherwise, if this :class:`_orm.Session` is within a transaction,
and that transaction has not been rolled back internally, the
:attr:`_orm.Session.is_active` will also return True.
.. seealso::
:ref:`faq_session_rollback`
:meth:`_orm.Session.in_transaction`
""" # noqa: E501
return self._proxied.is_active
@property
def autoflush(self) -> bool:
r"""Proxy for the :attr:`_orm.Session.autoflush` attribute
on behalf of the :class:`_asyncio.AsyncSession` class.
""" # noqa: E501
return self._proxied.autoflush
@autoflush.setter
def autoflush(self, attr: bool) -> None:
self._proxied.autoflush = attr
@property
def no_autoflush(self) -> Any:
r"""Return a context manager that disables autoflush.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class
on behalf of the :class:`_asyncio.AsyncSession` class.
e.g.::
with session.no_autoflush:
some_object = SomeClass()
session.add(some_object)
# won't autoflush
some_object.related_thing = session.query(SomeRelated).first()
Operations that proceed within the ``with:`` block
will not be subject to flushes occurring upon query
access. This is useful when initializing a series
of objects which involve existing database queries,
where the uncompleted object should not yet be flushed.
""" # noqa: E501
return self._proxied.no_autoflush
@property
def info(self) -> Any:
r"""A user-modifiable dictionary.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class
on behalf of the :class:`_asyncio.AsyncSession` class.
The initial value of this dictionary can be populated using the
``info`` argument to the :class:`.Session` constructor or
:class:`.sessionmaker` constructor or factory methods. The dictionary
here is always local to this :class:`.Session` and can be modified
independently of all other :class:`.Session` objects.
""" # noqa: E501
return self._proxied.info
@classmethod
def object_session(cls, instance: Any) -> "Session":
r"""Return the :class:`.Session` to which an object belongs.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
This is an alias of :func:`.object_session`.
""" # noqa: E501
return Session.object_session(instance)
@classmethod
def identity_key(
cls,
class_=None,
ident=None,
*,
instance=None,
row=None,
identity_token=None,
) -> _IdentityKeyType:
r"""Return an identity key.
.. container:: class_bases
Proxied for the :class:`_orm.Session` class on
behalf of the :class:`_asyncio.AsyncSession` class.
This is an alias of :func:`.util.identity_key`.
""" # noqa: E501
return Session.identity_key(
class_=class_,
ident=ident,
instance=instance,
row=row,
identity_token=identity_token,
)
# END PROXY METHODS AsyncSession
class _AsyncSessionContextManager:
def __init__(self, async_session):
File diff suppressed because it is too large Load Diff
+2 -4
View File
@@ -56,7 +56,6 @@ from ..sql import coercions
from ..sql import dml
from ..sql import roles
from ..sql import visitors
from ..sql._typing import _ColumnsClauseArgument
from ..sql.base import CompileState
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
from ..util.typing import Literal
@@ -64,6 +63,7 @@ from ..util.typing import Literal
if typing.TYPE_CHECKING:
from .mapper import Mapper
from ..engine import Row
from ..sql._typing import _ColumnsClauseArgument
from ..sql._typing import _ExecuteOptions
from ..sql._typing import _ExecuteParams
from ..sql.base import Executable
@@ -2043,9 +2043,7 @@ class Session(_SessionClassMethods):
% (", ".join(context),),
)
def query(
self, *entities: "_ColumnsClauseArgument", **kwargs: Any
) -> "Query":
def query(self, *entities: _ColumnsClauseArgument, **kwargs: Any) -> Query:
"""Return a new :class:`_query.Query` object corresponding to this
:class:`_orm.Session`.
+10 -7
View File
@@ -5,14 +5,8 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Testing extensions.
from __future__ import annotations
this module is designed to work as a testing-framework-agnostic library,
created so that multiple test frameworks can be supported at once
(mostly so that we can migrate to new ones). The current target
is pytest.
"""
import abc
import configparser
import logging
@@ -23,6 +17,15 @@ from typing import Any
from sqlalchemy.testing import asyncio
"""Testing extensions.
this module is designed to work as a testing-framework-agnostic library,
created so that multiple test frameworks can be supported at once
(mostly so that we can migrate to new ones). The current target
is pytest.
"""
# flag which indicates we are in the SQLAlchemy testing suite,
# and not that of Alembic or a third party dialect.
bootstrapped_as_sqlalchemy = False
+12 -6
View File
@@ -1,9 +1,4 @@
try:
# installed by bootstrap.py
import sqla_plugin_base as plugin_base
except ImportError:
# assume we're a package, use traditional import
from . import plugin_base
from __future__ import annotations
import argparse
import collections
@@ -17,6 +12,13 @@ import uuid
import pytest
try:
# installed by bootstrap.py
import sqla_plugin_base as plugin_base
except ImportError:
# assume we're a package, use traditional import
from . import plugin_base
def pytest_addoption(parser):
group = parser.getgroup("sqlalchemy")
@@ -565,6 +567,10 @@ def _pytest_fn_decorator(target):
from sqlalchemy.util.compat import inspect_getfullargspec
def _exec_code_in_env(code, env, fn_name):
# note this is affected by "from __future__ import annotations" at
# the top; exec'ed code will use non-evaluated annotations
# which allows us to be more flexible with code rendering
# in format_argpsec_plus()
exec(code, env)
return env[fn_name]
+4 -4
View File
@@ -139,17 +139,17 @@ def _formatannotation(annotation, base_module=None):
"""vendored from python 3.7"""
if isinstance(annotation, str):
return f'"{annotation}"'
return annotation
if getattr(annotation, "__module__", None) == "typing":
return f'"{repr(annotation).replace("typing.", "").replace("~", "")}"'
return repr(annotation).replace("typing.", "").replace("~", "")
if isinstance(annotation, type):
if annotation.__module__ in ("builtins", base_module):
return repr(annotation.__qualname__)
return annotation.__module__ + "." + annotation.__qualname__
elif isinstance(annotation, typing.TypeVar):
return f'"{repr(annotation).replace("~", "")}"'
return f'"{repr(annotation).replace("~", "")}"'
return repr(annotation).replace("~", "")
return repr(annotation).replace("~", "")
def inspect_formatargspec(
+6 -100
View File
@@ -678,111 +678,17 @@ def create_proxy_methods(
methods=(),
attributes=(),
):
"""A class decorator that will copy attributes to a proxy class.
"""A class decorator indicating attributes should refer to a proxy
class.
The class to be instrumented must define a single accessor "_proxied".
This decorator is now a "marker" that does nothing at runtime. Instead,
it is consumed by the tools/generate_proxy_methods.py script to
statically generate proxy methods and attributes that are fully
recognized by typing tools such as mypy.
"""
def decorate(cls):
def instrument(name, clslevel=False):
fn = cast(types.FunctionType, getattr(target_cls, name))
spec = compat.inspect_getfullargspec(fn)
env = {"__name__": fn.__module__}
spec = _update_argspec_defaults_into_env(spec, env)
caller_argspec = format_argspec_plus(spec, grouped=False)
metadata = {
"name": fn.__name__,
"apply_pos_proxied": caller_argspec["apply_pos_proxied"],
"apply_kw_proxied": caller_argspec["apply_kw_proxied"],
"grouped_args": caller_argspec["grouped_args"],
"self_arg": caller_argspec["self_arg"],
}
if clslevel:
code = (
"def %(name)s%(grouped_args)s:\n"
" return target_cls.%(name)s(%(apply_kw_proxied)s)"
% metadata
)
env["target_cls"] = target_cls
else:
code = (
"def %(name)s%(grouped_args)s:\n"
" return %(self_arg)s._proxied.%(name)s(%(apply_kw_proxied)s)" # noqa: E501
% metadata
)
proxy_fn = cast(
types.FunctionType, _exec_code_in_env(code, env, fn.__name__)
)
proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__
proxy_fn.__doc__ = inject_docstring_text(
fn.__doc__,
".. container:: class_bases\n\n "
"Proxied for the %s class on behalf of the %s class."
% (target_cls_sphinx_name, proxy_cls_sphinx_name),
1,
)
if clslevel:
return classmethod(proxy_fn)
else:
return proxy_fn
def makeprop(name):
attr = target_cls.__dict__.get(name, None)
if attr is not None:
doc = inject_docstring_text(
attr.__doc__,
".. container:: class_bases\n\n "
"Proxied for the %s class on behalf of the %s class."
% (
target_cls_sphinx_name,
proxy_cls_sphinx_name,
),
1,
)
else:
doc = None
code = (
"def set_(self, attr):\n"
" self._proxied.%(name)s = attr\n"
"def get(self):\n"
" return self._proxied.%(name)s\n"
"get.__doc__ = doc\n"
"getset = property(get, set_)"
) % {"name": name}
getset = _exec_code_in_env(code, {"doc": doc}, "getset")
return getset
for meth in methods:
if hasattr(cls, meth):
raise TypeError(
"class %s already has a method %s" % (cls, meth)
)
setattr(cls, meth, instrument(meth))
for prop in attributes:
if hasattr(cls, prop):
raise TypeError(
"class %s already has a method %s" % (cls, prop)
)
setattr(cls, prop, makeprop(prop))
for prop in classmethods:
if hasattr(cls, prop):
raise TypeError(
"class %s already has a method %s" % (cls, prop)
)
setattr(cls, prop, instrument(prop, clslevel=True))
return cls
return decorate
+2 -1
View File
@@ -37,8 +37,9 @@ class AsyncScopedSessionTest(AsyncFixture):
await AsyncSession.flush()
conn = await AsyncSession.connection()
stmt = select(func.count(User.id)).where(User.name == user_name)
eq_(await conn.scalar(stmt), 1)
eq_(await AsyncSession.scalar(stmt), 1)
await AsyncSession.delete(u1)
await AsyncSession.flush()
+413
View File
@@ -0,0 +1,413 @@
"""Generate static proxy code for SQLAlchemy classes that proxy other
objects.
This tool is run at source code authoring / commit time whenever we add new
methods to engines/connections/sessions that need to be generically proxied by
scoped_session or asyncio. The generated code is part of what's committed
to source just as though we typed it all by hand.
The original "proxy" class was scoped_session. Then with asyncio, all the
asyncio objects are essentially "proxy" objects as well; while all the methods
that are "async" needed to be written by hand, there's lots of other attributes
and methods that are proxied exactly.
To eliminate redundancy, all of these classes made use of the
@langhelpers.create_proxy_methods() decorator which at runtime would read a
selected list of methods and attributes from the proxied class and generate new
methods and properties descriptors on the proxying class; this way the proxy
would have all the same methods signatures / attributes / docstrings consumed
by Sphinx and look just like the proxied class.
Then mypy and typing came along, which don't care about runtime generated code
and never will. So this script takes that same
@langhelpers.create_proxy_methods() decorator, keeps its public interface just
as is, and uses it to generate all the code and docs in those proxy classes
statically, as though we sat there and spent seven hours typing it all by hand.
The runtime code generation part is removed from ``create_proxy_methods()``.
Now we have static code that is perfectly consumable by all the typing tools
and we also reduce import time a bit.
A similar approach is used in Alembic where a dynamic approach towards creating
alembic "ops" was enhanced to generate a .pyi stubs file statically for
consumption by typing tools.
.. versionadded:: 2.0
"""
from __future__ import annotations
from argparse import ArgumentParser
import collections
import importlib
import inspect
import os
from pathlib import Path
import re
import shlex
import shutil
import subprocess
import sys
from tempfile import NamedTemporaryFile
import textwrap
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import TextIO
from typing import Tuple
from typing import Type
from typing import TypeVar
from sqlalchemy import util
from sqlalchemy.util import compat
from sqlalchemy.util import langhelpers
from sqlalchemy.util.langhelpers import format_argspec_plus
from sqlalchemy.util.langhelpers import inject_docstring_text
is_posix = os.name == "posix"
sys.path.append(str(Path(__file__).parent.parent))
class _repr_sym:
__slots__ = ("sym",)
def __init__(self, sym: str):
self.sym = sym
def __repr__(self) -> str:
return self.sym
classes: collections.defaultdict[
str, Dict[str, Tuple[Any, ...]]
] = collections.defaultdict(dict)
_T = TypeVar("_T", bound="Any")
def create_proxy_methods(
target_cls: Type[Any],
target_cls_sphinx_name: str,
proxy_cls_sphinx_name: str,
classmethods: Iterable[str] = (),
methods: Iterable[str] = (),
attributes: Iterable[str] = (),
) -> Callable[[Type[_T]], Type[_T]]:
"""A class decorator that will copy attributes to a proxy class.
The class to be instrumented must define a single accessor "_proxied".
"""
def decorate(cls: Type[_T]) -> Type[_T]:
# collect the class as a separate step. since the decorator
# is called as a result of imports, the order in which classes
# are collected (like in asyncio) can't be well controlled. however,
# the proxies (specifically asyncio session and asyncio scoped_session)
# have to be generated in dependency order, so run them in order in a
# second step.
classes[cls.__module__][cls.__name__] = (
target_cls,
target_cls_sphinx_name,
proxy_cls_sphinx_name,
classmethods,
methods,
attributes,
cls,
)
return cls
return decorate
def process_class(
buf: TextIO,
target_cls: Type[Any],
target_cls_sphinx_name: str,
proxy_cls_sphinx_name: str,
classmethods: Iterable[str],
methods: Iterable[str],
attributes: Iterable[str],
cls: Type[Any],
):
sphinx_symbol_match = re.match(r":class:`(.+)`", target_cls_sphinx_name)
if not sphinx_symbol_match:
raise Exception(
f"Couldn't match sphinx class identifier from "
f"target_cls_sphinx_name f{target_cls_sphinx_name!r}. Currently "
'this program expects the form ":class:`_<prefix>.<clsname>`"'
)
sphinx_symbol = sphinx_symbol_match.group(1)
def instrument(buf: TextIO, name: str, clslevel: bool = False) -> None:
fn = getattr(target_cls, name)
spec = compat.inspect_getfullargspec(fn)
iscoroutine = inspect.iscoroutinefunction(fn)
if spec.defaults:
new_defaults = tuple(
_repr_sym("util.EMPTY_DICT") if df is util.EMPTY_DICT else df
for df in spec.defaults
)
elem = list(spec)
elem[3] = tuple(new_defaults)
spec = compat.FullArgSpec(*elem)
caller_argspec = format_argspec_plus(spec, grouped=False)
metadata = {
"name": fn.__name__,
"async": "async " if iscoroutine else "",
"await": "await " if iscoroutine else "",
"apply_pos_proxied": caller_argspec["apply_pos_proxied"],
"target_cls_name": target_cls.__name__,
"apply_kw_proxied": caller_argspec["apply_kw_proxied"],
"grouped_args": caller_argspec["grouped_args"],
"self_arg": caller_argspec["self_arg"],
"doc": textwrap.indent(
inject_docstring_text(
fn.__doc__,
textwrap.indent(
".. container:: class_bases\n\n"
f" Proxied for the {target_cls_sphinx_name} "
"class on \n"
f" behalf of the {proxy_cls_sphinx_name} "
"class.",
" ",
),
1,
),
" ",
).lstrip(),
}
if clslevel:
code = (
"@classmethod\n"
"%(async)sdef %(name)s%(grouped_args)s:\n"
' r"""%(doc)s\n """ # noqa: E501\n\n'
" return %(await)s%(target_cls_name)s.%(name)s(%(apply_kw_proxied)s)\n\n" # noqa: E501
% metadata
)
else:
code = (
"%(async)sdef %(name)s%(grouped_args)s:\n"
' r"""%(doc)s\n """ # noqa: E501\n\n'
" return %(await)s%(self_arg)s._proxied.%(name)s(%(apply_kw_proxied)s)\n\n" # noqa: E501
% metadata
)
buf.write(textwrap.indent(code, " "))
def makeprop(buf: TextIO, name: str) -> None:
attr = target_cls.__dict__.get(name, None)
return_type = target_cls.__annotations__.get(name, "Any")
assert isinstance(return_type, str), (
"expected string annotations, is from __future__ "
"import annotations set up?"
)
if attr is not None:
if isinstance(attr, property):
readonly = attr.fset is None
elif isinstance(attr, langhelpers.generic_fn_descriptor):
readonly = True
else:
readonly = not hasattr(attr, "__set__")
doc = textwrap.indent(
inject_docstring_text(
attr.__doc__,
textwrap.indent(
".. container:: class_bases\n\n"
f" Proxied for the {target_cls_sphinx_name} "
"class \n"
f" on behalf of the {proxy_cls_sphinx_name} "
"class.",
" ",
),
1,
),
" ",
).lstrip()
else:
readonly = False
doc = (
f"Proxy for the :attr:`{sphinx_symbol}.{name}` "
"attribute \n"
f" on behalf of the {proxy_cls_sphinx_name} "
"class.\n"
)
code = (
"@property\n"
"def %(name)s(self) -> %(return_type)s:\n"
' r"""%(doc)s\n """ # noqa: E501\n\n'
" return self._proxied.%(name)s\n\n"
) % {"name": name, "doc": doc, "return_type": return_type}
if not readonly:
code += (
"@%(name)s.setter\n"
"def %(name)s(self, attr: %(return_type)s) -> None:\n"
" self._proxied.%(name)s = attr\n\n"
) % {"name": name, "doc": doc, "return_type": return_type}
buf.write(textwrap.indent(code, " "))
for meth in methods:
instrument(buf, meth)
for prop in attributes:
makeprop(buf, prop)
for prop in classmethods:
instrument(buf, prop, clslevel=True)
def process_module(modname: str, filename: str) -> str:
class_entries = classes[modname]
# use tempfile in same path as the module, or at least in the
# current working directory, so that black / zimports use
# local pyproject.toml
with NamedTemporaryFile(
mode="w", delete=False, suffix=".py", dir=Path(filename).parent
) as buf, open(filename) as orig_py:
in_block = False
current_clsname = None
for line in orig_py:
m = re.match(r" # START PROXY METHODS (.+)$", line)
if m:
current_clsname = m.group(1)
args = class_entries[current_clsname]
sys.stderr.write(
f"Generating attributes for class {current_clsname}\n"
)
in_block = True
buf.write(line)
buf.write(
"\n # code within this block is "
"**programmatically, \n"
" # statically generated** by"
" tools/generate_proxy_methods.py\n\n"
)
process_class(buf, *args)
if line.startswith(f" # END PROXY METHODS {current_clsname}"):
in_block = False
if not in_block:
buf.write(line)
return buf.name
def console_scripts(
path: str, options: dict, ignore_output: bool = False
) -> None:
entrypoint_name = options["entrypoint"]
for entry in compat.importlib_metadata_get("console_scripts"):
if entry.name == entrypoint_name:
impl = entry
break
else:
raise Exception(
f"Could not find entrypoint console_scripts.{entrypoint_name}"
)
cmdline_options_str = options.get("options", "")
cmdline_options_list = shlex.split(cmdline_options_str, posix=is_posix) + [
path
]
kw = {}
if ignore_output:
kw["stdout"] = kw["stderr"] = subprocess.DEVNULL
subprocess.run(
[
sys.executable,
"-c",
"import %s; %s.%s()" % (impl.module, impl.module, impl.attr),
]
+ cmdline_options_list,
cwd=Path(__file__).parent.parent,
**kw,
)
def run_module(modname, stdout):
sys.stderr.write(f"importing module {modname}\n")
mod = importlib.import_module(modname)
filename = destination_path = mod.__file__
assert filename is not None
tempfile = process_module(modname, filename)
ignore_output = stdout
console_scripts(
str(tempfile),
{"entrypoint": "zimports"},
ignore_output=ignore_output,
)
console_scripts(
str(tempfile),
{"entrypoint": "black"},
ignore_output=ignore_output,
)
if stdout:
with open(tempfile) as tf:
print(tf.read())
os.unlink(tempfile)
else:
sys.stderr.write(f"Writing {destination_path}...\n")
shutil.move(tempfile, destination_path)
def main(args):
from sqlalchemy import util
from sqlalchemy.util import langhelpers
util.create_proxy_methods = (
langhelpers.create_proxy_methods
) = create_proxy_methods
for entry in entries:
if args.module in {"all", entry}:
run_module(entry, args.stdout)
entries = [
"sqlalchemy.orm.scoping",
"sqlalchemy.ext.asyncio.engine",
"sqlalchemy.ext.asyncio.session",
"sqlalchemy.ext.asyncio.scoping",
]
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--module",
choices=entries + ["all"],
default="all",
help="Which file to generate. Default is to regenerate all files",
)
parser.add_argument(
"--stdout",
action="store_true",
help="Write to stdout instead of saving to file",
)
args = parser.parse_args()
main(args)