mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-14 20:57:19 -04:00
read from cls.__dict__ so init_subclass works
Modified the :class:`.DeclarativeMeta` metaclass to pass ``cls.__dict__`` into the declarative scanning process to look for attributes, rather than the separate dictionary passed to the type's ``__init__()`` method. This allows user-defined base classes that add attributes within an ``__init_subclass__()`` to work as expected, as ``__init_subclass__()`` can only affect the ``cls.__dict__`` itself and not the other dictionary. This is technically a regression from 1.3 where ``__dict__`` was being used. Additionally makes the reference between ClassManager and the declarative configuration object a weak reference, so that it can be discarded after mappers are set up. Fixes: #7900 Change-Id: I3c2fd4e227cc1891aa4bb3d7d5b43d5686f9f27c
This commit is contained in:
+14
@@ -0,0 +1,14 @@
|
||||
.. change::
|
||||
:tags: bug, orm, declarative
|
||||
:tickets: 7900
|
||||
|
||||
Modified the :class:`.DeclarativeMeta` metaclass to pass ``cls.__dict__``
|
||||
into the declarative scanning process to look for attributes, rather than
|
||||
the separate dictionary passed to the type's ``__init__()`` method. This
|
||||
allows user-defined base classes that add attributes within an
|
||||
``__init_subclass__()`` to work as expected, as ``__init_subclass__()`` can
|
||||
only affect the ``cls.__dict__`` itself and not the other dictionary. This
|
||||
is technically a regression from 1.3 where ``__dict__`` was being used.
|
||||
|
||||
|
||||
|
||||
@@ -109,6 +109,10 @@ class DeclarativeMeta(
|
||||
def __init__(
|
||||
cls, classname: Any, bases: Any, dict_: Any, **kw: Any
|
||||
) -> None:
|
||||
# use cls.__dict__, which can be modified by an
|
||||
# __init_subclass__() method (#7900)
|
||||
dict_ = cls.__dict__
|
||||
|
||||
# early-consume registry from the initial declarative base,
|
||||
# assign privately to not conflict with subclass attributes named
|
||||
# "registry"
|
||||
@@ -293,7 +297,8 @@ class declared_attr(interfaces._MappedAttribute[_T]):
|
||||
|
||||
# here, we are inside of the declarative scan. use the registry
|
||||
# that is tracking the values of these attributes.
|
||||
declarative_scan = manager.declarative_scan
|
||||
declarative_scan = manager.declarative_scan()
|
||||
assert declarative_scan is not None
|
||||
reg = declarative_scan.declared_attr_reg
|
||||
|
||||
if self in reg:
|
||||
|
||||
@@ -161,7 +161,13 @@ def _check_declared_props_nocascade(obj, name, cls):
|
||||
|
||||
|
||||
class _MapperConfig:
|
||||
__slots__ = ("cls", "classname", "properties", "declared_attr_reg")
|
||||
__slots__ = (
|
||||
"cls",
|
||||
"classname",
|
||||
"properties",
|
||||
"declared_attr_reg",
|
||||
"__weakref__",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def setup_mapping(cls, registry, cls_, dict_, table, mapper_kw):
|
||||
@@ -311,13 +317,15 @@ class _ClassScanMapperConfig(_MapperConfig):
|
||||
mapper_kw,
|
||||
):
|
||||
|
||||
# grab class dict before the instrumentation manager has been added.
|
||||
# reduces cycles
|
||||
self.clsdict_view = (
|
||||
util.immutabledict(dict_) if dict_ else util.EMPTY_DICT
|
||||
)
|
||||
super(_ClassScanMapperConfig, self).__init__(registry, cls_, mapper_kw)
|
||||
self.registry = registry
|
||||
self.persist_selectable = None
|
||||
|
||||
self.clsdict_view = (
|
||||
util.immutabledict(dict_) if dict_ else util.EMPTY_DICT
|
||||
)
|
||||
self.collected_attributes = {}
|
||||
self.collected_annotations: Dict[str, Tuple[Any, bool]] = {}
|
||||
self.declared_columns = util.OrderedSet()
|
||||
|
||||
@@ -39,6 +39,7 @@ from typing import Optional
|
||||
from typing import Set
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
import weakref
|
||||
|
||||
from . import base
|
||||
from . import collections
|
||||
@@ -167,7 +168,7 @@ class ClassManager(
|
||||
if registry:
|
||||
registry._add_manager(self)
|
||||
if declarative_scan:
|
||||
self.declarative_scan = declarative_scan
|
||||
self.declarative_scan = weakref.ref(declarative_scan)
|
||||
if expired_attribute_loader:
|
||||
self.expired_attribute_loader = expired_attribute_loader
|
||||
|
||||
|
||||
@@ -43,7 +43,11 @@ Base = None
|
||||
mapper_registry = None
|
||||
|
||||
|
||||
class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
|
||||
class DeclarativeTestBase(
|
||||
testing.AssertsCompiledSQL,
|
||||
fixtures.TestBase,
|
||||
testing.AssertsExecutionResults,
|
||||
):
|
||||
def setup_test(self):
|
||||
global Base, mapper_registry
|
||||
|
||||
@@ -58,6 +62,19 @@ class DeclarativeTestBase(fixtures.TestBase, testing.AssertsExecutionResults):
|
||||
|
||||
|
||||
class DeclarativeMixinTest(DeclarativeTestBase):
|
||||
def test_init_subclass_works(self, registry):
|
||||
class Base:
|
||||
def __init_subclass__(cls):
|
||||
cls.id = Column(Integer, primary_key=True)
|
||||
|
||||
Base = registry.generate_base(cls=Base)
|
||||
|
||||
class Foo(Base):
|
||||
__tablename__ = "foo"
|
||||
name = Column(String)
|
||||
|
||||
self.assert_compile(select(Foo), "SELECT foo.name, foo.id FROM foo")
|
||||
|
||||
def test_simple_wbase(self):
|
||||
class MyMixin:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user