mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-18 22:52:01 -04:00
Fill-out dataclass-related attr resolution
Fixed issue where mixin attribute rules were not taking effect correctly for attributes pulled from dataclasses using the approach added in #5745. Fixes: #5876 Change-Id: I45099a42de1d9611791e72250fe0edc69bed684c
This commit is contained in:
+110
-21
@@ -325,6 +325,94 @@ class _ClassScanMapperConfig(_MapperConfig):
|
||||
def before_configured():
|
||||
self.cls.__declare_first__()
|
||||
|
||||
def _cls_attr_override_checker(self, cls):
|
||||
"""Produce a function that checks if a class has overridden an
|
||||
attribute, taking SQLAlchemy-enabled dataclass fields into account.
|
||||
|
||||
"""
|
||||
sa_dataclass_metadata_key = _get_immediate_cls_attr(
|
||||
cls, "__sa_dataclass_metadata_key__", None
|
||||
)
|
||||
|
||||
if sa_dataclass_metadata_key is None:
|
||||
|
||||
def attribute_is_overridden(key, obj):
|
||||
return getattr(cls, key) is not obj
|
||||
|
||||
else:
|
||||
|
||||
all_datacls_fields = {
|
||||
f.name: f.metadata[sa_dataclass_metadata_key]
|
||||
for f in util.dataclass_fields(cls)
|
||||
if sa_dataclass_metadata_key in f.metadata
|
||||
}
|
||||
local_datacls_fields = {
|
||||
f.name: f.metadata[sa_dataclass_metadata_key]
|
||||
for f in util.local_dataclass_fields(cls)
|
||||
if sa_dataclass_metadata_key in f.metadata
|
||||
}
|
||||
|
||||
absent = object()
|
||||
|
||||
def attribute_is_overridden(key, obj):
|
||||
# this function likely has some failure modes still if
|
||||
# someone is doing a deep mixing of the same attribute
|
||||
# name as plain Python attribute vs. dataclass field.
|
||||
|
||||
ret = local_datacls_fields.get(key, absent)
|
||||
|
||||
if ret is obj:
|
||||
return False
|
||||
elif ret is not absent:
|
||||
return True
|
||||
|
||||
ret = getattr(cls, key, obj)
|
||||
|
||||
if ret is obj:
|
||||
return False
|
||||
elif ret is not absent:
|
||||
return True
|
||||
|
||||
ret = all_datacls_fields.get(key, absent)
|
||||
|
||||
if ret is obj:
|
||||
return False
|
||||
elif ret is not absent:
|
||||
return True
|
||||
|
||||
# can't find another attribute
|
||||
return False
|
||||
|
||||
return attribute_is_overridden
|
||||
|
||||
def _cls_attr_resolver(self, cls):
|
||||
"""produce a function to iterate the "attributes" of a class,
|
||||
adjusting for SQLAlchemy fields embedded in dataclass fields.
|
||||
|
||||
"""
|
||||
sa_dataclass_metadata_key = _get_immediate_cls_attr(
|
||||
cls, "__sa_dataclass_metadata_key__", None
|
||||
)
|
||||
|
||||
if sa_dataclass_metadata_key is None:
|
||||
|
||||
def local_attributes_for_class():
|
||||
for name, obj in vars(cls).items():
|
||||
yield name, obj
|
||||
|
||||
else:
|
||||
|
||||
def local_attributes_for_class():
|
||||
for name, obj in vars(cls).items():
|
||||
yield name, obj
|
||||
for field in util.local_dataclass_fields(cls):
|
||||
if sa_dataclass_metadata_key in field.metadata:
|
||||
yield field.name, field.metadata[
|
||||
sa_dataclass_metadata_key
|
||||
]
|
||||
|
||||
return local_attributes_for_class
|
||||
|
||||
def _scan_attributes(self):
|
||||
cls = self.cls
|
||||
dict_ = self.dict_
|
||||
@@ -333,9 +421,9 @@ class _ClassScanMapperConfig(_MapperConfig):
|
||||
table_args = inherited_table_args = None
|
||||
tablename = None
|
||||
|
||||
for base in cls.__mro__:
|
||||
attribute_is_overridden = self._cls_attr_override_checker(self.cls)
|
||||
|
||||
sa_dataclass_metadata_key = None
|
||||
for base in cls.__mro__:
|
||||
|
||||
class_mapped = (
|
||||
base is not cls
|
||||
@@ -345,25 +433,14 @@ class _ClassScanMapperConfig(_MapperConfig):
|
||||
)
|
||||
)
|
||||
|
||||
if sa_dataclass_metadata_key is None:
|
||||
sa_dataclass_metadata_key = _get_immediate_cls_attr(
|
||||
base, "__sa_dataclass_metadata_key__", None
|
||||
)
|
||||
|
||||
def attributes_for_class(cls):
|
||||
for name, obj in vars(cls).items():
|
||||
yield name, obj
|
||||
if sa_dataclass_metadata_key:
|
||||
for field in util.dataclass_fields(cls):
|
||||
if sa_dataclass_metadata_key in field.metadata:
|
||||
yield field.name, field.metadata[
|
||||
sa_dataclass_metadata_key
|
||||
]
|
||||
local_attributes_for_class = self._cls_attr_resolver(base)
|
||||
|
||||
if not class_mapped and base is not cls:
|
||||
self._produce_column_copies(attributes_for_class, base)
|
||||
self._produce_column_copies(
|
||||
local_attributes_for_class, attribute_is_overridden
|
||||
)
|
||||
|
||||
for name, obj in attributes_for_class(base):
|
||||
for name, obj in local_attributes_for_class():
|
||||
if name == "__mapper_args__":
|
||||
check_decl = _check_declared_props_nocascade(
|
||||
obj, name, cls
|
||||
@@ -471,6 +548,15 @@ class _ClassScanMapperConfig(_MapperConfig):
|
||||
else:
|
||||
self._warn_for_decl_attributes(base, name, obj)
|
||||
elif name not in dict_ or dict_[name] is not obj:
|
||||
# here, we are definitely looking at the target class
|
||||
# and not a superclass. this is currently a
|
||||
# dataclass-only path. if the name is only
|
||||
# a dataclass field and isn't in local cls.__dict__,
|
||||
# put the object there.
|
||||
|
||||
# assert that the dataclass-enabled resolver agrees
|
||||
# with what we are seeing
|
||||
assert not attribute_is_overridden(name, obj)
|
||||
dict_[name] = obj
|
||||
|
||||
if inherited_table_args and not tablename:
|
||||
@@ -489,14 +575,17 @@ class _ClassScanMapperConfig(_MapperConfig):
|
||||
% (key, cls)
|
||||
)
|
||||
|
||||
def _produce_column_copies(self, attributes_for_class, base):
|
||||
def _produce_column_copies(
|
||||
self, attributes_for_class, attribute_is_overridden
|
||||
):
|
||||
cls = self.cls
|
||||
dict_ = self.dict_
|
||||
column_copies = self.column_copies
|
||||
# copy mixin columns to the mapped class
|
||||
for name, obj in attributes_for_class(base):
|
||||
|
||||
for name, obj in attributes_for_class():
|
||||
if isinstance(obj, Column):
|
||||
if getattr(cls, name) is not obj:
|
||||
if attribute_is_overridden(name, obj):
|
||||
# if column has been overridden
|
||||
# (like by the InstrumentedAttribute of the
|
||||
# superclass), skip
|
||||
|
||||
@@ -552,6 +552,7 @@ class DeclarativeMappedTest(MappedTest):
|
||||
metaclass=FindFixtureDeclarative,
|
||||
cls=DeclarativeBasic,
|
||||
)
|
||||
|
||||
cls.DeclarativeBasic = _DeclBase
|
||||
|
||||
# sets up cls.Basic which is helpful for things like composite
|
||||
|
||||
@@ -66,6 +66,7 @@ from .compat import int_types # noqa
|
||||
from .compat import iterbytes # noqa
|
||||
from .compat import itertools_filter # noqa
|
||||
from .compat import itertools_filterfalse # noqa
|
||||
from .compat import local_dataclass_fields # noqa
|
||||
from .compat import namedtuple # noqa
|
||||
from .compat import next # noqa
|
||||
from .compat import nullcontext # noqa
|
||||
|
||||
@@ -425,17 +425,37 @@ if py37:
|
||||
import dataclasses
|
||||
|
||||
def dataclass_fields(cls):
|
||||
"""Return a sequence of all dataclasses.Field objects associated
|
||||
with a class."""
|
||||
|
||||
if dataclasses.is_dataclass(cls):
|
||||
return dataclasses.fields(cls)
|
||||
else:
|
||||
return []
|
||||
|
||||
def local_dataclass_fields(cls):
|
||||
"""Return a sequence of all dataclasses.Field objects associated with
|
||||
a class, excluding those that originate from a superclass."""
|
||||
|
||||
if dataclasses.is_dataclass(cls):
|
||||
super_fields = set()
|
||||
for sup in cls.__bases__:
|
||||
super_fields.update(dataclass_fields(sup))
|
||||
return [
|
||||
f for f in dataclasses.fields(cls) if f not in super_fields
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
else:
|
||||
|
||||
def dataclass_fields(cls):
|
||||
return []
|
||||
|
||||
def local_dataclass_fields(cls):
|
||||
return []
|
||||
|
||||
|
||||
def raise_from_cause(exception, exc_info=None):
|
||||
r"""legacy. use raise\_()"""
|
||||
|
||||
@@ -6,12 +6,16 @@ from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import testing
|
||||
from sqlalchemy.orm import clear_mappers
|
||||
from sqlalchemy.orm import declared_attr
|
||||
from sqlalchemy.orm import mapper
|
||||
from sqlalchemy.orm import registry as declarative_registry
|
||||
from sqlalchemy.orm import registry
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.testing import eq_
|
||||
from sqlalchemy.testing import fixtures
|
||||
from sqlalchemy.testing.fixtures import fixture_session
|
||||
from sqlalchemy.testing.schema import Column
|
||||
from sqlalchemy.testing.schema import Table
|
||||
|
||||
@@ -171,14 +175,14 @@ class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
|
||||
assert Widget("Foo") != Widget("Bar")
|
||||
assert Widget("Foo") != SpecialWidget("Foo")
|
||||
|
||||
def test_asdict_and_astuple(self):
|
||||
def test_asdict_and_astuple_widget(self):
|
||||
Widget = self.classes.Widget
|
||||
SpecialWidget = self.classes.SpecialWidget
|
||||
|
||||
widget = Widget("Foo")
|
||||
eq_(dataclasses.asdict(widget), {"name": "Foo"})
|
||||
eq_(dataclasses.astuple(widget), ("Foo",))
|
||||
|
||||
def test_asdict_and_astuple_special_widget(self):
|
||||
SpecialWidget = self.classes.SpecialWidget
|
||||
widget = SpecialWidget("Bar", magic=True)
|
||||
eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True})
|
||||
eq_(dataclasses.astuple(widget), ("Bar", True))
|
||||
@@ -187,11 +191,11 @@ class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
|
||||
Account = self.classes.Account
|
||||
account = self.data_fixture()
|
||||
|
||||
with Session(testing.db) as session:
|
||||
with fixture_session() as session:
|
||||
session.add(account)
|
||||
session.commit()
|
||||
|
||||
with Session(testing.db) as session:
|
||||
with fixture_session() as session:
|
||||
a = session.query(Account).get(42)
|
||||
self.check_data_fixture(a)
|
||||
|
||||
@@ -373,14 +377,229 @@ class FieldEmbeddedDeclarativeDataclassesTest(
|
||||
def define_tables(cls, metadata):
|
||||
pass
|
||||
|
||||
def test_asdict_and_astuple(self):
|
||||
def test_asdict_and_astuple_widget(self):
|
||||
Widget = self.classes.Widget
|
||||
SpecialWidget = self.classes.SpecialWidget
|
||||
|
||||
widget = Widget("Foo")
|
||||
eq_(dataclasses.asdict(widget), {"name": "Foo"})
|
||||
eq_(dataclasses.astuple(widget), ("Foo",))
|
||||
|
||||
def test_asdict_and_astuple_special_widget(self):
|
||||
SpecialWidget = self.classes.SpecialWidget
|
||||
widget = SpecialWidget("Bar", magic=True)
|
||||
eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True})
|
||||
eq_(dataclasses.astuple(widget), ("Bar", True))
|
||||
|
||||
|
||||
class FieldEmbeddedWMixinTest(FieldEmbeddedDeclarativeDataclassesTest):
|
||||
__requires__ = ("dataclasses",)
|
||||
|
||||
@classmethod
|
||||
def setup_classes(cls):
|
||||
declarative = cls.DeclarativeBasic.registry.mapped
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SurrogateWidgetPK:
|
||||
|
||||
__sa_dataclass_metadata_key__ = "sa"
|
||||
|
||||
widget_id: int = dataclasses.field(
|
||||
init=False,
|
||||
metadata={"sa": Column(Integer, primary_key=True)},
|
||||
)
|
||||
|
||||
@declarative
|
||||
@dataclasses.dataclass
|
||||
class Widget(SurrogateWidgetPK):
|
||||
__tablename__ = "widgets"
|
||||
__sa_dataclass_metadata_key__ = "sa"
|
||||
|
||||
account_id = Column(
|
||||
Integer,
|
||||
ForeignKey("accounts.account_id"),
|
||||
nullable=False,
|
||||
)
|
||||
type = Column(String(30), nullable=False)
|
||||
|
||||
name: Optional[str] = dataclasses.field(
|
||||
default=None,
|
||||
metadata={"sa": Column(String(30), nullable=False)},
|
||||
)
|
||||
__mapper_args__ = dict(
|
||||
polymorphic_on="type",
|
||||
polymorphic_identity="normal",
|
||||
)
|
||||
|
||||
@declarative
|
||||
@dataclasses.dataclass
|
||||
class SpecialWidget(Widget):
|
||||
__sa_dataclass_metadata_key__ = "sa"
|
||||
|
||||
magic: bool = dataclasses.field(
|
||||
default=False, metadata={"sa": Column(Boolean)}
|
||||
)
|
||||
|
||||
__mapper_args__ = dict(
|
||||
polymorphic_identity="special",
|
||||
)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SurrogateAccountPK:
|
||||
|
||||
__sa_dataclass_metadata_key__ = "sa"
|
||||
|
||||
account_id = Column(
|
||||
"we_dont_want_to_use_this", Integer, primary_key=True
|
||||
)
|
||||
|
||||
@declarative
|
||||
@dataclasses.dataclass
|
||||
class Account(SurrogateAccountPK):
|
||||
__tablename__ = "accounts"
|
||||
__sa_dataclass_metadata_key__ = "sa"
|
||||
|
||||
account_id: int = dataclasses.field(
|
||||
metadata={"sa": Column(Integer, primary_key=True)},
|
||||
)
|
||||
widgets: List[Widget] = dataclasses.field(
|
||||
default_factory=list, metadata={"sa": relationship("Widget")}
|
||||
)
|
||||
widget_count: int = dataclasses.field(
|
||||
init=False,
|
||||
metadata={
|
||||
"sa": Column("widget_count", Integer, nullable=False)
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.widget_count = len(self.widgets)
|
||||
|
||||
def add_widget(self, widget: Widget):
|
||||
self.widgets.append(widget)
|
||||
self.widget_count += 1
|
||||
|
||||
cls.classes.Account = Account
|
||||
cls.classes.Widget = Widget
|
||||
cls.classes.SpecialWidget = SpecialWidget
|
||||
|
||||
def check_widget_dataclass(self, obj):
|
||||
assert dataclasses.is_dataclass(obj)
|
||||
(
|
||||
id_,
|
||||
name,
|
||||
) = dataclasses.fields(obj)
|
||||
eq_(name.name, "name")
|
||||
eq_(id_.name, "widget_id")
|
||||
|
||||
def check_special_widget_dataclass(self, obj):
|
||||
assert dataclasses.is_dataclass(obj)
|
||||
id_, name, magic = dataclasses.fields(obj)
|
||||
eq_(id_.name, "widget_id")
|
||||
eq_(name.name, "name")
|
||||
eq_(magic.name, "magic")
|
||||
|
||||
def test_asdict_and_astuple_widget(self):
|
||||
Widget = self.classes.Widget
|
||||
|
||||
widget = Widget("Foo")
|
||||
eq_(dataclasses.asdict(widget), {"name": "Foo", "widget_id": None})
|
||||
eq_(
|
||||
dataclasses.astuple(widget),
|
||||
(
|
||||
None,
|
||||
"Foo",
|
||||
),
|
||||
)
|
||||
|
||||
def test_asdict_and_astuple_special_widget(self):
|
||||
SpecialWidget = self.classes.SpecialWidget
|
||||
widget = SpecialWidget("Bar", magic=True)
|
||||
eq_(
|
||||
dataclasses.asdict(widget),
|
||||
{"name": "Bar", "magic": True, "widget_id": None},
|
||||
)
|
||||
eq_(dataclasses.astuple(widget), (None, "Bar", True))
|
||||
|
||||
|
||||
class PropagationBlockTest(fixtures.TestBase):
|
||||
__requires__ = ("dataclasses",)
|
||||
|
||||
run_setup_classes = "each"
|
||||
run_setup_mappers = "each"
|
||||
|
||||
def test_propagate_w_plain_mixin_col(self, run_test):
|
||||
@dataclasses.dataclass
|
||||
class CommonMixin:
|
||||
__sa_dataclass_metadata_key__ = "sa"
|
||||
|
||||
@declared_attr
|
||||
def __tablename__(cls):
|
||||
return cls.__name__.lower()
|
||||
|
||||
__table_args__ = {"mysql_engine": "InnoDB"}
|
||||
timestamp = Column(Integer)
|
||||
|
||||
run_test(CommonMixin)
|
||||
|
||||
def test_propagate_w_field_mixin_col(self, run_test):
|
||||
@dataclasses.dataclass
|
||||
class CommonMixin:
|
||||
__sa_dataclass_metadata_key__ = "sa"
|
||||
|
||||
@declared_attr
|
||||
def __tablename__(cls):
|
||||
return cls.__name__.lower()
|
||||
|
||||
__table_args__ = {"mysql_engine": "InnoDB"}
|
||||
|
||||
timestamp: int = dataclasses.field(
|
||||
init=False,
|
||||
metadata={"sa": Column(Integer, nullable=False)},
|
||||
)
|
||||
|
||||
run_test(CommonMixin)
|
||||
|
||||
@testing.fixture()
|
||||
def run_test(self):
|
||||
def go(CommonMixin):
|
||||
declarative = registry().mapped
|
||||
|
||||
@declarative
|
||||
@dataclasses.dataclass
|
||||
class BaseType(CommonMixin):
|
||||
|
||||
discriminator = Column("type", String(50))
|
||||
__mapper_args__ = dict(polymorphic_on=discriminator)
|
||||
id = Column(Integer, primary_key=True)
|
||||
value = Column(Integer())
|
||||
|
||||
@declarative
|
||||
@dataclasses.dataclass
|
||||
class Single(BaseType):
|
||||
|
||||
__tablename__ = None
|
||||
__mapper_args__ = dict(polymorphic_identity="type1")
|
||||
|
||||
@declarative
|
||||
@dataclasses.dataclass
|
||||
class Joined(BaseType):
|
||||
|
||||
__mapper_args__ = dict(polymorphic_identity="type2")
|
||||
id = Column(
|
||||
Integer, ForeignKey("basetype.id"), primary_key=True
|
||||
)
|
||||
|
||||
eq_(BaseType.__table__.name, "basetype")
|
||||
eq_(
|
||||
list(BaseType.__table__.c.keys()),
|
||||
["timestamp", "type", "id", "value"],
|
||||
)
|
||||
eq_(BaseType.__table__.kwargs, {"mysql_engine": "InnoDB"})
|
||||
assert Single.__table__ is BaseType.__table__
|
||||
eq_(Joined.__table__.name, "joined")
|
||||
eq_(list(Joined.__table__.c.keys()), ["id"])
|
||||
eq_(Joined.__table__.kwargs, {"mysql_engine": "InnoDB"})
|
||||
|
||||
yield go
|
||||
|
||||
clear_mappers()
|
||||
|
||||
Reference in New Issue
Block a user