Fill-out dataclass-related attr resolution

Fixed issue where mixin attribute rules were not taking
effect correctly for attributes pulled from dataclasses
using the approach added in #5745.

Fixes: #5876
Change-Id: I45099a42de1d9611791e72250fe0edc69bed684c
This commit is contained in:
Mike Bayer
2021-01-25 17:59:35 -05:00
parent 57db20a187
commit 9205e9171c
5 changed files with 358 additions and 28 deletions
+110 -21
View File
@@ -325,6 +325,94 @@ class _ClassScanMapperConfig(_MapperConfig):
def before_configured():
self.cls.__declare_first__()
def _cls_attr_override_checker(self, cls):
"""Produce a function that checks if a class has overridden an
attribute, taking SQLAlchemy-enabled dataclass fields into account.
"""
sa_dataclass_metadata_key = _get_immediate_cls_attr(
cls, "__sa_dataclass_metadata_key__", None
)
if sa_dataclass_metadata_key is None:
def attribute_is_overridden(key, obj):
return getattr(cls, key) is not obj
else:
all_datacls_fields = {
f.name: f.metadata[sa_dataclass_metadata_key]
for f in util.dataclass_fields(cls)
if sa_dataclass_metadata_key in f.metadata
}
local_datacls_fields = {
f.name: f.metadata[sa_dataclass_metadata_key]
for f in util.local_dataclass_fields(cls)
if sa_dataclass_metadata_key in f.metadata
}
absent = object()
def attribute_is_overridden(key, obj):
# this function likely has some failure modes still if
# someone is doing a deep mixing of the same attribute
# name as plain Python attribute vs. dataclass field.
ret = local_datacls_fields.get(key, absent)
if ret is obj:
return False
elif ret is not absent:
return True
ret = getattr(cls, key, obj)
if ret is obj:
return False
elif ret is not absent:
return True
ret = all_datacls_fields.get(key, absent)
if ret is obj:
return False
elif ret is not absent:
return True
# can't find another attribute
return False
return attribute_is_overridden
def _cls_attr_resolver(self, cls):
"""produce a function to iterate the "attributes" of a class,
adjusting for SQLAlchemy fields embedded in dataclass fields.
"""
sa_dataclass_metadata_key = _get_immediate_cls_attr(
cls, "__sa_dataclass_metadata_key__", None
)
if sa_dataclass_metadata_key is None:
def local_attributes_for_class():
for name, obj in vars(cls).items():
yield name, obj
else:
def local_attributes_for_class():
for name, obj in vars(cls).items():
yield name, obj
for field in util.local_dataclass_fields(cls):
if sa_dataclass_metadata_key in field.metadata:
yield field.name, field.metadata[
sa_dataclass_metadata_key
]
return local_attributes_for_class
def _scan_attributes(self):
cls = self.cls
dict_ = self.dict_
@@ -333,9 +421,9 @@ class _ClassScanMapperConfig(_MapperConfig):
table_args = inherited_table_args = None
tablename = None
for base in cls.__mro__:
attribute_is_overridden = self._cls_attr_override_checker(self.cls)
sa_dataclass_metadata_key = None
for base in cls.__mro__:
class_mapped = (
base is not cls
@@ -345,25 +433,14 @@ class _ClassScanMapperConfig(_MapperConfig):
)
)
if sa_dataclass_metadata_key is None:
sa_dataclass_metadata_key = _get_immediate_cls_attr(
base, "__sa_dataclass_metadata_key__", None
)
def attributes_for_class(cls):
for name, obj in vars(cls).items():
yield name, obj
if sa_dataclass_metadata_key:
for field in util.dataclass_fields(cls):
if sa_dataclass_metadata_key in field.metadata:
yield field.name, field.metadata[
sa_dataclass_metadata_key
]
local_attributes_for_class = self._cls_attr_resolver(base)
if not class_mapped and base is not cls:
self._produce_column_copies(attributes_for_class, base)
self._produce_column_copies(
local_attributes_for_class, attribute_is_overridden
)
for name, obj in attributes_for_class(base):
for name, obj in local_attributes_for_class():
if name == "__mapper_args__":
check_decl = _check_declared_props_nocascade(
obj, name, cls
@@ -471,6 +548,15 @@ class _ClassScanMapperConfig(_MapperConfig):
else:
self._warn_for_decl_attributes(base, name, obj)
elif name not in dict_ or dict_[name] is not obj:
# here, we are definitely looking at the target class
# and not a superclass. this is currently a
# dataclass-only path. if the name is only
# a dataclass field and isn't in local cls.__dict__,
# put the object there.
# assert that the dataclass-enabled resolver agrees
# with what we are seeing
assert not attribute_is_overridden(name, obj)
dict_[name] = obj
if inherited_table_args and not tablename:
@@ -489,14 +575,17 @@ class _ClassScanMapperConfig(_MapperConfig):
% (key, cls)
)
def _produce_column_copies(self, attributes_for_class, base):
def _produce_column_copies(
self, attributes_for_class, attribute_is_overridden
):
cls = self.cls
dict_ = self.dict_
column_copies = self.column_copies
# copy mixin columns to the mapped class
for name, obj in attributes_for_class(base):
for name, obj in attributes_for_class():
if isinstance(obj, Column):
if getattr(cls, name) is not obj:
if attribute_is_overridden(name, obj):
# if column has been overridden
# (like by the InstrumentedAttribute of the
# superclass), skip
+1
View File
@@ -552,6 +552,7 @@ class DeclarativeMappedTest(MappedTest):
metaclass=FindFixtureDeclarative,
cls=DeclarativeBasic,
)
cls.DeclarativeBasic = _DeclBase
# sets up cls.Basic which is helpful for things like composite
+1
View File
@@ -66,6 +66,7 @@ from .compat import int_types # noqa
from .compat import iterbytes # noqa
from .compat import itertools_filter # noqa
from .compat import itertools_filterfalse # noqa
from .compat import local_dataclass_fields # noqa
from .compat import namedtuple # noqa
from .compat import next # noqa
from .compat import nullcontext # noqa
+20
View File
@@ -425,17 +425,37 @@ if py37:
import dataclasses
def dataclass_fields(cls):
"""Return a sequence of all dataclasses.Field objects associated
with a class."""
if dataclasses.is_dataclass(cls):
return dataclasses.fields(cls)
else:
return []
def local_dataclass_fields(cls):
"""Return a sequence of all dataclasses.Field objects associated with
a class, excluding those that originate from a superclass."""
if dataclasses.is_dataclass(cls):
super_fields = set()
for sup in cls.__bases__:
super_fields.update(dataclass_fields(sup))
return [
f for f in dataclasses.fields(cls) if f not in super_fields
]
else:
return []
else:
def dataclass_fields(cls):
return []
def local_dataclass_fields(cls):
return []
def raise_from_cause(exception, exc_info=None):
r"""legacy. use raise\_()"""
+226 -7
View File
@@ -6,12 +6,16 @@ from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import testing
from sqlalchemy.orm import clear_mappers
from sqlalchemy.orm import declared_attr
from sqlalchemy.orm import mapper
from sqlalchemy.orm import registry as declarative_registry
from sqlalchemy.orm import registry
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
@@ -171,14 +175,14 @@ class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
assert Widget("Foo") != Widget("Bar")
assert Widget("Foo") != SpecialWidget("Foo")
def test_asdict_and_astuple(self):
def test_asdict_and_astuple_widget(self):
Widget = self.classes.Widget
SpecialWidget = self.classes.SpecialWidget
widget = Widget("Foo")
eq_(dataclasses.asdict(widget), {"name": "Foo"})
eq_(dataclasses.astuple(widget), ("Foo",))
def test_asdict_and_astuple_special_widget(self):
SpecialWidget = self.classes.SpecialWidget
widget = SpecialWidget("Bar", magic=True)
eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True})
eq_(dataclasses.astuple(widget), ("Bar", True))
@@ -187,11 +191,11 @@ class DataclassesTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
Account = self.classes.Account
account = self.data_fixture()
with Session(testing.db) as session:
with fixture_session() as session:
session.add(account)
session.commit()
with Session(testing.db) as session:
with fixture_session() as session:
a = session.query(Account).get(42)
self.check_data_fixture(a)
@@ -373,14 +377,229 @@ class FieldEmbeddedDeclarativeDataclassesTest(
def define_tables(cls, metadata):
pass
def test_asdict_and_astuple(self):
def test_asdict_and_astuple_widget(self):
Widget = self.classes.Widget
SpecialWidget = self.classes.SpecialWidget
widget = Widget("Foo")
eq_(dataclasses.asdict(widget), {"name": "Foo"})
eq_(dataclasses.astuple(widget), ("Foo",))
def test_asdict_and_astuple_special_widget(self):
SpecialWidget = self.classes.SpecialWidget
widget = SpecialWidget("Bar", magic=True)
eq_(dataclasses.asdict(widget), {"name": "Bar", "magic": True})
eq_(dataclasses.astuple(widget), ("Bar", True))
class FieldEmbeddedWMixinTest(FieldEmbeddedDeclarativeDataclassesTest):
__requires__ = ("dataclasses",)
@classmethod
def setup_classes(cls):
declarative = cls.DeclarativeBasic.registry.mapped
@dataclasses.dataclass
class SurrogateWidgetPK:
__sa_dataclass_metadata_key__ = "sa"
widget_id: int = dataclasses.field(
init=False,
metadata={"sa": Column(Integer, primary_key=True)},
)
@declarative
@dataclasses.dataclass
class Widget(SurrogateWidgetPK):
__tablename__ = "widgets"
__sa_dataclass_metadata_key__ = "sa"
account_id = Column(
Integer,
ForeignKey("accounts.account_id"),
nullable=False,
)
type = Column(String(30), nullable=False)
name: Optional[str] = dataclasses.field(
default=None,
metadata={"sa": Column(String(30), nullable=False)},
)
__mapper_args__ = dict(
polymorphic_on="type",
polymorphic_identity="normal",
)
@declarative
@dataclasses.dataclass
class SpecialWidget(Widget):
__sa_dataclass_metadata_key__ = "sa"
magic: bool = dataclasses.field(
default=False, metadata={"sa": Column(Boolean)}
)
__mapper_args__ = dict(
polymorphic_identity="special",
)
@dataclasses.dataclass
class SurrogateAccountPK:
__sa_dataclass_metadata_key__ = "sa"
account_id = Column(
"we_dont_want_to_use_this", Integer, primary_key=True
)
@declarative
@dataclasses.dataclass
class Account(SurrogateAccountPK):
__tablename__ = "accounts"
__sa_dataclass_metadata_key__ = "sa"
account_id: int = dataclasses.field(
metadata={"sa": Column(Integer, primary_key=True)},
)
widgets: List[Widget] = dataclasses.field(
default_factory=list, metadata={"sa": relationship("Widget")}
)
widget_count: int = dataclasses.field(
init=False,
metadata={
"sa": Column("widget_count", Integer, nullable=False)
},
)
def __post_init__(self):
self.widget_count = len(self.widgets)
def add_widget(self, widget: Widget):
self.widgets.append(widget)
self.widget_count += 1
cls.classes.Account = Account
cls.classes.Widget = Widget
cls.classes.SpecialWidget = SpecialWidget
def check_widget_dataclass(self, obj):
assert dataclasses.is_dataclass(obj)
(
id_,
name,
) = dataclasses.fields(obj)
eq_(name.name, "name")
eq_(id_.name, "widget_id")
def check_special_widget_dataclass(self, obj):
assert dataclasses.is_dataclass(obj)
id_, name, magic = dataclasses.fields(obj)
eq_(id_.name, "widget_id")
eq_(name.name, "name")
eq_(magic.name, "magic")
def test_asdict_and_astuple_widget(self):
Widget = self.classes.Widget
widget = Widget("Foo")
eq_(dataclasses.asdict(widget), {"name": "Foo", "widget_id": None})
eq_(
dataclasses.astuple(widget),
(
None,
"Foo",
),
)
def test_asdict_and_astuple_special_widget(self):
SpecialWidget = self.classes.SpecialWidget
widget = SpecialWidget("Bar", magic=True)
eq_(
dataclasses.asdict(widget),
{"name": "Bar", "magic": True, "widget_id": None},
)
eq_(dataclasses.astuple(widget), (None, "Bar", True))
class PropagationBlockTest(fixtures.TestBase):
__requires__ = ("dataclasses",)
run_setup_classes = "each"
run_setup_mappers = "each"
def test_propagate_w_plain_mixin_col(self, run_test):
@dataclasses.dataclass
class CommonMixin:
__sa_dataclass_metadata_key__ = "sa"
@declared_attr
def __tablename__(cls):
return cls.__name__.lower()
__table_args__ = {"mysql_engine": "InnoDB"}
timestamp = Column(Integer)
run_test(CommonMixin)
def test_propagate_w_field_mixin_col(self, run_test):
@dataclasses.dataclass
class CommonMixin:
__sa_dataclass_metadata_key__ = "sa"
@declared_attr
def __tablename__(cls):
return cls.__name__.lower()
__table_args__ = {"mysql_engine": "InnoDB"}
timestamp: int = dataclasses.field(
init=False,
metadata={"sa": Column(Integer, nullable=False)},
)
run_test(CommonMixin)
@testing.fixture()
def run_test(self):
def go(CommonMixin):
declarative = registry().mapped
@declarative
@dataclasses.dataclass
class BaseType(CommonMixin):
discriminator = Column("type", String(50))
__mapper_args__ = dict(polymorphic_on=discriminator)
id = Column(Integer, primary_key=True)
value = Column(Integer())
@declarative
@dataclasses.dataclass
class Single(BaseType):
__tablename__ = None
__mapper_args__ = dict(polymorphic_identity="type1")
@declarative
@dataclasses.dataclass
class Joined(BaseType):
__mapper_args__ = dict(polymorphic_identity="type2")
id = Column(
Integer, ForeignKey("basetype.id"), primary_key=True
)
eq_(BaseType.__table__.name, "basetype")
eq_(
list(BaseType.__table__.c.keys()),
["timestamp", "type", "id", "value"],
)
eq_(BaseType.__table__.kwargs, {"mysql_engine": "InnoDB"})
assert Single.__table__ is BaseType.__table__
eq_(Joined.__table__.name, "joined")
eq_(list(Joined.__table__.c.keys()), ["id"])
eq_(Joined.__table__.kwargs, {"mysql_engine": "InnoDB"})
yield go
clear_mappers()