Type annotations for sqlalchemy.ext.automap

An attempt to annotate `lib/sqlalchemy/ext/automap.py` with type hints (issue [#6810](https://github.com/sqlalchemy/sqlalchemy/issues/6810#issuecomment-1127062951)).

More info on how I approach it could be found in [the earlier PR](https://github.com/sqlalchemy/sqlalchemy/pull/8775).

This pull request is:

- [ ] A documentation / typographical error fix
  - Good to go, no issue or tests are needed
- [ ] A short code fix
  - please include the issue number, and create an issue if none exists, which
    must include a complete example of the issue. one line code fixes without an
    issue and demonstration will not be accepted.
  - Please include: `Fixes: #<issue number>` in the commit message
  - please include tests. one line code fixes without tests will not be accepted.
- [x] A new feature implementation
  - please include the issue number, and create an issue if none exists, which must
    include a complete example of how the feature would look.
  - Please include: `Fixes: #<issue number>` in the commit message
  - please include tests.

**Have a nice day!**

Closes: #8874
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8874
Pull-request-sha: 834d58d77c

Change-Id: Ie64b2be7a51ddc83ef8f23385fb63db5b5c1bc17
This commit is contained in:
Gleb Kisenkov
2022-12-05 08:45:25 -05:00
committed by Mike Bayer
parent 9058593e0b
commit 422d8d3bcb
3 changed files with 244 additions and 62 deletions
+239 -57
View File
@@ -4,7 +4,6 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""Define an extension to the :mod:`sqlalchemy.ext.declarative` system
which automatically generates mapped classes and relationships from a database
@@ -572,6 +571,22 @@ be applied as::
""" # noqa
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import List
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from .. import util
from ..orm import backref
from ..orm import declarative_base as _declarative_base
@@ -582,9 +597,36 @@ from ..orm.decl_base import _DeferredMapperConfig
from ..orm.mapper import _CONFIGURE_MUTEX
from ..schema import ForeignKeyConstraint
from ..sql import and_
from ..util.typing import Protocol
if TYPE_CHECKING:
from ..engine.base import Engine
from ..orm.base import RelationshipDirection
from ..orm.relationships import ORMBackrefArgument
from ..orm.relationships import Relationship
from ..sql.elements import quoted_name
from ..sql.schema import Column
from ..sql.schema import Table
from ..util import immutabledict
from ..util import Properties
def classname_for_table(base, tablename, table):
_KT = TypeVar("_KT", bound=Any)
_VT = TypeVar("_VT", bound=Any)
class ClassnameForTableType(Protocol):
def __call__(
self, base: Type[Any], tablename: quoted_name, table: Table
) -> str:
...
def classname_for_table(
base: Type[Any],
tablename: quoted_name,
table: Table,
) -> str:
"""Return the class name that should be used, given the name
of a table.
@@ -617,7 +659,23 @@ def classname_for_table(base, tablename, table):
return str(tablename)
def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
class NameForScalarRelationshipType(Protocol):
def __call__(
self,
base: Type[Any],
local_cls: Type[Any],
referred_cls: Type[Any],
constraint: ForeignKeyConstraint,
) -> str:
...
def name_for_scalar_relationship(
base: Type[Any],
local_cls: Type[Any],
referred_cls: Type[Any],
constraint: ForeignKeyConstraint,
) -> str:
"""Return the attribute name that should be used to refer from one
class to another, for a scalar object reference.
@@ -642,9 +700,23 @@ def name_for_scalar_relationship(base, local_cls, referred_cls, constraint):
return referred_cls.__name__.lower()
class NameForCollectionRelationshipType(Protocol):
def __call__(
self,
base: Type[Any],
local_cls: Type[Any],
referred_cls: Type[Any],
constraint: ForeignKeyConstraint,
) -> str:
...
def name_for_collection_relationship(
base, local_cls, referred_cls, constraint
):
base: Type[Any],
local_cls: Type[Any],
referred_cls: Type[Any],
constraint: ForeignKeyConstraint,
) -> str:
"""Return the attribute name that should be used to refer from one
class to another, for a collection reference.
@@ -670,9 +742,85 @@ def name_for_collection_relationship(
return referred_cls.__name__.lower() + "_collection"
class GenerateRelationshipType(Protocol):
@overload
def __call__(
self,
base: Type[Any],
direction: RelationshipDirection,
return_fn: Callable[..., Relationship[Any]],
attrname: str,
local_cls: Type[Any],
referred_cls: Type[Any],
**kw: Any,
) -> Relationship[Any]:
...
@overload
def __call__(
self,
base: Type[Any],
direction: RelationshipDirection,
return_fn: Callable[..., ORMBackrefArgument],
attrname: str,
local_cls: Type[Any],
referred_cls: Type[Any],
**kw: Any,
) -> ORMBackrefArgument:
...
def __call__(
self,
base: Type[Any],
direction: RelationshipDirection,
return_fn: Union[
Callable[..., Relationship[Any]], Callable[..., ORMBackrefArgument]
],
attrname: str,
local_cls: Type[Any],
referred_cls: Type[Any],
**kw: Any,
) -> Union[ORMBackrefArgument, Relationship[Any]]:
...
@overload
def generate_relationship(
base, direction, return_fn, attrname, local_cls, referred_cls, **kw
):
base: Type[Any],
direction: RelationshipDirection,
return_fn: Callable[..., Relationship[Any]],
attrname: str,
local_cls: Type[Any],
referred_cls: Type[Any],
**kw: Any,
) -> Relationship[Any]:
...
@overload
def generate_relationship(
base: Type[Any],
direction: RelationshipDirection,
return_fn: Callable[..., ORMBackrefArgument],
attrname: str,
local_cls: Type[Any],
referred_cls: Type[Any],
**kw: Any,
) -> ORMBackrefArgument:
...
def generate_relationship(
base: Type[Any],
direction: RelationshipDirection,
return_fn: Union[
Callable[..., Relationship[Any]], Callable[..., ORMBackrefArgument]
],
attrname: str,
local_cls: Type[Any],
referred_cls: Type[Any],
**kw: Any,
) -> Union[Relationship[Any], ORMBackrefArgument]:
r"""Generate a :func:`_orm.relationship` or :func:`.backref`
on behalf of two
mapped classes.
@@ -721,6 +869,7 @@ def generate_relationship(
by the :paramref:`.generate_relationship.return_fn` parameter.
"""
if return_fn is backref:
return return_fn(attrname, **kw)
elif return_fn is relationship:
@@ -748,7 +897,7 @@ class AutomapBase:
__abstract__ = True
classes = None
classes: Optional[Properties[Type[Any]]] = None
"""An instance of :class:`.util.Properties` containing classes.
This object behaves much like the ``.c`` collection on a table. Classes
@@ -781,18 +930,24 @@ class AutomapBase:
),
)
def prepare(
cls,
autoload_with=None,
engine=None,
reflect=False,
schema=None,
classname_for_table=None,
collection_class=None,
name_for_scalar_relationship=None,
name_for_collection_relationship=None,
generate_relationship=None,
reflection_options=util.EMPTY_DICT,
):
cls: Type[Any],
autoload_with: Optional[Engine] = None,
engine: Optional[Any] = None,
reflect: bool = False,
schema: Optional[str] = None,
classname_for_table: Optional[ClassnameForTableType] = None,
collection_class: Optional[Any] = None,
name_for_scalar_relationship: Optional[
NameForScalarRelationshipType
] = None,
name_for_collection_relationship: Optional[
NameForCollectionRelationshipType
] = None,
generate_relationship: Optional[GenerateRelationshipType] = None,
reflection_options: Union[
Dict[_KT, _VT], immutabledict[_KT, _VT]
] = util.EMPTY_DICT,
) -> None:
"""Extract mapped classes and relationships from the
:class:`_schema.MetaData` and
perform mappings.
@@ -874,6 +1029,7 @@ class AutomapBase:
autoload_with = engine
if reflect:
assert autoload_with
opts = dict(
schema=schema,
extend_existing=True,
@@ -884,18 +1040,30 @@ class AutomapBase:
cls.metadata.reflect(autoload_with, **opts)
with _CONFIGURE_MUTEX:
table_to_map_config = {
m.local_table: m
table_to_map_config: Union[
Dict[Optional[Table], _DeferredMapperConfig],
Dict[Table, _DeferredMapperConfig],
] = {
cast("Table", m.local_table): m
for m in _DeferredMapperConfig.classes_for_base(
cls, sort=False
)
}
many_to_many = []
many_to_many: list[
tuple[
Table,
Table,
list[ForeignKeyConstraint],
Table,
]
] = []
for table in cls.metadata.tables.values():
lcl_m2m, rem_m2m, m2m_const = _is_many_to_many(cls, table)
if lcl_m2m is not None:
assert rem_m2m is not None
assert m2m_const is not None
many_to_many.append((lcl_m2m, rem_m2m, m2m_const, table))
elif not table.primary_key:
continue
@@ -961,7 +1129,7 @@ class AutomapBase:
"""
@classmethod
def _sa_raise_deferred_config(cls):
def _sa_raise_deferred_config(cls) -> NoReturn:
raise orm_exc.UnmappedClassError(
cls,
msg="Class %s is a subclass of AutomapBase. "
@@ -971,7 +1139,9 @@ class AutomapBase:
)
def automap_base(declarative_base=None, **kw):
def automap_base(
declarative_base: Optional[Type[Any]] = None, **kw: Any
) -> Any:
r"""Produce a declarative automap base.
This function produces a new base class that is a product of the
@@ -1003,7 +1173,11 @@ def automap_base(declarative_base=None, **kw):
)
def _is_many_to_many(automap_base, table):
def _is_many_to_many(
automap_base: Type[Any], table: Table
) -> Tuple[
Optional[Table], Optional[Table], Optional[list[ForeignKeyConstraint]]
]:
fk_constraints = [
const
for const in table.constraints
@@ -1012,7 +1186,7 @@ def _is_many_to_many(automap_base, table):
if len(fk_constraints) != 2:
return None, None, None
cols = sum(
cols: list[Column[Any]] = sum(
[
[fk.parent for fk in fk_constraint.elements]
for fk_constraint in fk_constraints
@@ -1031,16 +1205,21 @@ def _is_many_to_many(automap_base, table):
def _relationships_for_fks(
automap_base,
map_config,
table_to_map_config,
collection_class,
name_for_scalar_relationship,
name_for_collection_relationship,
generate_relationship,
):
local_table = map_config.local_table
local_cls = map_config.cls # derived from a weakref, may be None
automap_base: Type[Any],
map_config: _DeferredMapperConfig,
table_to_map_config: Union[
Dict[Optional[Table], _DeferredMapperConfig],
Dict[Table, _DeferredMapperConfig],
],
collection_class: type,
name_for_scalar_relationship: NameForScalarRelationshipType,
name_for_collection_relationship: NameForCollectionRelationshipType,
generate_relationship: GenerateRelationshipType,
) -> None:
local_table = cast("Optional[Table]", map_config.local_table)
local_cls = cast(
"Optional[Type[Any]]", map_config.cls
) # derived from a weakref, may be None
if local_table is None or local_cls is None:
return
@@ -1065,7 +1244,7 @@ def _relationships_for_fks(
automap_base, referred_cls, local_cls, constraint
)
o2m_kws = {}
o2m_kws: dict[str, Union[str, bool]] = {}
nullable = False not in {fk.parent.nullable for fk in fks}
if not nullable:
o2m_kws["cascade"] = "all, delete-orphan"
@@ -1114,7 +1293,7 @@ def _relationships_for_fks(
if not create_backref:
referred_cfg.properties[
backref_name
].back_populates = relationship_name
].back_populates = relationship_name # type: ignore[union-attr] # noqa: E501
elif create_backref:
rel = generate_relationship(
automap_base,
@@ -1132,21 +1311,24 @@ def _relationships_for_fks(
referred_cfg.properties[backref_name] = rel
map_config.properties[
relationship_name
].back_populates = backref_name
].back_populates = backref_name # type: ignore[union-attr]
def _m2m_relationship(
automap_base,
lcl_m2m,
rem_m2m,
m2m_const,
table,
table_to_map_config,
collection_class,
name_for_scalar_relationship,
name_for_collection_relationship,
generate_relationship,
):
automap_base: Type[Any],
lcl_m2m: Table,
rem_m2m: Table,
m2m_const: List[ForeignKeyConstraint],
table: Table,
table_to_map_config: Union[
Dict[Optional[Table], _DeferredMapperConfig],
Dict[Table, _DeferredMapperConfig],
],
collection_class: type,
name_for_scalar_relationship: NameForCollectionRelationshipType,
name_for_collection_relationship: NameForCollectionRelationshipType,
generate_relationship: GenerateRelationshipType,
) -> None:
map_config = table_to_map_config.get(lcl_m2m, None)
referred_cfg = table_to_map_config.get(rem_m2m, None)
@@ -1196,10 +1378,10 @@ def _m2m_relationship(
secondary=table,
primaryjoin=and_(
fk.column == fk.parent for fk in m2m_const[0].elements
),
), # type: ignore [arg-type]
secondaryjoin=and_(
fk.column == fk.parent for fk in m2m_const[1].elements
),
), # type: ignore [arg-type]
backref=backref_obj,
collection_class=collection_class,
)
@@ -1209,7 +1391,7 @@ def _m2m_relationship(
if not create_backref:
referred_cfg.properties[
backref_name
].back_populates = relationship_name
].back_populates = relationship_name # type: ignore[union-attr] # noqa: E501
elif create_backref:
rel = generate_relationship(
automap_base,
@@ -1222,10 +1404,10 @@ def _m2m_relationship(
secondary=table,
primaryjoin=and_(
fk.column == fk.parent for fk in m2m_const[1].elements
),
), # type: ignore [arg-type]
secondaryjoin=and_(
fk.column == fk.parent for fk in m2m_const[0].elements
),
), # type: ignore [arg-type]
back_populates=relationship_name,
collection_class=collection_class,
)
@@ -1233,4 +1415,4 @@ def _m2m_relationship(
referred_cfg.properties[backref_name] = rel
map_config.properties[
relationship_name
].back_populates = backref_name
].back_populates = backref_name # type: ignore[union-attr]
+3 -3
View File
@@ -57,10 +57,10 @@ if TYPE_CHECKING:
from .mapper import Mapper
from .query import Query
from .relationships import _LazyLoadArgumentType
from .relationships import _ORMBackrefArgument
from .relationships import _ORMColCollectionArgument
from .relationships import _ORMOrderByArgument
from .relationships import _RelationshipJoinConditionArgument
from .relationships import ORMBackrefArgument
from .session import _SessionBind
from ..sql._typing import _ColumnExpressionArgument
from ..sql._typing import _FromClauseArgument
@@ -781,7 +781,7 @@ def relationship(
secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
back_populates: Optional[str] = None,
order_by: _ORMOrderByArgument = False,
backref: Optional[_ORMBackrefArgument] = None,
backref: Optional[ORMBackrefArgument] = None,
overlaps: Optional[str] = None,
post_update: bool = False,
cascade: str = "save-update, merge",
@@ -1898,7 +1898,7 @@ def dynamic_loader(
return relationship(argument, **kw)
def backref(name: str, **kwargs: Any) -> _ORMBackrefArgument:
def backref(name: str, **kwargs: Any) -> ORMBackrefArgument:
"""When using the :paramref:`_orm.relationship.backref` parameter,
provides specific parameters to be used when the new
:func:`_orm.relationship` is generated.
+2 -2
View File
@@ -171,7 +171,7 @@ _ORMOrderByArgument = Union[
Callable[[], Iterable[ColumnElement[Any]]],
Iterable[Union[str, _ColumnExpressionArgument[Any]]],
]
_ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]]
_ORMColCollectionElement = Union[
ColumnClause[Any], _HasClauseElement, roles.DMLColumnRole
@@ -366,7 +366,7 @@ class RelationshipProperty(
secondaryjoin: Optional[_RelationshipJoinConditionArgument] = None,
back_populates: Optional[str] = None,
order_by: _ORMOrderByArgument = False,
backref: Optional[_ORMBackrefArgument] = None,
backref: Optional[ORMBackrefArgument] = None,
overlaps: Optional[str] = None,
post_update: bool = False,
cascade: str = "save-update, merge",