pep484 - SQL internals

non-strict checking for mostly internal or semi-internal
code

Change-Id: Ib91b47f1a8ccc15e666b94bad1ce78c4ab15b0ec
This commit is contained in:
Mike Bayer
2022-03-20 16:39:36 -04:00
parent c565c47051
commit 6f02d5edd8
34 changed files with 1084 additions and 483 deletions
+3 -2
View File
@@ -7,6 +7,8 @@
from __future__ import annotations
from typing import Any
from . import util as _util
from .engine import AdaptedConnection as AdaptedConnection
from .engine import BaseRow as BaseRow
@@ -191,7 +193,6 @@ from .sql.expression import tuple_ as tuple_
from .sql.expression import type_coerce as type_coerce
from .sql.expression import TypeClause as TypeClause
from .sql.expression import TypeCoerce as TypeCoerce
from .sql.expression import typing as typing
from .sql.expression import UnaryExpression as UnaryExpression
from .sql.expression import union as union
from .sql.expression import union_all as union_all
@@ -254,7 +255,7 @@ from .types import VARCHAR as VARCHAR
__version__ = "2.0.0b1"
def __go(lcls):
def __go(lcls: Any) -> None:
from . import util as _sa_util
_sa_util.preloaded.import_prefix("sqlalchemy")
@@ -26,6 +26,10 @@ cdef class OrderedSet(set):
cdef list _list
@classmethod
def __class_getitem__(cls, key):
return cls
def __init__(self, d=None):
set.__init__(self)
if d is not None:
+3 -2
View File
@@ -73,6 +73,7 @@ if typing.TYPE_CHECKING:
from ..sql.functions import FunctionElement
from ..sql.schema import ColumnDefault
from ..sql.schema import HasSchemaAttr
from ..sql.schema import SchemaItem
"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`.
@@ -2004,7 +2005,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
def _run_ddl_visitor(
self,
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
element: DDLElement,
element: SchemaItem,
**kwargs: Any,
) -> None:
"""run a DDL visitor.
@@ -2749,7 +2750,7 @@ class Engine(
def _run_ddl_visitor(
self,
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
element: DDLElement,
element: SchemaItem,
**kwargs: Any,
) -> None:
with self.begin() as conn:
+6 -10
View File
@@ -54,7 +54,9 @@ from ..sql import expression
from ..sql._typing import is_tuple_type
from ..sql.compiler import DDLCompiler
from ..sql.compiler import SQLCompiler
from ..sql.elements import ColumnClause
from ..sql.elements import quoted_name
from ..sql.schema import default_is_scalar
if typing.TYPE_CHECKING:
from types import ModuleType
@@ -1164,7 +1166,7 @@ class DefaultExecutionContext(ExecutionContext):
return ()
@util.memoized_property
def returning_cols(self) -> Optional[Sequence[Column[Any]]]:
def returning_cols(self) -> Optional[Sequence[ColumnClause[Any]]]:
if TYPE_CHECKING:
assert isinstance(self.compiled, SQLCompiler)
return self.compiled.returning
@@ -1778,15 +1780,11 @@ class DefaultExecutionContext(ExecutionContext):
# to avoid many calls of get_insert_default()/
# get_update_default()
for c in insert_prefetch:
if c.default and not c.default.is_sequence and c.default.is_scalar:
if TYPE_CHECKING:
assert isinstance(c.default, ColumnDefault)
if c.default and default_is_scalar(c.default):
scalar_defaults[c] = c.default.arg
for c in update_prefetch:
if c.onupdate and c.onupdate.is_scalar:
if TYPE_CHECKING:
assert isinstance(c.onupdate, ColumnDefault)
if c.onupdate and default_is_scalar(c.onupdate):
scalar_defaults[c] = c.onupdate.arg
for param in self.compiled_parameters:
@@ -1817,9 +1815,7 @@ class DefaultExecutionContext(ExecutionContext):
) = self.compiled_parameters[0]
for c in compiled.insert_prefetch:
if c.default and not c.default.is_sequence and c.default.is_scalar:
if TYPE_CHECKING:
assert isinstance(c.default, ColumnDefault)
if c.default and default_is_scalar(c.default):
val = c.default.arg
else:
val = self.get_insert_default(c)
+2 -1
View File
@@ -32,6 +32,7 @@ if typing.TYPE_CHECKING:
from ..sql.ddl import SchemaDropper
from ..sql.ddl import SchemaGenerator
from ..sql.schema import HasSchemaAttr
from ..sql.schema import SchemaItem
class MockConnection:
@@ -55,7 +56,7 @@ class MockConnection:
def _run_ddl_visitor(
self,
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
element: DDLElement,
element: SchemaItem,
**kwargs: Any,
) -> None:
kwargs["checkfirst"] = False
-1
View File
@@ -317,7 +317,6 @@ class Inspector(inspection.Inspectable["Inspector"]):
with an already-given :class:`_schema.MetaData`.
"""
with self._operation_context() as conn:
tnames = self.dialect.get_table_names(
conn, schema, info_cache=self.info_cache
+1 -1
View File
@@ -88,7 +88,7 @@ class SQLAlchemyAttribute:
return cls(typ=typ, info=info, **data)
def name_is_dunder(name):
def name_is_dunder(name: str) -> bool:
return bool(re.match(r"^__.+?__$", name))
+2 -2
View File
@@ -225,7 +225,7 @@ def instance_logger(
else:
name = _qual_logger_name_for_cls(instance.__class__)
instance._echo = echoflag
instance._echo = echoflag # type: ignore
logger: Union[logging.Logger, InstanceLogger]
@@ -239,7 +239,7 @@ def instance_logger(
# levels by calling logger._log()
logger = InstanceLogger(echoflag, name)
instance.logger = logger
instance.logger = logger # type: ignore
class echo_property:
+2 -1
View File
@@ -511,7 +511,8 @@ class Composite(
"""
__hash__ = None
# https://github.com/python/mypy/issues/4266
__hash__ = None # type: ignore
@util.memoized_property
def clauses(self):
+2 -1
View File
@@ -476,7 +476,8 @@ class Relationship(
"the set of foreign key values."
)
__hash__ = None
# https://github.com/python/mypy/issues/4266
__hash__ = None # type: ignore
def __eq__(self, other):
"""Implement the ``==`` operator.
+2 -1
View File
@@ -4,6 +4,7 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
from typing import Any
from .base import Executable as Executable
from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS
@@ -97,7 +98,7 @@ from .expression import within_group as within_group
from .visitors import ClauseVisitor as ClauseVisitor
def __go(lcls):
def __go(lcls: Any) -> None:
from .. import util as _sa_util
from . import base
@@ -8,7 +8,7 @@
from __future__ import annotations
from typing import Any
from typing import Union
from typing import Optional
from . import coercions
from . import roles
@@ -23,8 +23,6 @@ from .selectable import Select
from .selectable import TableClause
from .selectable import TableSample
from .selectable import Values
from ..util.typing import _LiteralStar
from ..util.typing import Literal
def alias(selectable, name=None, flat=False):
@@ -283,9 +281,7 @@ def outerjoin(left, right, onclause=None, full=False):
return Join(left, right, onclause, isouter=True, full=full)
def select(
*entities: Union[_LiteralStar, Literal[1], _ColumnsClauseElement]
) -> "Select":
def select(*entities: _ColumnsClauseElement) -> Select:
r"""Construct a new :class:`_expression.Select`.
@@ -326,7 +322,7 @@ def select(
return Select(*entities)
def table(name: str, *columns: ColumnClause, **kw: Any) -> "TableClause":
def table(name: str, *columns: ColumnClause[Any], **kw: Any) -> TableClause:
"""Produce a new :class:`_expression.TableClause`.
The object returned is an instance of
@@ -435,7 +431,11 @@ def union_all(*selects):
return CompoundSelect._create_union_all(*selects)
def values(*columns, name=None, literal_binds=False) -> "Values":
def values(
*columns: ColumnClause[Any],
name: Optional[str] = None,
literal_binds: bool = False,
) -> Values:
r"""Construct a :class:`_expression.Values` construct.
The column expressions and the actual data for
+4 -2
View File
@@ -9,6 +9,7 @@ from typing import Union
from . import roles
from .. import util
from ..inspection import Inspectable
from ..util.typing import Literal
if TYPE_CHECKING:
from .elements import quoted_name
@@ -24,12 +25,13 @@ if TYPE_CHECKING:
_T = TypeVar("_T", bound=Any)
_ColumnsClauseElement = Union[
Literal["*", 1],
roles.ColumnsClauseRole,
Type,
Type[Any],
Inspectable[roles.HasColumnElementClauseElement],
]
_FromClauseElement = Union[
roles.FromClauseRole, Type, Inspectable[roles.HasFromClauseElement]
roles.FromClauseRole, Type[Any], Inspectable[roles.HasFromClauseElement]
]
_ColumnExpression = Union[
+173 -96
View File
@@ -12,22 +12,32 @@
from __future__ import annotations
import collections.abc as collections_abc
from enum import Enum
from functools import reduce
import itertools
from itertools import zip_longest
import operator
import re
import typing
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import FrozenSet
from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import NoReturn
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import roles
from . import visitors
@@ -36,17 +46,26 @@ from .cache_key import MemoizedHasCacheKey # noqa
from .traversals import HasCopyInternals # noqa
from .visitors import ClauseVisitor
from .visitors import ExtendedInternalTraversal
from .visitors import ExternallyTraversible
from .visitors import InternalTraversal
from .. import event
from .. import exc
from .. import util
from ..util import HasMemoized as HasMemoized
from ..util import hybridmethod
from ..util import typing as compat_typing
from ..util.typing import Protocol
from ..util.typing import Self
from ..util.typing import TypeGuard
if typing.TYPE_CHECKING:
if TYPE_CHECKING:
from . import coercions
from . import elements
from . import type_api
from .elements import BindParameter
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import SQLCoreOperations
from ..engine import Connection
from ..engine import Result
from ..engine.base import _CompiledCacheType
@@ -58,10 +77,12 @@ if typing.TYPE_CHECKING:
from ..engine.interfaces import CacheStats
from ..engine.interfaces import Compiled
from ..engine.interfaces import Dialect
from ..event import dispatcher
coercions = None
elements = None
type_api = None
if not TYPE_CHECKING:
coercions = None # noqa
elements = None # noqa
type_api = None # noqa
class _NoArg(Enum):
@@ -70,13 +91,24 @@ class _NoArg(Enum):
NO_ARG = _NoArg.NO_ARG
# if I use sqlalchemy.util.typing, which has the exact same
# symbols, mypy reports: "error: _Fn? not callable"
_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
_Fn = TypeVar("_Fn", bound=Callable[..., Any])
_AmbiguousTableNameMap = MutableMapping[str, str]
class _EntityNamespace(Protocol):
def __getattr__(self, key: str) -> SQLCoreOperations[Any]:
...
class _HasEntityNamespace(Protocol):
entity_namespace: _EntityNamespace
def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]:
return hasattr(element, "entity_namespace")
class Immutable:
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
@@ -107,10 +139,14 @@ class SingletonConstant(Immutable):
def __new__(cls, *arg, **kw):
return cls._singleton
@util.non_memoized_property
def proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
raise NotImplementedError()
@classmethod
def _create_singleton(cls):
obj = object.__new__(cls)
obj.__init__()
obj.__init__() # type: ignore
# for a long time this was an empty frozenset, meaning
# a SingletonConstant would never be a "corresponding column" in
@@ -139,12 +175,11 @@ def _select_iterables(elements):
)
_Self = typing.TypeVar("_Self", bound="_GenerativeType")
_Args = compat_typing.ParamSpec("_Args")
_SelfGenerativeType = TypeVar("_SelfGenerativeType", bound="_GenerativeType")
class _GenerativeType(compat_typing.Protocol):
def _generate(self: "_Self") -> "_Self":
def _generate(self: _SelfGenerativeType) -> _SelfGenerativeType:
...
@@ -158,8 +193,8 @@ def _generative(fn: _Fn) -> _Fn:
@util.decorator
def _generative(
fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs
) -> _Self:
fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any
) -> _SelfGenerativeType:
"""Mark a method as generative."""
self = self._generate()
@@ -167,9 +202,9 @@ def _generative(fn: _Fn) -> _Fn:
assert x is self, "generative methods must return self"
return self
decorated = _generative(fn)
decorated.non_generative = fn
return decorated
decorated = _generative(fn) # type: ignore
decorated.non_generative = fn # type: ignore
return decorated # type: ignore
def _exclusive_against(*names, **kw):
@@ -233,7 +268,7 @@ def _cloned_difference(a, b):
)
class _DialectArgView(collections_abc.MutableMapping):
class _DialectArgView(MutableMapping[str, Any]):
"""A dictionary view of dialect-level arguments in the form
<dialectname>_<argument_name>.
@@ -290,7 +325,7 @@ class _DialectArgView(collections_abc.MutableMapping):
)
class _DialectArgDict(collections_abc.MutableMapping):
class _DialectArgDict(MutableMapping[str, Any]):
"""A dictionary view of dialect-level arguments for a specific
dialect.
@@ -343,6 +378,8 @@ class DialectKWArgs:
"""
__slots__ = ()
_dialect_kwargs_traverse_internals = [
("dialect_options", InternalTraversal.dp_dialect_options)
]
@@ -534,7 +571,7 @@ class CompileState:
__slots__ = ("statement", "_ambiguous_table_name_map")
plugins = {}
plugins: Dict[Tuple[str, str], Type[CompileState]] = {}
_ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
@@ -639,9 +676,9 @@ class InPlaceGenerative(HasMemoized):
class HasCompileState(Generative):
"""A class that has a :class:`.CompileState` associated with it."""
_compile_state_plugin = None
_compile_state_plugin: Optional[Type[CompileState]] = None
_attributes = util.immutabledict()
_attributes: util.immutabledict[str, Any] = util.EMPTY_DICT
_compile_state_factory = CompileState.create_for_statement
@@ -655,6 +692,8 @@ class _MetaOptions(type):
"""
_cache_attrs: Tuple[str, ...]
def __add__(self, other):
o1 = self()
@@ -674,6 +713,8 @@ class Options(metaclass=_MetaOptions):
__slots__ = ()
_cache_attrs: Tuple[str, ...]
def __init_subclass__(cls) -> None:
dict_ = cls.__dict__
cls._cache_attrs = tuple(
@@ -732,13 +773,13 @@ class Options(metaclass=_MetaOptions):
return self + {name: getattr(self, name) + value}
@hybridmethod
def _state_dict(self):
def _state_dict_inst(self) -> Mapping[str, Any]:
return self.__dict__
_state_dict_const = util.immutabledict()
_state_dict_const: util.immutabledict[str, Any] = util.EMPTY_DICT
@_state_dict.classlevel
def _state_dict(cls):
@_state_dict_inst.classlevel
def _state_dict(cls) -> Mapping[str, Any]:
return cls._state_dict_const
@classmethod
@@ -825,10 +866,10 @@ class CacheableOptions(Options, HasCacheKey):
__slots__ = ()
@hybridmethod
def _gen_cache_key(self, anon_map, bindparams):
def _gen_cache_key_inst(self, anon_map, bindparams):
return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
@_gen_cache_key.classlevel
@_gen_cache_key_inst.classlevel
def _gen_cache_key(cls, anon_map, bindparams):
return (cls, ())
@@ -849,11 +890,11 @@ class ExecutableOption(HasCopyInternals):
def _clone(self, **kw):
"""Create a shallow copy of this ExecutableOption."""
c = self.__class__.__new__(self.__class__)
c.__dict__ = dict(self.__dict__)
c.__dict__ = dict(self.__dict__) # type: ignore
return c
SelfExecutable = typing.TypeVar("SelfExecutable", bound="Executable")
SelfExecutable = TypeVar("SelfExecutable", bound="Executable")
class Executable(roles.StatementRole, Generative):
@@ -866,9 +907,12 @@ class Executable(roles.StatementRole, Generative):
"""
supports_execution: bool = True
_execution_options: _ImmutableExecuteOptions = util.immutabledict()
_with_options = ()
_with_context_options = ()
_execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT
_with_options: Tuple[ExecutableOption, ...] = ()
_with_context_options: Tuple[
Tuple[Callable[[CompileState], None], Any], ...
] = ()
_compile_options: Optional[CacheableOptions]
_executable_traverse_internals = [
("_with_options", InternalTraversal.dp_executable_options),
@@ -886,7 +930,9 @@ class Executable(roles.StatementRole, Generative):
is_delete = False
is_dml = False
if typing.TYPE_CHECKING:
if TYPE_CHECKING:
__visit_name__: str
def _compile_w_cache(
self,
@@ -916,11 +962,13 @@ class Executable(roles.StatementRole, Generative):
raise NotImplementedError()
@property
def _effective_plugin_target(self):
def _effective_plugin_target(self) -> str:
return self.__visit_name__
@_generative
def options(self: SelfExecutable, *options) -> SelfExecutable:
def options(
self: SelfExecutable, *options: ExecutableOption
) -> SelfExecutable:
"""Apply options to this statement.
In the general sense, options are any kind of Python object
@@ -957,7 +1005,7 @@ class Executable(roles.StatementRole, Generative):
@_generative
def _set_compile_options(
self: SelfExecutable, compile_options
self: SelfExecutable, compile_options: CacheableOptions
) -> SelfExecutable:
"""Assign the compile options to a new value.
@@ -970,16 +1018,19 @@ class Executable(roles.StatementRole, Generative):
@_generative
def _update_compile_options(
self: SelfExecutable, options
self: SelfExecutable, options: CacheableOptions
) -> SelfExecutable:
"""update the _compile_options with new keys."""
assert self._compile_options is not None
self._compile_options += options
return self
@_generative
def _add_context_option(
self: SelfExecutable, callable_, cache_args
self: SelfExecutable,
callable_: Callable[[CompileState], None],
cache_args: Any,
) -> SelfExecutable:
"""Add a context option to this statement.
@@ -995,7 +1046,7 @@ class Executable(roles.StatementRole, Generative):
return self
@_generative
def execution_options(self: SelfExecutable, **kw) -> SelfExecutable:
def execution_options(self: SelfExecutable, **kw: Any) -> SelfExecutable:
"""Set non-SQL options for the statement which take effect during
execution.
@@ -1112,7 +1163,7 @@ class Executable(roles.StatementRole, Generative):
self._execution_options = self._execution_options.union(kw)
return self
def get_execution_options(self):
def get_execution_options(self) -> _ExecuteOptions:
"""Get the non-SQL options which will take effect during execution.
.. versionadded:: 1.3
@@ -1124,7 +1175,7 @@ class Executable(roles.StatementRole, Generative):
return self._execution_options
class SchemaEventTarget:
class SchemaEventTarget(event.EventTarget):
"""Base class for elements that are the targets of :class:`.DDLEvents`
events.
@@ -1132,6 +1183,8 @@ class SchemaEventTarget:
"""
dispatch: dispatcher[SchemaEventTarget]
def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
"""Associate with this SchemaEvent's parent object."""
@@ -1149,7 +1202,10 @@ class SchemaVisitor(ClauseVisitor):
__traverse_options__ = {"schema_visitor": True}
class ColumnCollection:
_COL = TypeVar("_COL", bound="ColumnClause[Any]")
class ColumnCollection(Generic[_COL]):
"""Collection of :class:`_expression.ColumnElement` instances,
typically for
:class:`_sql.FromClause` objects.
@@ -1260,32 +1316,36 @@ class ColumnCollection:
__slots__ = "_collection", "_index", "_colset"
def __init__(self, columns=None):
_collection: List[Tuple[str, _COL]]
_index: Dict[Union[str, int], _COL]
_colset: Set[_COL]
def __init__(self, columns: Optional[Iterable[Tuple[str, _COL]]] = None):
object.__setattr__(self, "_colset", set())
object.__setattr__(self, "_index", {})
object.__setattr__(self, "_collection", [])
if columns:
self._initial_populate(columns)
def _initial_populate(self, iter_):
def _initial_populate(self, iter_: Iterable[Tuple[str, _COL]]) -> None:
self._populate_separate_keys(iter_)
@property
def _all_columns(self):
def _all_columns(self) -> List[_COL]:
return [col for (k, col) in self._collection]
def keys(self):
def keys(self) -> List[str]:
"""Return a sequence of string key names for all columns in this
collection."""
return [k for (k, col) in self._collection]
def values(self):
def values(self) -> List[_COL]:
"""Return a sequence of :class:`_sql.ColumnClause` or
:class:`_schema.Column` objects for all columns in this
collection."""
return [col for (k, col) in self._collection]
def items(self):
def items(self) -> List[Tuple[str, _COL]]:
"""Return a sequence of (key, column) tuples for all columns in this
collection each consisting of a string key name and a
:class:`_sql.ColumnClause` or
@@ -1294,17 +1354,17 @@ class ColumnCollection:
return list(self._collection)
def __bool__(self):
def __bool__(self) -> bool:
return bool(self._collection)
def __len__(self):
def __len__(self) -> int:
return len(self._collection)
def __iter__(self):
def __iter__(self) -> Iterator[_COL]:
# turn to a list first to maintain over a course of changes
return iter([col for k, col in self._collection])
def __getitem__(self, key):
def __getitem__(self, key: Union[str, int]) -> _COL:
try:
return self._index[key]
except KeyError as err:
@@ -1313,13 +1373,13 @@ class ColumnCollection:
else:
raise
def __getattr__(self, key):
def __getattr__(self, key: str) -> _COL:
try:
return self._index[key]
except KeyError as err:
raise AttributeError(key) from err
def __contains__(self, key):
def __contains__(self, key: str) -> bool:
if key not in self._index:
if not isinstance(key, str):
raise exc.ArgumentError(
@@ -1329,7 +1389,7 @@ class ColumnCollection:
else:
return True
def compare(self, other):
def compare(self, other: ColumnCollection[Any]) -> bool:
"""Compare this :class:`_expression.ColumnCollection` to another
based on the names of the keys"""
@@ -1339,10 +1399,10 @@ class ColumnCollection:
else:
return True
def __eq__(self, other):
def __eq__(self, other: Any) -> bool:
return self.compare(other)
def get(self, key, default=None):
def get(self, key: str, default: Optional[_COL] = None) -> Optional[_COL]:
"""Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object
based on a string key name from this
:class:`_expression.ColumnCollection`."""
@@ -1352,39 +1412,40 @@ class ColumnCollection:
else:
return default
def __str__(self):
def __str__(self) -> str:
return "%s(%s)" % (
self.__class__.__name__,
", ".join(str(c) for c in self),
)
def __setitem__(self, key, value):
def __setitem__(self, key: str, value: Any) -> NoReturn:
raise NotImplementedError()
def __delitem__(self, key):
def __delitem__(self, key: str) -> NoReturn:
raise NotImplementedError()
def __setattr__(self, key, obj):
def __setattr__(self, key: str, obj: Any) -> NoReturn:
raise NotImplementedError()
def clear(self):
def clear(self) -> NoReturn:
"""Dictionary clear() is not implemented for
:class:`_sql.ColumnCollection`."""
raise NotImplementedError()
def remove(self, column):
"""Dictionary remove() is not implemented for
:class:`_sql.ColumnCollection`."""
def remove(self, column: Any) -> None:
raise NotImplementedError()
def update(self, iter_):
def update(self, iter_: Any) -> NoReturn:
"""Dictionary update() is not implemented for
:class:`_sql.ColumnCollection`."""
raise NotImplementedError()
__hash__ = None
# https://github.com/python/mypy/issues/4266
__hash__ = None # type: ignore
def _populate_separate_keys(self, iter_):
def _populate_separate_keys(
self, iter_: Iterable[Tuple[str, _COL]]
) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
self._collection[:] = cols
@@ -1394,7 +1455,7 @@ class ColumnCollection:
)
self._index.update({k: col for k, col in reversed(self._collection)})
def add(self, column, key=None):
def add(self, column: _COL, key: Optional[str] = None) -> None:
"""Add a column to this :class:`_sql.ColumnCollection`.
.. note::
@@ -1416,17 +1477,17 @@ class ColumnCollection:
if key not in self._index:
self._index[key] = column
def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
return {"_collection": self._collection, "_index": self._index}
def __setstate__(self, state):
def __setstate__(self, state: Dict[str, Any]) -> None:
object.__setattr__(self, "_index", state["_index"])
object.__setattr__(self, "_collection", state["_collection"])
object.__setattr__(
self, "_colset", {col for k, col in self._collection}
)
def contains_column(self, col):
def contains_column(self, col: _COL) -> bool:
"""Checks if a column object exists in this collection"""
if col not in self._colset:
if isinstance(col, str):
@@ -1438,13 +1499,15 @@ class ColumnCollection:
else:
return True
def as_immutable(self):
def as_immutable(self) -> ImmutableColumnCollection[_COL]:
"""Return an "immutable" form of this
:class:`_sql.ColumnCollection`."""
return ImmutableColumnCollection(self)
def corresponding_column(self, column, require_embedded=False):
def corresponding_column(
self, column: _COL, require_embedded: bool = False
) -> Optional[_COL]:
"""Given a :class:`_expression.ColumnElement`, return the exported
:class:`_expression.ColumnElement` object from this
:class:`_expression.ColumnCollection`
@@ -1497,7 +1560,7 @@ class ColumnCollection:
not require_embedded
or embedded(expanded_proxy_set, target_set)
):
if col is None:
if col is None or intersect is None:
# no corresponding column yet, pick this one.
@@ -1542,7 +1605,7 @@ class ColumnCollection:
return col
class DedupeColumnCollection(ColumnCollection):
class DedupeColumnCollection(ColumnCollection[_COL]):
"""A :class:`_expression.ColumnCollection`
that maintains deduplicating behavior.
@@ -1555,7 +1618,7 @@ class DedupeColumnCollection(ColumnCollection):
"""
def add(self, column, key=None):
def add(self, column: _COL, key: Optional[str] = None) -> None:
if key is not None and column.key != key:
raise exc.ArgumentError(
@@ -1589,7 +1652,9 @@ class DedupeColumnCollection(ColumnCollection):
self._index[l] = column
self._index[key] = column
def _populate_separate_keys(self, iter_):
def _populate_separate_keys(
self, iter_: Iterable[Tuple[str, _COL]]
) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
@@ -1614,10 +1679,10 @@ class DedupeColumnCollection(ColumnCollection):
for col in replace_col:
self.replace(col)
def extend(self, iter_):
def extend(self, iter_: Iterable[_COL]) -> None:
self._populate_separate_keys((col.key, col) for col in iter_)
def remove(self, column):
def remove(self, column: _COL) -> None:
if column not in self._colset:
raise ValueError(
"Can't remove column %r; column is not in this collection"
@@ -1634,7 +1699,7 @@ class DedupeColumnCollection(ColumnCollection):
# delete higher index
del self._index[len(self._collection)]
def replace(self, column):
def replace(self, column: _COL) -> None:
"""add the given column to this collection, removing unaliased
versions of this column as well as existing columns with the
same key.
@@ -1687,7 +1752,9 @@ class DedupeColumnCollection(ColumnCollection):
self._index.update(self._collection)
class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
class ImmutableColumnCollection(
util.ImmutableContainer, ColumnCollection[_COL]
):
__slots__ = ("_parent",)
def __init__(self, collection):
@@ -1701,12 +1768,19 @@ class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
def __setstate__(self, state):
parent = state["_parent"]
self.__init__(parent)
self.__init__(parent) # type: ignore
add = extend = remove = util.ImmutableContainer._immutable
def add(self, column: Any, key: Any = ...) -> Any:
self._immutable()
def extend(self, elements: Any) -> None:
self._immutable()
def remove(self, item: Any) -> None:
self._immutable()
class ColumnSet(util.ordered_column_set):
class ColumnSet(util.OrderedSet["ColumnClause[Any]"]):
def contains_column(self, col):
return col in self
@@ -1714,9 +1788,6 @@ class ColumnSet(util.ordered_column_set):
for col in cols:
self.add(col)
def __add__(self, other):
return list(self) + list(other)
def __eq__(self, other):
l = []
for c in other:
@@ -1729,7 +1800,9 @@ class ColumnSet(util.ordered_column_set):
return hash(tuple(x for x in self))
def _entity_namespace(entity):
def _entity_namespace(
entity: Union[_HasEntityNamespace, ExternallyTraversible]
) -> _EntityNamespace:
"""Return the nearest .entity_namespace for the given entity.
If not immediately available, does an iterate to find a sub-element
@@ -1737,16 +1810,20 @@ def _entity_namespace(entity):
"""
try:
return entity.entity_namespace
return cast(_HasEntityNamespace, entity).entity_namespace
except AttributeError:
for elem in visitors.iterate(entity):
if hasattr(elem, "entity_namespace"):
for elem in visitors.iterate(cast(ExternallyTraversible, entity)):
if _is_has_entity_namespace(elem):
return elem.entity_namespace
else:
raise
def _entity_namespace_key(entity, key, default=NO_ARG):
def _entity_namespace_key(
entity: Union[_HasEntityNamespace, ExternallyTraversible],
key: str,
default: Union[SQLCoreOperations[Any], _NoArg] = NO_ARG,
) -> SQLCoreOperations[Any]:
"""Return an entry from an entity_namespace.
@@ -1760,7 +1837,7 @@ def _entity_namespace_key(entity, key, default=NO_ARG):
if default is not NO_ARG:
return getattr(ns, key, default)
else:
return getattr(ns, key)
return getattr(ns, key) # type: ignore
except AttributeError as err:
raise exc.InvalidRequestError(
'Entity namespace for "%s" has no property "%s"' % (entity, key)
+38 -21
View File
@@ -71,6 +71,7 @@ from .schema import Column
from .sqltypes import TupleType
from .type_api import TypeEngine
from .visitors import prefix_anon_map
from .visitors import Visitable
from .. import exc
from .. import util
from ..util.typing import Literal
@@ -614,10 +615,10 @@ class Compiled:
raise NotImplementedError()
def process(self, obj, **kwargs):
def process(self, obj: Visitable, **kwargs: Any) -> str:
return obj._compiler_dispatch(self, **kwargs)
def __str__(self):
def __str__(self) -> str:
"""Return the string text of the generated SQL or DDL."""
return self.string or ""
@@ -723,7 +724,7 @@ class SQLCompiler(Compiled):
"""list of columns for which onupdate default values should be evaluated
before an UPDATE takes place"""
returning: Optional[List[Column[Any]]]
returning: Optional[List[ColumnClause[Any]]]
"""list of columns that will be delivered to cursor.description or
dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE
@@ -1485,15 +1486,12 @@ class SQLCompiler(Compiled):
self._result_columns
)
_key_getters_for_crud_column: Tuple[
Callable[[Union[str, Column[Any]]], str],
Callable[[Column[Any]], str],
Callable[[Column[Any]], str],
]
# assigned by crud.py for insert/update statements
_get_bind_name_for_col: _BindNameForColProtocol
@util.memoized_property
def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
getter = self._key_getters_for_crud_column[2]
getter = self._get_bind_name_for_col
if self.escaped_bind_names:
def _get(obj):
@@ -4098,7 +4096,9 @@ class SQLCompiler(Compiled):
def for_update_clause(self, select, **kw):
return " FOR UPDATE"
def returning_clause(self, stmt, returning_cols):
def returning_clause(
self, stmt: UpdateBase, returning_cols: List[ColumnClause[Any]]
) -> str:
raise exc.CompileError(
"RETURNING is not supported by this "
"dialect's statement compiler."
@@ -4243,12 +4243,13 @@ class SQLCompiler(Compiled):
}
)
crud_params = crud._get_crud_params(
crud_params_struct = crud._get_crud_params(
self, insert_stmt, compile_state, **kw
)
crud_params_single = crud_params_struct.single_params
if (
not crud_params
not crud_params_single
and not self.dialect.supports_default_values
and not self.dialect.supports_default_metavalue
and not self.dialect.supports_empty_insert
@@ -4266,9 +4267,9 @@ class SQLCompiler(Compiled):
"version settings does not support "
"in-place multirow inserts." % self.dialect.name
)
crud_params_single = crud_params[0]
crud_params_single = crud_params_struct.single_params
else:
crud_params_single = crud_params
crud_params_single = crud_params_struct.single_params
preparer = self.preparer
supports_default_values = self.dialect.supports_default_values
@@ -4293,7 +4294,7 @@ class SQLCompiler(Compiled):
if crud_params_single or not supports_default_values:
text += " (%s)" % ", ".join(
[expr for c, expr, value in crud_params_single]
[expr for _, expr, _ in crud_params_single]
)
if self.returning or insert_stmt._returning:
@@ -4323,19 +4324,24 @@ class SQLCompiler(Compiled):
)
else:
text += " %s" % select_text
elif not crud_params and supports_default_values:
elif not crud_params_single and supports_default_values:
text += " DEFAULT VALUES"
elif compile_state._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
"(%s)"
% (", ".join(value for c, expr, value in crud_param_set))
for crud_param_set in crud_params
% (", ".join(value for _, _, value in crud_param_set))
for crud_param_set in crud_params_struct.all_multi_params
)
)
else:
insert_single_values_expr = ", ".join(
[value for c, expr, value in crud_params]
[
value
for _, _, value in cast(
"List[Tuple[Any, Any, str]]", crud_params_single
)
]
)
text += " VALUES (%s)" % insert_single_values_expr
if toplevel and insert_stmt._post_values_clause is None:
@@ -4443,9 +4449,10 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(
update_stmt, update_stmt.table, render_extra_froms, **kw
)
crud_params = crud._get_crud_params(
crud_params_struct = crud._get_crud_params(
self, update_stmt, compile_state, **kw
)
crud_params = crud_params_struct.single_params
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
@@ -4460,7 +4467,12 @@ class SQLCompiler(Compiled):
text += table_text
text += " SET "
text += ", ".join(expr + "=" + value for c, expr, value in crud_params)
text += ", ".join(
expr + "=" + value
for _, expr, value in cast(
"List[Tuple[Any, str, str]]", crud_params
)
)
if self.returning or update_stmt._returning:
if self.returning_precedes_values:
@@ -5446,6 +5458,11 @@ class _SchemaForObjectCallable(Protocol):
...
class _BindNameForColProtocol(Protocol):
def __call__(self, col: ColumnClause[Any]) -> str:
...
class IdentifierPreparer:
"""Handle quoting and case-folding of identifiers based on options."""
+243 -60
View File
@@ -13,13 +13,44 @@ from __future__ import annotations
import functools
import operator
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import List
from typing import MutableMapping
from typing import NamedTuple
from typing import Optional
from typing import overload
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from . import coercions
from . import dml
from . import elements
from . import roles
from .schema import default_is_clause_element
from .schema import default_is_sequence
from .. import exc
from .. import util
from ..util.typing import Literal
if TYPE_CHECKING:
from .compiler import _BindNameForColProtocol
from .compiler import SQLCompiler
from .dml import DMLState
from .dml import Insert
from .dml import Update
from .dml import UpdateDMLState
from .dml import ValuesBase
from .elements import ClauseElement
from .elements import ColumnClause
from .elements import ColumnElement
from .elements import TextClause
from .schema import _SQLExprDefault
from .schema import Column
from .selectable import TableClause
REQUIRED = util.symbol(
"REQUIRED",
@@ -36,7 +67,27 @@ values present.
)
def _get_crud_params(compiler, stmt, compile_state, **kw):
class _CrudParams(NamedTuple):
single_params: List[
Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
]
all_multi_params: List[
List[
Tuple[
ColumnClause[Any],
str,
str,
]
]
]
def _get_crud_params(
compiler: SQLCompiler,
stmt: ValuesBase,
compile_state: DMLState,
**kw: Any,
) -> _CrudParams:
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
@@ -59,24 +110,32 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
_column_as_key,
_getattr_col_key,
_col_bind_name,
) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)
) = _key_getters_for_crud_column(compiler, stmt, compile_state)
compiler._key_getters_for_crud_column = getters
compiler._get_bind_name_for_col = _col_bind_name
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
return [
(
c,
compiler.preparer.format_column(c),
_create_bind_param(compiler, c, None, required=True),
)
for c in stmt.table.columns
]
return _CrudParams(
[
(
c,
compiler.preparer.format_column(c),
_create_bind_param(compiler, c, None, required=True),
)
for c in stmt.table.columns
],
[],
)
stmt_parameter_tuples: Optional[List[Any]]
spd: Optional[MutableMapping[str, Any]]
if compile_state._has_multi_parameters:
spd = compile_state._multi_parameters[0]
mp = compile_state._multi_parameters
assert mp is not None
spd = mp[0]
stmt_parameter_tuples = list(spd.items())
elif compile_state._ordered_values:
spd = compile_state._dict_parameters
@@ -92,6 +151,7 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
if compiler.column_keys is None:
parameters = {}
elif stmt_parameter_tuples:
assert spd is not None
parameters = dict(
(_column_as_key(key), REQUIRED)
for key in compiler.column_keys
@@ -103,7 +163,9 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
)
# create a list of column assignment clauses as tuples
values = []
values: List[
Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
] = []
if stmt_parameter_tuples is not None:
_get_stmt_parameter_tuples_params(
@@ -116,11 +178,11 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
kw,
)
check_columns = {}
check_columns: Dict[str, ColumnClause[Any]] = {}
# special logic that only occurs for multi-table UPDATE
# statements
if compile_state.isupdate and compile_state.is_multitable:
if dml.isupdate(compile_state) and compile_state.is_multitable:
_get_update_multitable_params(
compiler,
stmt,
@@ -134,6 +196,10 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
)
if compile_state.isinsert and stmt._select_names:
# is an insert from select, is not a multiparams
assert not compile_state._has_multi_parameters
_scan_insert_from_select_cols(
compiler,
stmt,
@@ -173,14 +239,17 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
)
if compile_state._has_multi_parameters:
values = _extend_values_for_multiparams(
# is a multiparams, is not an insert from a select
assert not stmt._select_names
multi_extended_values = _extend_values_for_multiparams(
compiler,
stmt,
compile_state,
values,
_column_as_key,
cast("List[Tuple[ColumnClause[Any], str, str]]", values),
cast("Callable[..., str]", _column_as_key),
kw,
)
return _CrudParams(values, multi_extended_values)
elif (
not values
and compiler.for_executemany
@@ -198,12 +267,41 @@ def _get_crud_params(compiler, stmt, compile_state, **kw):
)
]
return values
return _CrudParams(values, [])
@overload
def _create_bind_param(
compiler: SQLCompiler,
col: ColumnElement[Any],
value: Any,
process: Literal[True] = ...,
required: bool = False,
name: Optional[str] = None,
**kw: Any,
) -> str:
...
@overload
def _create_bind_param(
compiler: SQLCompiler,
col: ColumnElement[Any],
value: Any,
**kw: Any,
) -> str:
...
def _create_bind_param(
compiler, col, value, process=True, required=False, name=None, **kw
):
compiler: SQLCompiler,
col: ColumnElement[Any],
value: Any,
process: bool = True,
required: bool = False,
name: Optional[str] = None,
**kw: Any,
) -> Union[str, elements.BindParameter[Any]]:
if name is None:
name = col.key
bindparam = elements.BindParameter(
@@ -211,8 +309,9 @@ def _create_bind_param(
)
bindparam._is_crud = True
if process:
bindparam = bindparam._compiler_dispatch(compiler, **kw)
return bindparam
return bindparam._compiler_dispatch(compiler, **kw)
else:
return bindparam
def _handle_values_anonymous_param(compiler, col, value, name, **kw):
@@ -253,8 +352,14 @@ def _handle_values_anonymous_param(compiler, col, value, name, **kw):
return value._compiler_dispatch(compiler, **kw)
def _key_getters_for_crud_column(compiler, stmt, compile_state):
if compile_state.isupdate and compile_state._extra_froms:
def _key_getters_for_crud_column(
compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState
) -> Tuple[
Callable[[Union[str, Column[Any]]], Union[str, Tuple[str, str]]],
Callable[[Column[Any]], Union[str, Tuple[str, str]]],
_BindNameForColProtocol,
]:
if dml.isupdate(compile_state) and compile_state._extra_froms:
# when extra tables are present, refer to the columns
# in those extra tables as table-qualified, including in
# dictionaries and when rendering bind param names.
@@ -267,30 +372,36 @@ def _key_getters_for_crud_column(compiler, stmt, compile_state):
coercions.expect_as_key, roles.DMLColumnRole
)
def _column_as_key(key):
def _column_as_key(
key: Union[ColumnClause[Any], str]
) -> Union[str, Tuple[str, str]]:
str_key = c_key_role(key)
if hasattr(key, "table") and key.table in _et:
return (key.table.name, str_key)
if hasattr(key, "table") and key.table in _et: # type: ignore
return (key.table.name, str_key) # type: ignore
else:
return str_key
return str_key # type: ignore
def _getattr_col_key(col):
def _getattr_col_key(
col: ColumnClause[Any],
) -> Union[str, Tuple[str, str]]:
if col.table in _et:
return (col.table.name, col.key)
return (col.table.name, col.key) # type: ignore
else:
return col.key
def _col_bind_name(col):
def _col_bind_name(col: ColumnClause[Any]) -> str:
if col.table in _et:
if TYPE_CHECKING:
assert isinstance(col.table, TableClause)
return "%s_%s" % (col.table.name, col.key)
else:
return col.key
else:
_column_as_key = functools.partial(
_column_as_key = functools.partial( # type: ignore
coercions.expect_as_key, roles.DMLColumnRole
)
_getattr_col_key = _col_bind_name = operator.attrgetter("key")
_getattr_col_key = _col_bind_name = operator.attrgetter("key") # type: ignore # noqa E501
return _column_as_key, _getattr_col_key, _col_bind_name
@@ -321,7 +432,7 @@ def _scan_insert_from_select_cols(
compiler.stack[-1]["insert_from_select"] = stmt.select
add_select_cols = []
add_select_cols: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]] = []
if stmt.include_insert_from_select_defaults:
col_set = set(cols)
for col in stmt.table.columns:
@@ -707,16 +818,22 @@ def _append_param_insert_hasdefault(
)
def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
def _append_param_insert_select_hasdefault(
compiler: SQLCompiler,
stmt: ValuesBase,
c: ColumnClause[Any],
values: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]],
kw: Dict[str, Any],
) -> None:
if c.default.is_sequence:
if default_is_sequence(c.default):
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
):
values.append(
(c, compiler.preparer.format_column(c), c.default.next_value())
)
elif c.default.is_clause_element:
elif default_is_clause_element(c.default):
values.append(
(c, compiler.preparer.format_column(c), c.default.arg.self_group())
)
@@ -777,28 +894,76 @@ def _append_param_update(
compiler.returning.append(c)
@overload
def _create_insert_prefetch_bind_param(
compiler, c, process=True, name=None, **kw
):
compiler: SQLCompiler,
c: ColumnElement[Any],
process: Literal[True] = ...,
**kw: Any,
) -> str:
...
@overload
def _create_insert_prefetch_bind_param(
compiler: SQLCompiler,
c: ColumnElement[Any],
process: Literal[False],
**kw: Any,
) -> elements.BindParameter[Any]:
...
def _create_insert_prefetch_bind_param(
compiler: SQLCompiler,
c: ColumnElement[Any],
process: bool = True,
name: Optional[str] = None,
**kw: Any,
) -> Union[elements.BindParameter[Any], str]:
param = _create_bind_param(
compiler, c, None, process=process, name=name, **kw
)
compiler.insert_prefetch.append(c)
compiler.insert_prefetch.append(c) # type: ignore
return param
@overload
def _create_update_prefetch_bind_param(
compiler: SQLCompiler,
c: ColumnElement[Any],
process: Literal[True] = ...,
**kw: Any,
) -> str:
...
@overload
def _create_update_prefetch_bind_param(
compiler: SQLCompiler,
c: ColumnElement[Any],
process: Literal[False],
**kw: Any,
) -> elements.BindParameter[Any]:
...
def _create_update_prefetch_bind_param(
compiler, c, process=True, name=None, **kw
):
compiler: SQLCompiler,
c: ColumnElement[Any],
process: bool = True,
name: Optional[str] = None,
**kw: Any,
) -> Union[elements.BindParameter[Any], str]:
param = _create_bind_param(
compiler, c, None, process=process, name=name, **kw
)
compiler.update_prefetch.append(c)
compiler.update_prefetch.append(c) # type: ignore
return param
class _multiparam_column(elements.ColumnElement):
class _multiparam_column(elements.ColumnElement[Any]):
_is_multiparam_column = True
def __init__(self, original, index):
@@ -822,14 +987,20 @@ class _multiparam_column(elements.ColumnElement):
)
def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
def _process_multiparam_default_bind(
compiler: SQLCompiler,
stmt: ValuesBase,
c: ColumnClause[Any],
index: int,
kw: Dict[str, Any],
) -> str:
if not c.default:
raise exc.CompileError(
"INSERT value for column %s is explicitly rendered as a bound"
"parameter in the VALUES clause; "
"a Python-side value or SQL expression is required" % c
)
elif c.default.is_clause_element:
elif default_is_clause_element(c.default):
return compiler.process(c.default.arg.self_group(), **kw)
elif c.default.is_sequence:
# these conditions would have been established
@@ -844,9 +1015,13 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
else:
col = _multiparam_column(c, index)
if isinstance(stmt, dml.Insert):
return _create_insert_prefetch_bind_param(compiler, col, **kw)
return _create_insert_prefetch_bind_param(
compiler, col, process=True, **kw
)
else:
return _create_update_prefetch_bind_param(compiler, col, **kw)
return _create_update_prefetch_bind_param(
compiler, col, process=True, **kw
)
def _get_update_multitable_params(
@@ -926,18 +1101,26 @@ def _get_update_multitable_params(
def _extend_values_for_multiparams(
compiler,
stmt,
compile_state,
values,
_column_as_key,
kw,
):
values_0 = values
values = [values]
compiler: SQLCompiler,
stmt: ValuesBase,
compile_state: DMLState,
initial_values: List[Tuple[ColumnClause[Any], str, str]],
_column_as_key: Callable[..., str],
kw: Dict[str, Any],
) -> List[List[Tuple[ColumnClause[Any], str, str]]]:
values_0 = initial_values
values = [initial_values]
for i, row in enumerate(compile_state._multi_parameters[1:]):
extension = []
mp = compile_state._multi_parameters
assert mp is not None
for i, row in enumerate(mp[1:]):
extension: List[
Tuple[
ColumnClause[Any],
str,
str,
]
] = []
row = {_column_as_key(key): v for key, v in row.items()}
+50 -28
View File
@@ -26,6 +26,7 @@ from . import roles
from . import type_api
from .elements import and_
from .elements import BinaryExpression
from .elements import ClauseElement
from .elements import ClauseList
from .elements import CollationClause
from .elements import CollectionAggregate
@@ -43,7 +44,7 @@ _T = typing.TypeVar("_T", bound=Any)
if typing.TYPE_CHECKING:
from .elements import ColumnElement
from .operators import custom_op
from .sqltypes import TypeEngine
from .type_api import TypeEngine
def _boolean_compare(
@@ -53,10 +54,10 @@ def _boolean_compare(
*,
negate_op: Optional[OperatorType] = None,
reverse: bool = False,
_python_is_types=(util.NoneType, bool),
_any_all_expr=False,
_python_is_types: Tuple[Type[Any], ...] = (type(None), bool),
_any_all_expr: bool = False,
result_type: Optional[
Union[Type["TypeEngine[bool]"], "TypeEngine[bool]"]
Union[Type[TypeEngine[bool]], TypeEngine[bool]]
] = None,
**kwargs: Any,
) -> BinaryExpression[bool]:
@@ -165,7 +166,7 @@ def _custom_op_operate(
def _binary_operate(
expr: ColumnElement[Any],
op: OperatorType,
obj: roles.BinaryElementRole,
obj: roles.BinaryElementRole[Any],
*,
reverse: bool = False,
result_type: Optional[
@@ -192,7 +193,7 @@ def _binary_operate(
def _conjunction_operate(
expr: ColumnElement[Any], op: OperatorType, other, **kw
expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
) -> ColumnElement[Any]:
if op is operators.and_:
return and_(expr, other)
@@ -203,7 +204,10 @@ def _conjunction_operate(
def _scalar(
expr: ColumnElement[Any], op: OperatorType, fn, **kw
expr: ColumnElement[Any],
op: OperatorType,
fn: Callable[[ColumnElement[Any]], ColumnElement[Any]],
**kw: Any,
) -> ColumnElement[Any]:
return fn(expr)
@@ -211,9 +215,9 @@ def _scalar(
def _in_impl(
expr: ColumnElement[Any],
op: OperatorType,
seq_or_selectable,
seq_or_selectable: ClauseElement,
negate_op: OperatorType,
**kw,
**kw: Any,
) -> ColumnElement[Any]:
seq_or_selectable = coercions.expect(
roles.InElementRole, seq_or_selectable, expr=expr, operator=op
@@ -227,7 +231,7 @@ def _in_impl(
def _getitem_impl(
expr: ColumnElement[Any], op: OperatorType, other, **kw
expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
) -> ColumnElement[Any]:
if isinstance(expr.type, type_api.INDEXABLE):
other = coercions.expect(
@@ -239,7 +243,7 @@ def _getitem_impl(
def _unsupported_impl(
expr: ColumnElement[Any], op: OperatorType, *arg, **kw
expr: ColumnElement[Any], op: OperatorType, *arg: Any, **kw: Any
) -> NoReturn:
raise NotImplementedError(
"Operator '%s' is not supported on " "this expression" % op.__name__
@@ -247,7 +251,7 @@ def _unsupported_impl(
def _inv_impl(
expr: ColumnElement[Any], op: OperatorType, **kw
expr: ColumnElement[Any], op: OperatorType, **kw: Any
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.__inv__`."""
@@ -260,14 +264,14 @@ def _inv_impl(
def _neg_impl(
expr: ColumnElement[Any], op: OperatorType, **kw
expr: ColumnElement[Any], op: OperatorType, **kw: Any
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.__neg__`."""
return UnaryExpression(expr, operator=operators.neg, type_=expr.type)
def _match_impl(
expr: ColumnElement[Any], op: OperatorType, other, **kw
expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.match`."""
@@ -289,7 +293,7 @@ def _match_impl(
def _distinct_impl(
expr: ColumnElement[Any], op: OperatorType, **kw
expr: ColumnElement[Any], op: OperatorType, **kw: Any
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.distinct`."""
return UnaryExpression(
@@ -298,7 +302,11 @@ def _distinct_impl(
def _between_impl(
expr: ColumnElement[Any], op: OperatorType, cleft, cright, **kw
expr: ColumnElement[Any],
op: OperatorType,
cleft: Any,
cright: Any,
**kw: Any,
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.between`."""
return BinaryExpression(
@@ -329,26 +337,32 @@ def _between_impl(
def _collate_impl(
expr: ColumnElement[Any], op: OperatorType, collation, **kw
) -> ColumnElement[Any]:
expr: ColumnElement[str], op: OperatorType, collation: str, **kw: Any
) -> ColumnElement[str]:
return CollationClause._create_collation_expression(expr, collation)
def _regexp_match_impl(
expr: ColumnElement[Any], op: OperatorType, pattern, flags, **kw
expr: ColumnElement[str],
op: OperatorType,
pattern: Any,
flags: Optional[str],
**kw: Any,
) -> ColumnElement[Any]:
if flags is not None:
flags = coercions.expect(
flags_expr = coercions.expect(
roles.BinaryElementRole,
flags,
expr=expr,
operator=operators.regexp_replace_op,
)
else:
flags_expr = None
return _boolean_compare(
expr,
op,
pattern,
flags=flags,
flags=flags_expr,
negate_op=operators.not_regexp_match_op
if op is operators.regexp_match_op
else operators.regexp_match_op,
@@ -359,10 +373,10 @@ def _regexp_match_impl(
def _regexp_replace_impl(
expr: ColumnElement[Any],
op: OperatorType,
pattern,
replacement,
flags,
**kw,
pattern: Any,
replacement: Any,
flags: Optional[str],
**kw: Any,
) -> ColumnElement[Any]:
replacement = coercions.expect(
roles.BinaryElementRole,
@@ -371,21 +385,29 @@ def _regexp_replace_impl(
operator=operators.regexp_replace_op,
)
if flags is not None:
flags = coercions.expect(
flags_expr = coercions.expect(
roles.BinaryElementRole,
flags,
expr=expr,
operator=operators.regexp_replace_op,
)
else:
flags_expr = None
return _binary_operate(
expr, op, pattern, replacement=replacement, flags=flags, **kw
expr, op, pattern, replacement=replacement, flags=flags_expr, **kw
)
# a mapping of operators with the method they use, along with
# additional keyword arguments to be passed
operator_lookup: Dict[
str, Tuple[Callable[..., ColumnElement[Any]], util.immutabledict]
str,
Tuple[
Callable[..., ColumnElement[Any]],
util.immutabledict[
str, Union[OperatorType, Callable[..., ColumnElement[Any]]]
],
],
] = {
"and_": (_conjunction_operate, util.EMPTY_DICT),
"or_": (_conjunction_operate, util.EMPTY_DICT),
+24
View File
@@ -12,11 +12,13 @@ Provide :class:`_expression.Insert`, :class:`_expression.Update` and
from __future__ import annotations
import collections.abc as collections_abc
import operator
import typing
from typing import Any
from typing import List
from typing import MutableMapping
from typing import Optional
from typing import TYPE_CHECKING
from . import coercions
from . import roles
@@ -36,10 +38,29 @@ from .elements import Null
from .selectable import HasCTE
from .selectable import HasPrefixes
from .selectable import ReturnsRows
from .selectable import TableClause
from .sqltypes import NullType
from .visitors import InternalTraversal
from .. import exc
from .. import util
from ..util.typing import TypeGuard
if TYPE_CHECKING:
def isupdate(dml) -> TypeGuard[UpdateDMLState]:
...
def isdelete(dml) -> TypeGuard[DeleteDMLState]:
...
def isinsert(dml) -> TypeGuard[InsertDMLState]:
...
else:
isupdate = operator.attrgetter("isupdate")
isdelete = operator.attrgetter("isdelete")
isinsert = operator.attrgetter("isinsert")
class DMLState(CompileState):
@@ -49,6 +70,7 @@ class DMLState(CompileState):
_ordered_values = None
_parameter_ordering = None
_has_multi_parameters = False
isupdate = False
isdelete = False
isinsert = False
@@ -237,6 +259,8 @@ class UpdateBase(
_hints = util.immutabledict()
named_with_column = False
table: TableClause
_return_defaults = False
_return_defaults_columns = None
_returning = ()
+12 -9
View File
@@ -18,6 +18,7 @@ import itertools
import operator
import re
import typing
from typing import AbstractSet
from typing import Any
from typing import Callable
from typing import cast
@@ -83,6 +84,7 @@ if typing.TYPE_CHECKING:
from .operators import OperatorType
from .schema import Column
from .schema import DefaultGenerator
from .schema import FetchedValue
from .schema import ForeignKey
from .selectable import FromClause
from .selectable import NamedFromClause
@@ -290,7 +292,7 @@ class ClauseElement(
"""
@util.memoized_property
@util.ro_memoized_property
def description(self) -> Optional[str]:
return None
@@ -319,7 +321,7 @@ class ClauseElement(
_cache_key_traversal = None
negation_clause: ClauseElement
negation_clause: ColumnElement[bool]
if typing.TYPE_CHECKING:
@@ -1153,9 +1155,7 @@ class ColumnElement(
primary_key: bool = False
_is_clone_of: Optional[ColumnElement[_T]]
@util.memoized_property
def foreign_keys(self) -> Iterable[ForeignKey]:
return []
foreign_keys: AbstractSet[ForeignKey] = frozenset()
@util.memoized_property
def _proxies(self) -> List[ColumnElement[Any]]:
@@ -1494,6 +1494,8 @@ class ColumnElement(
else:
key = name
assert key is not None
co: ColumnClause[_T] = ColumnClause(
coercions.expect(roles.TruncatedLabelRole, name)
if name_is_truncatable
@@ -1506,7 +1508,6 @@ class ColumnElement(
co._proxies = [self]
if selectable._is_clone_of is not None:
co._is_clone_of = selectable._is_clone_of.columns.get(key)
assert key is not None
return key, co
def cast(self, type_: TypeEngine[_T]) -> Cast[_T]:
@@ -4050,13 +4051,14 @@ class NamedColumn(ColumnElement[_T]):
is_literal = False
table: Optional[FromClause] = None
name: str
key: str
def _compare_name_for_result(self, other):
return (hasattr(other, "name") and self.name == other.name) or (
hasattr(other, "_label") and self._label == other._label
)
@util.memoized_property
@util.ro_memoized_property
def description(self) -> str:
return self.name
@@ -4125,6 +4127,7 @@ class NamedColumn(ColumnElement[_T]):
_selectable=selectable,
is_literal=False,
)
c._propagate_attrs = selectable._propagate_attrs
if name is None:
c.key = self.key
@@ -4192,8 +4195,8 @@ class ColumnClause(
onupdate: Optional[DefaultGenerator] = None
default: Optional[DefaultGenerator] = None
server_default: Optional[DefaultGenerator] = None
server_onupdate: Optional[DefaultGenerator] = None
server_default: Optional[FetchedValue] = None
server_onupdate: Optional[FetchedValue] = None
_is_multiparam_column = False
+37 -9
View File
@@ -7,11 +7,23 @@
from __future__ import annotations
from typing import Any
from typing import TYPE_CHECKING
from .base import SchemaEventTarget
from .. import event
if TYPE_CHECKING:
from .schema import Column
from .schema import Constraint
from .schema import SchemaItem
from .schema import Table
from ..engine.base import Connection
from ..engine.interfaces import ReflectedColumn
from ..engine.reflection import Inspector
class DDLEvents(event.Events):
class DDLEvents(event.Events[SchemaEventTarget]):
"""
Define event listeners for schema objects,
that is, :class:`.SchemaItem` and other :class:`.SchemaEventTarget`
@@ -93,7 +105,9 @@ class DDLEvents(event.Events):
_target_class_doc = "SomeSchemaClassOrObject"
_dispatch_target = SchemaEventTarget
def before_create(self, target, connection, **kw):
def before_create(
self, target: SchemaEventTarget, connection: Connection, **kw: Any
) -> None:
r"""Called before CREATE statements are emitted.
:param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
@@ -120,7 +134,9 @@ class DDLEvents(event.Events):
"""
def after_create(self, target, connection, **kw):
def after_create(
self, target: SchemaEventTarget, connection: Connection, **kw: Any
) -> None:
r"""Called after CREATE statements are emitted.
:param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
@@ -142,7 +158,9 @@ class DDLEvents(event.Events):
"""
def before_drop(self, target, connection, **kw):
def before_drop(
self, target: SchemaEventTarget, connection: Connection, **kw: Any
) -> None:
r"""Called before DROP statements are emitted.
:param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
@@ -164,7 +182,9 @@ class DDLEvents(event.Events):
"""
def after_drop(self, target, connection, **kw):
def after_drop(
self, target: SchemaEventTarget, connection: Connection, **kw: Any
) -> None:
r"""Called after DROP statements are emitted.
:param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
@@ -186,7 +206,9 @@ class DDLEvents(event.Events):
"""
def before_parent_attach(self, target, parent):
def before_parent_attach(
self, target: SchemaEventTarget, parent: SchemaItem
) -> None:
"""Called before a :class:`.SchemaItem` is associated with
a parent :class:`.SchemaItem`.
@@ -201,7 +223,9 @@ class DDLEvents(event.Events):
"""
def after_parent_attach(self, target, parent):
def after_parent_attach(
self, target: SchemaEventTarget, parent: SchemaItem
) -> None:
"""Called after a :class:`.SchemaItem` is associated with
a parent :class:`.SchemaItem`.
@@ -216,13 +240,17 @@ class DDLEvents(event.Events):
"""
def _sa_event_column_added_to_pk_constraint(self, const, col):
def _sa_event_column_added_to_pk_constraint(
self, const: Constraint, col: Column[Any]
) -> None:
"""internal event hook used for primary key naming convention
updates.
"""
def column_reflect(self, inspector, table, column_info):
def column_reflect(
self, inspector: Inspector, table: Table, column_info: ReflectedColumn
) -> None:
"""Called for each unit of 'column info' retrieved when
a :class:`_schema.Table` is being reflected.
-1
View File
@@ -43,7 +43,6 @@ from ._elements_constructors import text as text
from ._elements_constructors import true as true
from ._elements_constructors import tuple_ as tuple_
from ._elements_constructors import type_coerce as type_coerce
from ._elements_constructors import typing as typing
from ._elements_constructors import within_group as within_group
from ._selectable_constructors import alias as alias
from ._selectable_constructors import cte as cte
+4 -2
View File
@@ -211,9 +211,11 @@ class StrictFromClauseRole(FromClauseRole):
__slots__ = ()
# does not allow text() or select() objects
c: ColumnCollection
c: ColumnCollection[Any]
@property
# this should be ->str , however, working around:
# https://github.com/python/mypy/issues/12440
@util.ro_non_memoized_property
def description(self) -> str:
raise NotImplementedError()
+293 -143
View File
@@ -30,16 +30,22 @@ as components in SQL expressions.
"""
from __future__ import annotations
from abc import ABC
import collections
import operator
import typing
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import MutableMapping
from typing import Optional
from typing import overload
from typing import Sequence as _typing_Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
@@ -48,6 +54,7 @@ from . import ddl
from . import roles
from . import type_api
from . import visitors
from .base import ColumnCollection
from .base import DedupeColumnCollection
from .base import DialectKWArgs
from .base import Executable
@@ -67,12 +74,15 @@ from .. import exc
from .. import inspection
from .. import util
from ..util.typing import Literal
from ..util.typing import Protocol
from ..util.typing import TypeGuard
if typing.TYPE_CHECKING:
from .type_api import TypeEngine
from ..engine import Connection
from ..engine import Engine
from ..engine.interfaces import ExecutionContext
from ..engine.mock import MockConnection
_T = TypeVar("_T", bound="Any")
_ServerDefaultType = Union["FetchedValue", str, TextClause, ColumnElement]
_TAB = TypeVar("_TAB", bound="Table")
@@ -102,7 +112,7 @@ NULL_UNSPECIFIED = util.symbol(
)
def _get_table_key(name, schema):
def _get_table_key(name: str, schema: Optional[str]) -> str:
if schema is None:
return name
else:
@@ -207,7 +217,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
__visit_name__ = "table"
constraints = None
constraints: Set[Constraint]
"""A collection of all :class:`_schema.Constraint` objects associated with
this :class:`_schema.Table`.
@@ -235,7 +245,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
"""
indexes = None
indexes: Set[Index]
"""A collection of all :class:`_schema.Index` objects associated with this
:class:`_schema.Table`.
@@ -249,6 +259,14 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
("schema", InternalTraversal.dp_string)
]
if TYPE_CHECKING:
@util.non_memoized_property
def columns(self) -> ColumnCollection[Column[Any]]:
...
c: ColumnCollection[Column[Any]]
def _gen_cache_key(self, anon_map, bindparams):
if self._annotations:
return (self,) + self._annotations_cache_key
@@ -736,11 +754,12 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
)
@property
def _sorted_constraints(self):
def _sorted_constraints(self) -> List[Constraint]:
"""Return the set of constraints as a list, sorted by creation
order.
"""
return sorted(self.constraints, key=lambda c: c._creation_order)
@property
@@ -801,6 +820,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
)
self.info = kwargs.pop("info", self.info)
exclude_columns: _typing_Sequence[str]
if autoload:
if not autoload_replace:
# don't replace columns already present.
@@ -1074,8 +1095,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
return metadata.tables[key]
args = []
for c in self.columns:
args.append(c._copy(schema=schema))
for col in self.columns:
args.append(col._copy(schema=schema))
table = Table(
name,
metadata,
@@ -1084,28 +1105,30 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause):
*args,
**self.kwargs,
)
for c in self.constraints:
if isinstance(c, ForeignKeyConstraint):
referred_schema = c._referred_schema
for const in self.constraints:
if isinstance(const, ForeignKeyConstraint):
referred_schema = const._referred_schema
if referred_schema_fn:
fk_constraint_schema = referred_schema_fn(
self, schema, c, referred_schema
self, schema, const, referred_schema
)
else:
fk_constraint_schema = (
schema if referred_schema == self.schema else None
)
table.append_constraint(
c._copy(schema=fk_constraint_schema, target_table=table)
const._copy(
schema=fk_constraint_schema, target_table=table
)
)
elif not c._type_bound:
elif not const._type_bound:
# skip unique constraints that would be generated
# by the 'unique' flag on Column
if c._column_flag:
if const._column_flag:
continue
table.append_constraint(
c._copy(schema=schema, target_table=table)
const._copy(schema=schema, target_table=table)
)
for index in self.indexes:
# skip indexes that would be generated
@@ -1734,23 +1757,25 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
name = kwargs.pop("name", None)
type_ = kwargs.pop("type_", None)
args = list(args)
if args:
if isinstance(args[0], str):
l_args = list(args)
del args
if l_args:
if isinstance(l_args[0], str):
if name is not None:
raise exc.ArgumentError(
"May not pass name positionally and as a keyword."
)
name = args.pop(0)
if args:
coltype = args[0]
name = l_args.pop(0)
if l_args:
coltype = l_args[0]
if hasattr(coltype, "_sqla_type"):
if type_ is not None:
raise exc.ArgumentError(
"May not pass type_ positionally and as a keyword."
)
type_ = args.pop(0)
type_ = l_args.pop(0)
if name is not None:
name = quoted_name(name, kwargs.pop("quote", None))
@@ -1772,7 +1797,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
else:
self.nullable = not primary_key
self.default = kwargs.pop("default", None)
default = kwargs.pop("default", None)
onupdate = kwargs.pop("onupdate", None)
self.server_default = kwargs.pop("server_default", None)
self.server_onupdate = kwargs.pop("server_onupdate", None)
@@ -1784,7 +1811,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
self.system = kwargs.pop("system", False)
self.doc = kwargs.pop("doc", None)
self.onupdate = kwargs.pop("onupdate", None)
self.autoincrement = kwargs.pop("autoincrement", "auto")
self.constraints = set()
self.foreign_keys = set()
@@ -1803,32 +1829,38 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
if isinstance(impl, SchemaEventTarget):
impl._set_parent_with_dispatch(self)
if self.default is not None:
if isinstance(self.default, (ColumnDefault, Sequence)):
args.append(self.default)
else:
args.append(ColumnDefault(self.default))
if default is not None:
if not isinstance(default, (ColumnDefault, Sequence)):
default = ColumnDefault(default)
self.default = default
l_args.append(default)
else:
self.default = None
if onupdate is not None:
if not isinstance(onupdate, (ColumnDefault, Sequence)):
onupdate = ColumnDefault(onupdate, for_update=True)
self.onupdate = onupdate
l_args.append(onupdate)
else:
self.onpudate = None
if self.server_default is not None:
if isinstance(self.server_default, FetchedValue):
args.append(self.server_default._as_for_update(False))
l_args.append(self.server_default._as_for_update(False))
else:
args.append(DefaultClause(self.server_default))
if self.onupdate is not None:
if isinstance(self.onupdate, (ColumnDefault, Sequence)):
args.append(self.onupdate)
else:
args.append(ColumnDefault(self.onupdate, for_update=True))
l_args.append(DefaultClause(self.server_default))
if self.server_onupdate is not None:
if isinstance(self.server_onupdate, FetchedValue):
args.append(self.server_onupdate._as_for_update(True))
l_args.append(self.server_onupdate._as_for_update(True))
else:
args.append(
l_args.append(
DefaultClause(self.server_onupdate, for_update=True)
)
self._init_items(*args)
self._init_items(*l_args)
util.set_creation_order(self)
@@ -1837,7 +1869,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
self._extra_kwargs(**kwargs)
foreign_keys = None
table: Table
constraints: Set[Constraint]
foreign_keys: Set[ForeignKey]
"""A collection of all :class:`_schema.ForeignKey` marker objects
associated with this :class:`_schema.Column`.
@@ -1850,7 +1886,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
"""
index = None
index: bool
"""The value of the :paramref:`_schema.Column.index` parameter.
Does not indicate if this :class:`_schema.Column` is actually indexed
@@ -1861,7 +1897,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
:attr:`_schema.Table.indexes`
"""
unique = None
unique: bool
"""The value of the :paramref:`_schema.Column.unique` parameter.
Does not indicate if this :class:`_schema.Column` is actually subject to
@@ -2074,8 +2110,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]):
server_default = self.server_default
server_onupdate = self.server_onupdate
if isinstance(server_default, (Computed, Identity)):
args.append(server_default._copy(**kw))
server_default = server_onupdate = None
args.append(self.server_default._copy(**kw))
type_ = self.type
if isinstance(type_, SchemaEventTarget):
@@ -2203,9 +2239,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
__visit_name__ = "foreign_key"
parent: Column[Any]
def __init__(
self,
column: Union[str, Column, SQLCoreOperations],
column: Union[str, Column[Any], SQLCoreOperations[Any]],
_constraint: Optional["ForeignKeyConstraint"] = None,
use_alter: bool = False,
name: Optional[str] = None,
@@ -2296,7 +2334,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
self._table_column = self._colspec
if not isinstance(
self._table_column.table, (util.NoneType, TableClause)
self._table_column.table, (type(None), TableClause)
):
raise exc.ArgumentError(
"ForeignKey received Column not bound "
@@ -2309,7 +2347,10 @@ class ForeignKey(DialectKWArgs, SchemaItem):
# object passes itself in when creating ForeignKey
# markers.
self.constraint = _constraint
self.parent = None
# .parent is not Optional under normal use
self.parent = None # type: ignore
self.use_alter = use_alter
self.name = name
self.onupdate = onupdate
@@ -2501,19 +2542,18 @@ class ForeignKey(DialectKWArgs, SchemaItem):
return parenttable, tablekey, colname
def _link_to_col_by_colstring(self, parenttable, table, colname):
if not hasattr(self.constraint, "_referred_table"):
self.constraint._referred_table = table
else:
assert self.constraint._referred_table is table
_column = None
if colname is None:
# colname is None in the case that ForeignKey argument
# was specified as table name only, in which case we
# match the column name to the same column on the
# parent.
key = self.parent
_column = table.c.get(self.parent.key, None)
# this use case wasn't working in later 1.x series
# as it had no test coverage; fixed in 2.0
parent = self.parent
assert parent is not None
key = parent.key
_column = table.c.get(key, None)
elif self.link_to_name:
key = colname
for c in table.c:
@@ -2533,10 +2573,10 @@ class ForeignKey(DialectKWArgs, SchemaItem):
key,
)
self._set_target_column(_column)
return _column
def _set_target_column(self, column):
assert isinstance(self.parent.table, Table)
assert self.parent is not None
# propagate TypeEngine to parent if it didn't have one
if self.parent.type._isnull:
@@ -2561,11 +2601,6 @@ class ForeignKey(DialectKWArgs, SchemaItem):
If no target column has been established, an exception
is raised.
.. versionchanged:: 0.9.0
Foreign key target column resolution now occurs as soon as both
the ForeignKey object and the remote Column to which it refers
are both associated with the same MetaData object.
"""
if isinstance(self._colspec, str):
@@ -2586,14 +2621,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
"parent MetaData" % parenttable
)
else:
raise exc.NoReferencedColumnError(
"Could not initialize target column for "
"ForeignKey '%s' on table '%s': "
"table '%s' has no column named '%s'"
% (self._colspec, parenttable.name, tablekey, colname),
tablekey,
colname,
table = parenttable.metadata.tables[tablekey]
return self._link_to_col_by_colstring(
parenttable, table, colname
)
elif hasattr(self._colspec, "__clause_element__"):
_column = self._colspec.__clause_element__()
return _column
@@ -2601,18 +2633,22 @@ class ForeignKey(DialectKWArgs, SchemaItem):
_column = self._colspec
return _column
def _set_parent(self, column, **kw):
if self.parent is not None and self.parent is not column:
def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
assert isinstance(parent, Column)
if self.parent is not None and self.parent is not parent:
raise exc.InvalidRequestError(
"This ForeignKey already has a parent !"
)
self.parent = column
self.parent = parent
self.parent.foreign_keys.add(self)
self.parent._on_table_attach(self._set_table)
def _set_remote_table(self, table):
parenttable, tablekey, colname = self._resolve_col_tokens()
self._link_to_col_by_colstring(parenttable, table, colname)
parenttable, _, colname = self._resolve_col_tokens()
_column = self._link_to_col_by_colstring(parenttable, table, colname)
self._set_target_column(_column)
assert self.constraint is not None
self.constraint._validate_dest_table(table)
def _remove_from_metadata(self, metadata):
@@ -2651,10 +2687,15 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if table_key in parenttable.metadata.tables:
table = parenttable.metadata.tables[table_key]
try:
self._link_to_col_by_colstring(parenttable, table, colname)
_column = self._link_to_col_by_colstring(
parenttable, table, colname
)
except exc.NoReferencedColumnError:
# this is OK, we'll try later
pass
else:
self._set_target_column(_column)
parenttable.metadata._fk_memos[fk_key].append(self)
elif hasattr(self._colspec, "__clause_element__"):
_column = self._colspec.__clause_element__()
@@ -2664,6 +2705,31 @@ class ForeignKey(DialectKWArgs, SchemaItem):
self._set_target_column(_column)
if TYPE_CHECKING:
def default_is_sequence(
obj: Optional[DefaultGenerator],
) -> TypeGuard[Sequence]:
...
def default_is_clause_element(
obj: Optional[DefaultGenerator],
) -> TypeGuard[ColumnElementColumnDefault]:
...
def default_is_scalar(
obj: Optional[DefaultGenerator],
) -> TypeGuard[ScalarElementColumnDefault]:
...
else:
default_is_sequence = operator.attrgetter("is_sequence")
default_is_clause_element = operator.attrgetter("is_clause_element")
default_is_scalar = operator.attrgetter("is_scalar")
class DefaultGenerator(Executable, SchemaItem):
"""Base class for column *default* values."""
@@ -2671,18 +2737,18 @@ class DefaultGenerator(Executable, SchemaItem):
is_sequence = False
is_server_default = False
is_clause_element = False
is_callable = False
is_scalar = False
column = None
column: Optional[Column[Any]]
def __init__(self, for_update=False):
self.for_update = for_update
@util.memoized_property
def is_callable(self):
raise NotImplementedError()
def _set_parent(self, column, **kw):
self.column = column
def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
if TYPE_CHECKING:
assert isinstance(parent, Column)
self.column = parent
if self.for_update:
self.column.onupdate = self
else:
@@ -2696,7 +2762,7 @@ class DefaultGenerator(Executable, SchemaItem):
)
class ColumnDefault(DefaultGenerator):
class ColumnDefault(DefaultGenerator, ABC):
"""A plain default value on a column.
This could correspond to a constant, a callable function,
@@ -2718,7 +2784,30 @@ class ColumnDefault(DefaultGenerator):
"""
def __init__(self, arg, **kwargs):
arg: Any
@overload
def __new__(
cls, arg: Callable[..., Any], for_update: bool = ...
) -> CallableColumnDefault:
...
@overload
def __new__(
cls, arg: ColumnElement[Any], for_update: bool = ...
) -> ColumnElementColumnDefault:
...
# if I return ScalarElementColumnDefault here, which is what's actually
# returned, mypy complains that
# overloads overlap w/ incompatible return types.
@overload
def __new__(cls, arg: object, for_update: bool = ...) -> ColumnDefault:
...
def __new__(
cls, arg: Any = None, for_update: bool = False
) -> ColumnDefault:
"""Construct a new :class:`.ColumnDefault`.
@@ -2744,70 +2833,121 @@ class ColumnDefault(DefaultGenerator):
statement and parameters.
"""
super(ColumnDefault, self).__init__(**kwargs)
if isinstance(arg, FetchedValue):
raise exc.ArgumentError(
"ColumnDefault may not be a server-side default type."
)
if callable(arg):
arg = self._maybe_wrap_callable(arg)
elif callable(arg):
cls = CallableColumnDefault
elif isinstance(arg, ClauseElement):
cls = ColumnElementColumnDefault
elif arg is not None:
cls = ScalarElementColumnDefault
return object.__new__(cls)
def __repr__(self):
return f"{self.__class__.__name__}({self.arg!r})"
class ScalarElementColumnDefault(ColumnDefault):
"""default generator for a fixed scalar Python value
.. versionadded: 2.0
"""
is_scalar = True
def __init__(self, arg: Any, for_update: bool = False):
self.for_update = for_update
self.arg = arg
@util.memoized_property
def is_callable(self):
return callable(self.arg)
@util.memoized_property
def is_clause_element(self):
return isinstance(self.arg, ClauseElement)
# _SQLExprDefault = Union["ColumnElement[Any]", "TextClause", "SelectBase"]
_SQLExprDefault = Union["ColumnElement[Any]", "TextClause"]
@util.memoized_property
def is_scalar(self):
return (
not self.is_callable
and not self.is_clause_element
and not self.is_sequence
)
class ColumnElementColumnDefault(ColumnDefault):
"""default generator for a SQL expression
.. versionadded:: 2.0
"""
is_clause_element = True
arg: _SQLExprDefault
def __init__(
self,
arg: _SQLExprDefault,
for_update: bool = False,
):
self.for_update = for_update
self.arg = arg
@util.memoized_property
@util.preload_module("sqlalchemy.sql.sqltypes")
def _arg_is_typed(self):
sqltypes = util.preloaded.sql_sqltypes
if self.is_clause_element:
return not isinstance(self.arg.type, sqltypes.NullType)
else:
return False
return not isinstance(self.arg.type, sqltypes.NullType)
def _maybe_wrap_callable(self, fn):
class _CallableColumnDefaultProtocol(Protocol):
def __call__(self, context: ExecutionContext) -> Any:
...
class CallableColumnDefault(ColumnDefault):
"""default generator for a callable Python function
.. versionadded:: 2.0
"""
is_callable = True
arg: _CallableColumnDefaultProtocol
def __init__(
self,
arg: Union[_CallableColumnDefaultProtocol, Callable[[], Any]],
for_update: bool = False,
):
self.for_update = for_update
self.arg = self._maybe_wrap_callable(arg)
def _maybe_wrap_callable(
self, fn: Union[_CallableColumnDefaultProtocol, Callable[[], Any]]
) -> _CallableColumnDefaultProtocol:
"""Wrap callables that don't accept a context.
This is to allow easy compatibility with default callables
that aren't specific to accepting of a context.
"""
try:
argspec = util.get_callable_argspec(fn, no_self=True)
except TypeError:
return util.wrap_callable(lambda ctx: fn(), fn)
return util.wrap_callable(lambda ctx: fn(), fn) # type: ignore
defaulted = argspec[3] is not None and len(argspec[3]) or 0
positionals = len(argspec[0]) - defaulted
if positionals == 0:
return util.wrap_callable(lambda ctx: fn(), fn)
return util.wrap_callable(lambda ctx: fn(), fn) # type: ignore
elif positionals == 1:
return fn
return fn # type: ignore
else:
raise exc.ArgumentError(
"ColumnDefault Python function takes zero or one "
"positional arguments"
)
def __repr__(self):
return "ColumnDefault(%r)" % (self.arg,)
class IdentityOptions:
"""Defines options for a named database sequence or an identity column.
@@ -2899,6 +3039,8 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
is_sequence = True
column: Optional[Column[Any]] = None
def __init__(
self,
name,
@@ -3087,14 +3229,6 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator):
else:
self.data_type = None
@util.memoized_property
def is_callable(self):
return False
@util.memoized_property
def is_clause_element(self):
return False
@util.preload_module("sqlalchemy.sql.functions")
def next_value(self):
"""Return a :class:`.next_value` function element
@@ -3235,6 +3369,9 @@ class Constraint(DialectKWArgs, SchemaItem):
__visit_name__ = "constraint"
_creation_order: int
_column_flag: bool
def __init__(
self,
name=None,
@@ -3316,8 +3453,6 @@ class Constraint(DialectKWArgs, SchemaItem):
class ColumnCollectionMixin:
columns = None
"""A :class:`_expression.ColumnCollection` of :class:`_schema.Column`
objects.
@@ -3326,8 +3461,17 @@ class ColumnCollectionMixin:
"""
columns: ColumnCollection[Column[Any]]
_allow_multiple_tables = False
if TYPE_CHECKING:
def _set_parent_with_dispatch(
self, parent: SchemaEventTarget, **kw: Any
) -> None:
...
def __init__(self, *columns, **kw):
_autoattach = kw.pop("_autoattach", True)
self._column_flag = kw.pop("_column_flag", False)
@@ -3404,14 +3548,16 @@ class ColumnCollectionMixin:
)
)
def _col_expressions(self, table):
def _col_expressions(self, table: Table) -> List[Column[Any]]:
return [
table.c[col] if isinstance(col, str) else col
for col in self._pending_colargs
]
def _set_parent(self, table, **kw):
for col in self._col_expressions(table):
def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
if TYPE_CHECKING:
assert isinstance(parent, Table)
for col in self._col_expressions(parent):
if col is not None:
self.columns.add(col)
@@ -3446,7 +3592,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
self, *columns, _autoattach=_autoattach, _column_flag=_column_flag
)
columns = None
columns: DedupeColumnCollection[Column[Any]]
"""A :class:`_expression.ColumnCollection` representing the set of columns
for this constraint.
@@ -3568,7 +3714,7 @@ class CheckConstraint(ColumnCollectionConstraint):
"""
self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext)
columns = []
columns: List[Column[Any]] = []
visitors.traverse(self.sqltext, {}, {"column": columns.append})
super(CheckConstraint, self).__init__(
@@ -3779,17 +3925,17 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
assert table is self.parent
self._set_parent_with_dispatch(table)
def _append_element(self, column, fk):
def _append_element(self, column: Column[Any], fk: ForeignKey) -> None:
self.columns.add(column)
self.elements.append(fk)
columns = None
columns: DedupeColumnCollection[Column[Any]]
"""A :class:`_expression.ColumnCollection` representing the set of columns
for this constraint.
"""
elements = None
elements: List[ForeignKey]
"""A sequence of :class:`_schema.ForeignKey` objects.
Each :class:`_schema.ForeignKey`
@@ -4271,7 +4417,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
self._validate_dialect_kwargs(kw)
self.expressions = []
self.expressions: List[ColumnElement[Any]] = []
# will call _set_parent() if table-bound column
# objects are present
ColumnCollectionMixin.__init__(
@@ -4501,11 +4647,13 @@ class MetaData(HasSchemaAttr):
)
if info:
self.info = info
self._schemas = set()
self._sequences = {}
self._fk_memos = collections.defaultdict(list)
self._schemas: Set[str] = set()
self._sequences: Dict[str, Sequence] = {}
self._fk_memos: Dict[
Tuple[str, str], List[ForeignKey]
] = collections.defaultdict(list)
tables: Dict[str, Table]
tables: util.FacadeDict[str, Table]
"""A dictionary of :class:`_schema.Table`
objects keyed to their name or "table key".
@@ -4539,7 +4687,7 @@ class MetaData(HasSchemaAttr):
def _remove_table(self, name, schema):
key = _get_table_key(name, schema)
removed = dict.pop(self.tables, key, None)
removed = dict.pop(self.tables, key, None) # type: ignore
if removed is not None:
for fk in removed.foreign_keys:
fk._remove_from_metadata(self)
@@ -4634,12 +4782,12 @@ class MetaData(HasSchemaAttr):
"""
return ddl.sort_tables(
sorted(self.tables.values(), key=lambda t: t.key)
sorted(self.tables.values(), key=lambda t: t.key) # type: ignore
)
def reflect(
self,
bind: Union["Engine", "Connection"],
bind: Union[Engine, Connection],
schema: Optional[str] = None,
views: bool = False,
only: Optional[_typing_Sequence[str]] = None,
@@ -4647,7 +4795,7 @@ class MetaData(HasSchemaAttr):
autoload_replace: bool = True,
resolve_fks: bool = True,
**dialect_kwargs: Any,
):
) -> None:
r"""Load all available table definitions from the database.
Automatically creates ``Table`` entries in this ``MetaData`` for any
@@ -4748,12 +4896,14 @@ class MetaData(HasSchemaAttr):
if schema is not None:
reflect_opts["schema"] = schema
available = util.OrderedSet(insp.get_table_names(schema))
available: util.OrderedSet[str] = util.OrderedSet(
insp.get_table_names(schema)
)
if views:
available.update(insp.get_view_names(schema))
if schema is not None:
available_w_schema = util.OrderedSet(
available_w_schema: util.OrderedSet[str] = util.OrderedSet(
["%s.%s" % (schema, name) for name in available]
)
else:
@@ -4796,10 +4946,10 @@ class MetaData(HasSchemaAttr):
def create_all(
self,
bind: Union["Engine", "Connection"],
bind: Union[Engine, Connection, MockConnection],
tables: Optional[_typing_Sequence[Table]] = None,
checkfirst: bool = True,
):
) -> None:
"""Create all tables stored in this metadata.
Conditional by default, will not attempt to recreate tables already
@@ -4824,10 +4974,10 @@ class MetaData(HasSchemaAttr):
def drop_all(
self,
bind: Union["Engine", "Connection"],
bind: Union[Engine, Connection, MockConnection],
tables: Optional[_typing_Sequence[Table]] = None,
checkfirst: bool = True,
):
) -> None:
"""Drop all tables stored in this metadata.
Conditional by default, will not attempt to drop tables not present in
+6 -4
View File
@@ -463,7 +463,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
_is_clone_of: Optional[FromClause]
schema = None
schema: Optional[str] = None
"""Define the 'schema' attribute for this :class:`_expression.FromClause`.
This is typically ``None`` for most objects except that of
@@ -673,7 +673,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
"""
return self._cloned_set.intersection(other._cloned_set)
@property
@util.non_memoized_property
def description(self) -> str:
"""A brief description of this :class:`_expression.FromClause`.
@@ -710,7 +710,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
return self.columns
@util.memoized_property
def columns(self) -> ColumnCollection:
def columns(self) -> ColumnCollection[Any]:
"""A named-based collection of :class:`_expression.ColumnElement`
objects maintained by this :class:`_expression.FromClause`.
@@ -796,7 +796,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
# this is awkward. maybe there's a better way
if TYPE_CHECKING:
c: ColumnCollection
c: ColumnCollection[Any]
else:
c = property(
attrgetter("columns"),
@@ -2399,6 +2399,8 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause):
_is_table = True
fullname: str
implicit_returning = False
""":class:`_expression.TableClause`
doesn't support having a primary key or column
+8 -2
View File
@@ -345,6 +345,12 @@ class Integer(HasExpressionLookup, TypeEngine[int]):
__visit_name__ = "integer"
if TYPE_CHECKING:
@util.ro_memoized_property
def _type_affinity(self) -> Type[Integer]:
...
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
@@ -1892,8 +1898,8 @@ class _AbstractInterval(HasExpressionLookup, TypeEngine[dt.timedelta]):
operators.truediv: {Numeric: self.__class__},
}
@util.non_memoized_property
def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
@util.ro_non_memoized_property
def _type_affinity(self) -> Type[Interval]:
return Interval
+3 -3
View File
@@ -705,7 +705,7 @@ class TypeEngine(Visitable, Generic[_T]):
"""
return self
@util.memoized_property
@util.ro_memoized_property
def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]:
"""Return a rudimental 'affinity' value expressing the general class
of type."""
@@ -719,7 +719,7 @@ class TypeEngine(Visitable, Generic[_T]):
else:
return self.__class__
@util.memoized_property
@util.ro_memoized_property
def _generic_type_affinity(
self,
) -> Type[TypeEngine[_T]]:
@@ -1694,7 +1694,7 @@ class TypeDecorator(SchemaEventTarget, ExternalType, TypeEngine[_T]):
tt.impl = tt.impl_instance = typedesc
return tt
@util.non_memoized_property
@util.ro_non_memoized_property
def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
return self.impl_instance._type_affinity
+2
View File
@@ -130,6 +130,8 @@ from .langhelpers import (
from .langhelpers import PluginLoader as PluginLoader
from .langhelpers import portable_instancemethod as portable_instancemethod
from .langhelpers import quoted_token_parser as quoted_token_parser
from .langhelpers import ro_memoized_property as ro_memoized_property
from .langhelpers import ro_non_memoized_property as ro_non_memoized_property
from .langhelpers import safe_reraise as safe_reraise
from .langhelpers import set_creation_order as set_creation_order
from .langhelpers import string_or_unprintable as string_or_unprintable
+1 -1
View File
@@ -135,7 +135,7 @@ def coerce_to_immutabledict(d):
EMPTY_DICT: immutabledict[Any, Any] = immutabledict()
class FacadeDict(ImmutableDictBase[Any, Any]):
class FacadeDict(ImmutableDictBase[_KT, _VT]):
"""A dictionary that is not publicly mutable."""
def __new__(cls, *args):
+28 -15
View File
@@ -55,8 +55,8 @@ _T_co = TypeVar("_T_co", covariant=True)
_F = TypeVar("_F", bound=Callable[..., Any])
_MP = TypeVar("_MP", bound="memoized_property[Any]")
_MA = TypeVar("_MA", bound="HasMemoized.memoized_attribute[Any]")
_HP = TypeVar("_HP", bound="hybridproperty")
_HM = TypeVar("_HM", bound="hybridmethod")
_HP = TypeVar("_HP", bound="hybridproperty[Any]")
_HM = TypeVar("_HM", bound="hybridmethod[Any]")
if compat.py310:
@@ -1234,12 +1234,23 @@ class _memoized_property(generic_fn_descriptor[_T_co]):
# superclass has non-memoized, the class hierarchy of the descriptors
# would need to be reversed; "class non_memoized(memoized)". so there's no
# way to achieve this.
# additional issues, RO properties:
# https://github.com/python/mypy/issues/12440
if TYPE_CHECKING:
# allow memoized and non-memoized to be freely mixed by having them
# be the same class
memoized_property = generic_fn_descriptor
non_memoized_property = generic_fn_descriptor
# for read only situations, mypy only sees @property as read only.
# read only is needed when a subtype specializes the return type
# of a property, meaning assignment needs to be disallowed
ro_memoized_property = property
ro_non_memoized_property = property
else:
memoized_property = _memoized_property
non_memoized_property = _non_memoized_property
memoized_property = ro_memoized_property = _memoized_property
non_memoized_property = ro_non_memoized_property = _non_memoized_property
def memoized_instancemethod(fn: _F) -> _F:
@@ -1515,7 +1526,9 @@ def duck_type_collection(
return default
def assert_arg_type(arg: Any, argtype: Type[Any], name: str) -> Any:
def assert_arg_type(
arg: Any, argtype: Union[Tuple[Type[Any], ...], Type[Any]], name: str
) -> Any:
if isinstance(arg, argtype):
return arg
else:
@@ -1576,37 +1589,37 @@ class classproperty(property):
return self.fget(cls) # type: ignore
class hybridproperty:
def __init__(self, func):
class hybridproperty(Generic[_T]):
def __init__(self, func: Callable[..., _T]):
self.func = func
self.clslevel = func
def __get__(self, instance, owner):
def __get__(self, instance: Any, owner: Any) -> _T:
if instance is None:
clsval = self.clslevel(owner)
return clsval
else:
return self.func(instance)
def classlevel(self, func):
def classlevel(self, func: Callable[..., Any]) -> hybridproperty[_T]:
self.clslevel = func
return self
class hybridmethod:
class hybridmethod(Generic[_T]):
"""Decorate a function as cls- or instance- level."""
def __init__(self, func):
def __init__(self, func: Callable[..., _T]):
self.func = self.__func__ = func
self.clslevel = func
def __get__(self, instance, owner):
def __get__(self, instance: Any, owner: Any) -> Callable[..., _T]:
if instance is None:
return self.clslevel.__get__(owner, owner.__class__)
return self.clslevel.__get__(owner, owner.__class__) # type:ignore
else:
return self.func.__get__(instance, owner)
return self.func.__get__(instance, owner) # type:ignore
def classlevel(self, func):
def classlevel(self, func: Callable[..., Any]) -> hybridmethod[_T]:
self.clslevel = func
return self
+2
View File
@@ -34,8 +34,10 @@ else:
if compat.py310:
from typing import TypeGuard as TypeGuard
from typing import TypeAlias as TypeAlias
else:
from typing_extensions import TypeGuard as TypeGuard
from typing_extensions import TypeAlias as TypeAlias
if typing.TYPE_CHECKING or compat.py38:
from typing import SupportsIndex as SupportsIndex
+58 -38
View File
@@ -51,57 +51,73 @@ reportTypedDictNotRequiredAccess = "warning"
[tool.mypy]
mypy_path = "./lib/"
show_error_codes = true
strict = false
strict = true
incremental = true
# disabled checking
[[tool.mypy.overrides]]
module="sqlalchemy.*"
ignore_errors = true
warn_unused_ignores = false
strict = true
# some debate at
# https://github.com/python/mypy/issues/8754.
# implicit_reexport = true
# individual packages or even modules should be listed here
# with strictness-specificity set up. there's no way we are going to get
# the whole library 100% strictly typed, so we have to tune this based on
# the type of module or package we are dealing with
[[tool.mypy.overrides]]
# ad-hoc ignores
module = [
"sqlalchemy.engine.reflection", # interim, should be strict
# TODO for strict:
"sqlalchemy.ext.asyncio.*",
"sqlalchemy.ext.automap",
"sqlalchemy.ext.compiler",
"sqlalchemy.ext.declarative.*",
"sqlalchemy.ext.mutable",
"sqlalchemy.ext.horizontal_shard",
"sqlalchemy.sql._selectable_constructors",
"sqlalchemy.sql._dml_constructors",
# TODO for non-strict:
"sqlalchemy.ext.baked",
"sqlalchemy.ext.instrumentation",
"sqlalchemy.ext.indexable",
"sqlalchemy.ext.orderinglist",
"sqlalchemy.ext.serializer",
"sqlalchemy.sql.selectable", # would be nice as strict
"sqlalchemy.sql.ddl",
"sqlalchemy.sql.functions", # would be nice as strict
"sqlalchemy.sql.lambdas",
"sqlalchemy.sql.dml", # would be nice as strict
"sqlalchemy.sql.util",
# not yet classified:
"sqlalchemy.orm.*",
"sqlalchemy.dialects.*",
"sqlalchemy.cyextension.*",
"sqlalchemy.future.*",
"sqlalchemy.testing.*",
]
warn_unused_ignores = false
ignore_errors = true
# strict checking
[[tool.mypy.overrides]]
module = [
"sqlalchemy.sql.annotation",
"sqlalchemy.sql.cache_key",
"sqlalchemy.sql._elements_constructors",
"sqlalchemy.sql.operators",
"sqlalchemy.sql.type_api",
"sqlalchemy.sql.roles",
"sqlalchemy.sql.visitors",
"sqlalchemy.sql._py_util",
# packages
"sqlalchemy.connectors.*",
"sqlalchemy.engine.*",
"sqlalchemy.ext.hybrid",
"sqlalchemy.ext.associationproxy",
"sqlalchemy.pool.*",
"sqlalchemy.event.*",
"sqlalchemy.ext.*",
"sqlalchemy.sql.*",
"sqlalchemy.engine.*",
"sqlalchemy.pool.*",
# modules
"sqlalchemy.events",
"sqlalchemy.exc",
"sqlalchemy.inspection",
"sqlalchemy.schema",
"sqlalchemy.types",
]
warn_unused_ignores = false
ignore_errors = false
strict = true
@@ -109,20 +125,24 @@ strict = true
[[tool.mypy.overrides]]
module = [
#"sqlalchemy.sql.*",
"sqlalchemy.sql.sqltypes",
"sqlalchemy.sql.elements",
"sqlalchemy.sql.coercions",
"sqlalchemy.sql.compiler",
#"sqlalchemy.sql.default_comparator",
"sqlalchemy.sql.naming",
"sqlalchemy.sql.traversals",
"sqlalchemy.util.*",
"sqlalchemy.engine.cursor",
"sqlalchemy.engine.default",
"sqlalchemy.sql.base",
"sqlalchemy.sql.coercions",
"sqlalchemy.sql.compiler",
"sqlalchemy.sql.crud",
"sqlalchemy.sql.elements", # would be nice as strict
"sqlalchemy.sql.naming",
"sqlalchemy.sql.schema", # would be nice as strict
"sqlalchemy.sql.sqltypes", # would be nice as strict
"sqlalchemy.sql.traversals",
"sqlalchemy.util.*",
]
warn_unused_ignores = false
ignore_errors = false
# mostly strict without requiring totally untyped things to be
+17 -17
View File
@@ -1,29 +1,29 @@
# /home/classic/dev/sqlalchemy/test/profiles.txt
# This file is written out on a per-environment basis.
# For each test in aaa_profiling, the corresponding function and
# For each test in aaa_profiling, the corresponding function and
# environment is located within this file. If it doesn't exist,
# the test is skipped.
# If a callcount does exist, it is compared to what we received.
# If a callcount does exist, it is compared to what we received.
# assertions are raised if the counts do not match.
#
# To add a new callcount test, apply the function_call_count
# decorator and re-run the tests using the --write-profiles
#
# To add a new callcount test, apply the function_call_count
# decorator and re-run the tests using the --write-profiles
# option - this file will be rewritten including the new count.
#
#
# TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 72
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 72
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 72
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 72
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 72
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 72
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 72
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 72
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 72
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 70
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 70
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 75
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 75
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 75
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 75
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 75
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 75
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 75
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 75
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 75
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 75
test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 75
# TEST: test.aaa_profiling.test_compiler.CompileTest.test_select
+2 -1
View File
@@ -354,7 +354,8 @@ class DefaultObjectTest(fixtures.TestBase):
assert_raises_message(
sa.exc.ArgumentError,
r"SQL expression for WHERE/HAVING role expected, "
r"got (?:Sequence|ColumnDefault|DefaultClause)\('y'.*\)",
r"got (?:Sequence|(?:ScalarElement)ColumnDefault|"
r"DefaultClause)\('y'.*\)",
t.select().where,
const,
)
+44 -1
View File
@@ -760,7 +760,10 @@ class MetaDataTest(fixtures.TestBase, ComparesTables):
"%s"
", name='someconstraint')" % repr(ck.sqltext),
),
(ColumnDefault(("foo", "bar")), "ColumnDefault(('foo', 'bar'))"),
(
ColumnDefault(("foo", "bar")),
"ScalarElementColumnDefault(('foo', 'bar'))",
),
):
eq_(repr(const), exp)
@@ -916,6 +919,46 @@ class ToMetaDataTest(fixtures.TestBase, AssertsCompiledSQL, ComparesTables):
a2 = a.to_metadata(m2)
assert b2.c.y.references(a2.c.x)
def test_fk_w_no_colname(self):
"""test a ForeignKey that refers to table name only. the column
name is assumed to be the same col name on parent table.
this is a little used feature from long ago that nonetheless is
still in the code.
The feature was found to be not working but is repaired for
SQLAlchemy 2.0.
"""
m1 = MetaData()
a = Table("a", m1, Column("x", Integer))
b = Table("b", m1, Column("x", Integer, ForeignKey("a")))
assert b.c.x.references(a.c.x)
m2 = MetaData()
b2 = b.to_metadata(m2)
a2 = a.to_metadata(m2)
assert b2.c.x.references(a2.c.x)
def test_fk_w_no_colname_name_missing(self):
"""test a ForeignKey that refers to table name only. the column
name is assumed to be the same col name on parent table.
this is a little used feature from long ago that nonetheless is
still in the code.
"""
m1 = MetaData()
a = Table("a", m1, Column("x", Integer))
b = Table("b", m1, Column("y", Integer, ForeignKey("a")))
with expect_raises_message(
exc.NoReferencedColumnError,
"Could not initialize target column for ForeignKey 'a' on "
"table 'b': table 'a' has no column named 'y'",
):
assert b.c.y.references(a.c.x)
def test_column_collection_constraint_w_ad_hoc_columns(self):
"""Test ColumnCollectionConstraint that has columns that aren't
part of the Table.