Type annotations for sqlalchemy.ext.mutable

The ``sqlalchemy.ext.mutable`` extension is now fully pep-484 typed. Huge
thanks to Gleb Kisenkov for their efforts on this.

Fixes: #8667
Closes: #8775
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8775
Pull-request-sha: b907888ec6

Change-Id: Id9224e03201e6970b1ec56eb546ece4b2f3e0edd
This commit is contained in:
Gleb Kisenkov
2022-11-16 10:23:06 -05:00
committed by Mike Bayer
parent 3fc6c40ea7
commit ba0e508141
2 changed files with 202 additions and 103 deletions
+6
View File
@@ -0,0 +1,6 @@
.. change::
:tags: bug, typing
:tickets: 8667
The ``sqlalchemy.ext.mutable`` extension is now fully pep-484 typed. Huge
thanks to Gleb Kisenkov for their efforts on this.
+196 -103
View File
@@ -4,7 +4,6 @@
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
r"""Provide support for tracking of in-place changes to scalar values,
which are propagated into ORM change events on owning parent objects.
@@ -355,16 +354,47 @@ pickling process of the parent's object-relational state so that the
""" # noqa: E501
from __future__ import annotations
from collections import defaultdict
from typing import AbstractSet
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import overload
from typing import Set
from typing import Tuple
from typing import TypeVar
from typing import Union
import weakref
from weakref import WeakKeyDictionary
from .. import event
from .. import inspect
from .. import types
from ..orm import Mapper
from ..orm._typing import _ExternalEntityType
from ..orm._typing import _O
from ..orm._typing import _T
from ..orm.attributes import AttributeEventToken
from ..orm.attributes import flag_modified
from ..orm.attributes import InstrumentedAttribute
from ..orm.attributes import QueryableAttribute
from ..orm.context import QueryContext
from ..orm.decl_api import DeclarativeAttributeIntercept
from ..orm.state import InstanceState
from ..orm.unitofwork import UOWTransaction
from ..sql.base import SchemaEventTarget
from ..sql.schema import Column
from ..sql.type_api import TypeEngine
from ..util import memoized_property
from ..util.typing import SupportsIndex
from ..util.typing import TypeGuard
_KT = TypeVar("_KT") # Key type.
_VT = TypeVar("_VT") # Value type.
class MutableBase:
@@ -374,7 +404,7 @@ class MutableBase:
"""
@memoized_property
def _parents(self):
def _parents(self) -> WeakKeyDictionary[Any, Any]:
"""Dictionary of parent object's :class:`.InstanceState`->attribute
name on the parent.
@@ -391,7 +421,7 @@ class MutableBase:
return weakref.WeakKeyDictionary()
@classmethod
def coerce(cls, key, value):
def coerce(cls, key: str, value: Any) -> Optional[Any]:
"""Given a value, coerce it into the target type.
Can be overridden by custom subclasses to coerce incoming
@@ -420,7 +450,7 @@ class MutableBase:
raise ValueError(msg % (key, type(value)))
@classmethod
def _get_listen_keys(cls, attribute):
def _get_listen_keys(cls, attribute: QueryableAttribute[Any]) -> Set[str]:
"""Given a descriptor attribute, return a ``set()`` of the attribute
keys which indicate a change in the state of this attribute.
@@ -441,7 +471,12 @@ class MutableBase:
return {attribute.key}
@classmethod
def _listen_on_attribute(cls, attribute, coerce, parent_cls):
def _listen_on_attribute(
cls,
attribute: QueryableAttribute[Any],
coerce: bool,
parent_cls: _ExternalEntityType[Any],
) -> None:
"""Establish this type as a mutation listener for the given
mapped descriptor.
@@ -455,7 +490,7 @@ class MutableBase:
listen_keys = cls._get_listen_keys(attribute)
def load(state, *args):
def load(state: InstanceState[_O], *args: Any) -> None:
"""Listen for objects loaded or refreshed.
Wrap the target data member's value with
@@ -469,11 +504,20 @@ class MutableBase:
state.dict[key] = val
val._parents[state] = key
def load_attrs(state, ctx, attrs):
def load_attrs(
state: InstanceState[_O],
ctx: Union[object, QueryContext, UOWTransaction],
attrs: Iterable[Any],
) -> None:
if not attrs or listen_keys.intersection(attrs):
load(state)
def set_(target, value, oldvalue, initiator):
def set_(
target: InstanceState[_O],
value: MutableBase | None,
oldvalue: MutableBase | None,
initiator: AttributeEventToken,
) -> MutableBase | None:
"""Listen for set/replace events on the target
data member.
@@ -493,14 +537,18 @@ class MutableBase:
oldvalue._parents.pop(inspect(target), None)
return value
def pickle(state, state_dict):
def pickle(
state: InstanceState[_O], state_dict: Dict[str, Any]
) -> None:
val = state.dict.get(key, None)
if val is not None:
if "ext.mutable.values" not in state_dict:
state_dict["ext.mutable.values"] = defaultdict(list)
state_dict["ext.mutable.values"][key].append(val)
def unpickle(state, state_dict):
def unpickle(
state: InstanceState[_O], state_dict: Dict[str, Any]
) -> None:
if "ext.mutable.values" in state_dict:
collection = state_dict["ext.mutable.values"]
if isinstance(collection, list):
@@ -543,14 +591,16 @@ class Mutable(MutableBase):
"""
def changed(self):
def changed(self) -> None:
"""Subclasses should call this method whenever change events occur."""
for parent, key in self._parents.items():
flag_modified(parent.obj(), key)
@classmethod
def associate_with_attribute(cls, attribute):
def associate_with_attribute(
cls, attribute: InstrumentedAttribute[_O]
) -> None:
"""Establish this type as a mutation listener for the given
mapped descriptor.
@@ -558,7 +608,7 @@ class Mutable(MutableBase):
cls._listen_on_attribute(attribute, True, attribute.class_)
@classmethod
def associate_with(cls, sqltype):
def associate_with(cls, sqltype: type) -> None:
"""Associate this wrapper with all future mapped columns
of the given type.
@@ -575,7 +625,7 @@ class Mutable(MutableBase):
"""
def listen_for_type(mapper, class_):
def listen_for_type(mapper: Mapper[_O], class_: type) -> None:
if mapper.non_primary:
return
for prop in mapper.column_attrs:
@@ -585,7 +635,7 @@ class Mutable(MutableBase):
event.listen(Mapper, "mapper_configured", listen_for_type)
@classmethod
def as_mutable(cls, sqltype):
def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]:
"""Associate a SQL type with this mutable Python type.
This establishes listeners that will detect ORM mappings against
@@ -625,21 +675,27 @@ class Mutable(MutableBase):
if isinstance(sqltype, SchemaEventTarget):
@event.listens_for(sqltype, "before_parent_attach")
def _add_column_memo(sqltyp, parent):
def _add_column_memo(
sqltyp: TypeEngine[Any],
parent: Column[_T],
) -> None:
parent.info["_ext_mutable_orig_type"] = sqltyp
schema_event_check = True
else:
schema_event_check = False
def listen_for_type(mapper, class_):
def listen_for_type(
mapper: Mapper[_T],
class_: Union[DeclarativeAttributeIntercept, type],
) -> None:
if mapper.non_primary:
return
for prop in mapper.column_attrs:
if (
schema_event_check
and hasattr(prop.expression, "info")
and prop.expression.info.get("_ext_mutable_orig_type")
and prop.expression.info.get("_ext_mutable_orig_type") # type: ignore # noqa: E501 # TODO: https://github.com/python/mypy/issues/1424#issuecomment-1272354487
is sqltype
) or (prop.columns[0].type is sqltype):
cls.associate_with_attribute(getattr(class_, prop.key))
@@ -659,10 +715,10 @@ class MutableComposite(MutableBase):
"""
@classmethod
def _get_listen_keys(cls, attribute):
def _get_listen_keys(cls, attribute: QueryableAttribute[_O]) -> Set[str]:
return {attribute.key}.union(attribute.property._attribute_keys)
def changed(self):
def changed(self) -> None:
"""Subclasses should call this method whenever change events occur."""
for parent, key in self._parents.items():
@@ -675,8 +731,8 @@ class MutableComposite(MutableBase):
setattr(parent.obj(), attr_name, value)
def _setup_composite_listener():
def _listen_for_type(mapper, class_):
def _setup_composite_listener() -> None:
def _listen_for_type(mapper: Mapper[_T], class_: type) -> None:
for prop in mapper.iterate_properties:
if (
hasattr(prop, "composite_class")
@@ -694,7 +750,7 @@ def _setup_composite_listener():
_setup_composite_listener()
class MutableDict(Mutable, dict):
class MutableDict(Mutable, Dict[_KT, _VT]):
"""A dictionary type that implements :class:`.Mutable`.
The :class:`.MutableDict` object implements a dictionary that will
@@ -717,41 +773,69 @@ class MutableDict(Mutable, dict):
"""
def __setitem__(self, key, value):
def __setitem__(self, key: _KT, value: _VT) -> None:
"""Detect dictionary set events and emit change events."""
dict.__setitem__(self, key, value)
super().__setitem__(key, value)
self.changed()
def setdefault(self, key, value):
result = dict.setdefault(self, key, value)
def _exists(self, value: _T | None) -> TypeGuard[_T]:
return value is not None
def _is_none(self, value: _T | None) -> TypeGuard[None]:
return value is None
@overload
def setdefault(self, key: _KT) -> _VT | None:
...
@overload
def setdefault(self, key: _KT, value: _VT) -> _VT:
...
def setdefault(self, key: _KT, value: _VT | None = None) -> _VT | None:
if self._exists(value):
result = super().setdefault(key, value)
else:
result = super().setdefault(key) # type: ignore[call-arg]
self.changed()
return result
def __delitem__(self, key):
def __delitem__(self, key: _KT) -> None:
"""Detect dictionary del events and emit change events."""
dict.__delitem__(self, key)
super().__delitem__(key)
self.changed()
def update(self, *a, **kw):
dict.update(self, *a, **kw)
def update(self, *a: Any, **kw: _VT) -> None:
super().update(*a, **kw)
self.changed()
def pop(self, *arg):
result = dict.pop(self, *arg)
@overload
def pop(self, __key: _KT) -> _VT:
...
@overload
def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T:
...
def pop(self, __key: _KT, __default: _VT | _T | None = None) -> _VT | _T:
if self._exists(__default):
result = super().pop(__key, __default)
else:
result = super().pop(__key)
self.changed()
return result
def popitem(self):
result = dict.popitem(self)
def popitem(self) -> Tuple[_KT, _VT]:
result = super().popitem()
self.changed()
return result
def clear(self):
dict.clear(self)
def clear(self) -> None:
super().clear()
self.changed()
@classmethod
def coerce(cls, key, value):
def coerce(cls, key: str, value: Any) -> MutableDict[_KT, _VT] | None:
"""Convert plain dictionary to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, dict):
@@ -760,14 +844,16 @@ class MutableDict(Mutable, dict):
else:
return value
def __getstate__(self):
def __getstate__(self) -> dict[_KT, _VT]:
return dict(self)
def __setstate__(self, state):
def __setstate__(
self, state: Union[Dict[str, int], Dict[str, str]]
) -> None:
self.update(state)
class MutableList(Mutable, list):
class MutableList(Mutable, List[_T]):
"""A list type that implements :class:`.Mutable`.
The :class:`.MutableList` object implements a list that will
@@ -792,83 +878,88 @@ class MutableList(Mutable, list):
"""
def __reduce_ex__(self, proto):
def __reduce_ex__(
self, proto: SupportsIndex
) -> Tuple[type, Tuple[List[int]]]:
return (self.__class__, (list(self),))
# needed for backwards compatibility with
# older pickles
def __setstate__(self, state):
def __setstate__(self, state: Iterable[_T]) -> None:
self[:] = state
def __setitem__(self, index, value):
def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]:
return not isinstance(value, Iterable)
def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]:
return isinstance(value, Iterable)
def __setitem__(
self, index: SupportsIndex | slice, value: _T | Iterable[_T]
) -> None:
"""Detect list set events and emit change events."""
list.__setitem__(self, index, value)
if isinstance(index, SupportsIndex) and self.is_scalar(value):
super().__setitem__(index, value)
elif isinstance(index, slice) and self.is_iterable(value):
super().__setitem__(index, value)
self.changed()
def __setslice__(self, start, end, value):
"""Detect list set events and emit change events."""
list.__setslice__(self, start, end, value)
self.changed()
def __delitem__(self, index):
def __delitem__(self, index: SupportsIndex | slice) -> None:
"""Detect list del events and emit change events."""
list.__delitem__(self, index)
super().__delitem__(index)
self.changed()
def __delslice__(self, start, end):
"""Detect list del events and emit change events."""
list.__delslice__(self, start, end)
self.changed()
def pop(self, *arg):
result = list.pop(self, *arg)
def pop(self, *arg: SupportsIndex) -> _T:
result = super().pop(*arg)
self.changed()
return result
def append(self, x):
list.append(self, x)
def append(self, x: _T) -> None:
super().append(x)
self.changed()
def extend(self, x):
list.extend(self, x)
def extend(self, x: Iterable[_T]) -> None:
super().extend(x)
self.changed()
def __iadd__(self, x):
def __iadd__(self, x: Iterable[_T]) -> MutableList[_T]: # type: ignore[override,misc] # noqa: E501
self.extend(x)
return self
def insert(self, i, x):
list.insert(self, i, x)
def insert(self, i: SupportsIndex, x: _T) -> None:
super().insert(i, x)
self.changed()
def remove(self, i):
list.remove(self, i)
def remove(self, i: _T) -> None:
super().remove(i)
self.changed()
def clear(self):
list.clear(self)
def clear(self) -> None:
super().clear()
self.changed()
def sort(self, **kw):
list.sort(self, **kw)
def sort(self, **kw: Any) -> None:
super().sort(**kw)
self.changed()
def reverse(self):
list.reverse(self)
def reverse(self) -> None:
super().reverse()
self.changed()
@classmethod
def coerce(cls, index, value):
def coerce(
cls, key: str, value: MutableList[_T] | _T
) -> Optional[MutableList[_T]]:
"""Convert plain list to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, list):
return cls(value)
return Mutable.coerce(index, value)
return Mutable.coerce(key, value)
else:
return value
class MutableSet(Mutable, set):
class MutableSet(Mutable, Set[_T]):
"""A set type that implements :class:`.Mutable`.
The :class:`.MutableSet` object implements a set that will
@@ -894,61 +985,61 @@ class MutableSet(Mutable, set):
"""
def update(self, *arg):
set.update(self, *arg)
def update(self, *arg: Iterable[_T]) -> None:
super().update(*arg)
self.changed()
def intersection_update(self, *arg):
set.intersection_update(self, *arg)
def intersection_update(self, *arg: Iterable[Any]) -> None:
super().intersection_update(*arg)
self.changed()
def difference_update(self, *arg):
set.difference_update(self, *arg)
def difference_update(self, *arg: Iterable[Any]) -> None:
super().difference_update(*arg)
self.changed()
def symmetric_difference_update(self, *arg):
set.symmetric_difference_update(self, *arg)
def symmetric_difference_update(self, *arg: Iterable[_T]) -> None:
super().symmetric_difference_update(*arg)
self.changed()
def __ior__(self, other):
def __ior__(self, other: AbstractSet[_T]) -> MutableSet[_T]: # type: ignore[override,misc] # noqa: E501
self.update(other)
return self
def __iand__(self, other):
def __iand__(self, other: AbstractSet[object]) -> MutableSet[_T]:
self.intersection_update(other)
return self
def __ixor__(self, other):
def __ixor__(self, other: AbstractSet[_T]) -> MutableSet[_T]: # type: ignore[override,misc] # noqa: E501
self.symmetric_difference_update(other)
return self
def __isub__(self, other):
def __isub__(self, other: AbstractSet[object]) -> MutableSet[_T]: # type: ignore[misc] # noqa: E501
self.difference_update(other)
return self
def add(self, elem):
set.add(self, elem)
def add(self, elem: _T) -> None:
super().add(elem)
self.changed()
def remove(self, elem):
set.remove(self, elem)
def remove(self, elem: _T) -> None:
super().remove(elem)
self.changed()
def discard(self, elem):
set.discard(self, elem)
def discard(self, elem: _T) -> None:
super().discard(elem)
self.changed()
def pop(self, *arg):
result = set.pop(self, *arg)
def pop(self, *arg: Any) -> _T:
result = super().pop(*arg)
self.changed()
return result
def clear(self):
set.clear(self)
def clear(self) -> None:
super().clear()
self.changed()
@classmethod
def coerce(cls, index, value):
def coerce(cls, index: str, value: Any) -> Optional[MutableSet[_T]]:
"""Convert plain set to instance of this class."""
if not isinstance(value, cls):
if isinstance(value, set):
@@ -957,11 +1048,13 @@ class MutableSet(Mutable, set):
else:
return value
def __getstate__(self):
def __getstate__(self) -> set[_T]:
return set(self)
def __setstate__(self, state):
def __setstate__(self, state: Iterable[_T]) -> None:
self.update(state)
def __reduce_ex__(self, proto):
def __reduce_ex__(
self, proto: SupportsIndex
) -> Tuple[type, Tuple[List[int]]]:
return (self.__class__, (list(self),))