mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-07 09:20:58 -04:00
02c9d3ab43
Reworked the handling of classes which extend from MappedAsDataclass but are not themselves mapped, i.e. the declarative base as well as any mixins or abstract classes. These classes as before are turned into real dataclasses, however a scan now takes place across the mapped elements such as mapped_column(), relationship(), etc. so that we may also take into account dataclasses.field-specific parameters like init=False, repr, etc. The main use case for this is so that mixin dataclasses may make use of "default" in fields while not being rejected by the dataclasses constructor. The generated classes are more functional as dataclasses in a standalone fashion as well, even though this is not their intended use. As a standalone dataclass, the one feature that does not work is a field that has a default with init=False, because we still need to have a mapped_column() or similar present at the class level for the class to work as a superclass. The change also addes the :func:`_orm.unmapped_dataclass` decorator function, which may be used to create unmapped superclasses in a mapped hierarchy that is using the :func:`_orm.mapped_dataclass` decorator to create mapped dataclasses. Previously there was no way to use unmapped dataclass mixins with the decorator approach. Finally, the warning added in 2.0 for 🎫`9350` is turned into an error as mentioned for 2.1, since we're deep into dataclass hierarchy changes here. Fixes: #12854 Change-Id: I11cd8c628d49e9ff1bdbda8a09f4112b40d84be7
586 lines
18 KiB
Python
586 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
from sqlalchemy import DDL
|
|
from sqlalchemy import event
|
|
from sqlalchemy import exc
|
|
from sqlalchemy import ForeignKey
|
|
from sqlalchemy import Integer
|
|
from sqlalchemy import select
|
|
from sqlalchemy import String
|
|
from sqlalchemy import testing
|
|
from sqlalchemy.ext.declarative import DeferredReflection
|
|
from sqlalchemy.orm import clear_mappers
|
|
from sqlalchemy.orm import decl_api as decl
|
|
from sqlalchemy.orm import declared_attr
|
|
from sqlalchemy.orm import exc as orm_exc
|
|
from sqlalchemy.orm import Mapped
|
|
from sqlalchemy.orm import mapped_column
|
|
from sqlalchemy.orm import relationship
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.orm.decl_base import _DeferredDeclarativeConfig
|
|
from sqlalchemy.testing import assert_raises_message
|
|
from sqlalchemy.testing import eq_
|
|
from sqlalchemy.testing import expect_raises_message
|
|
from sqlalchemy.testing import fixtures
|
|
from sqlalchemy.testing.entities import ComparableEntity
|
|
from sqlalchemy.testing.fixtures import fixture_session
|
|
from sqlalchemy.testing.schema import Column
|
|
from sqlalchemy.testing.schema import Table
|
|
from sqlalchemy.testing.util import gc_collect
|
|
|
|
|
|
class DeclarativeReflectionBase(fixtures.TablesTest):
|
|
__requires__ = ("reflectable_autoincrement",)
|
|
|
|
def setup_test(self):
|
|
global Base, registry
|
|
|
|
_DeferredDeclarativeConfig._configs.clear()
|
|
|
|
registry = decl.registry()
|
|
Base = registry.generate_base()
|
|
|
|
def teardown_test(self):
|
|
clear_mappers()
|
|
|
|
@testing.fixture
|
|
def decl_base(self):
|
|
yield Base
|
|
|
|
|
|
class DeferredReflectBase(DeclarativeReflectionBase):
|
|
def teardown_test(self):
|
|
super().teardown_test()
|
|
_DeferredDeclarativeConfig._configs.clear()
|
|
|
|
|
|
Base = None
|
|
|
|
|
|
class DeferredReflectPKFKTest(DeferredReflectBase):
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
Table(
|
|
"a",
|
|
metadata,
|
|
Column(
|
|
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
|
),
|
|
)
|
|
Table(
|
|
"b",
|
|
metadata,
|
|
Column("id", Integer, ForeignKey("a.id"), primary_key=True),
|
|
Column("x", Integer, primary_key=True),
|
|
)
|
|
|
|
def test_pk_fk(self):
|
|
class B(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "b"
|
|
a = relationship("A")
|
|
|
|
class A(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "a"
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
|
|
|
|
class DeferredReflectionTest(testing.AssertsCompiledSQL, DeferredReflectBase):
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
Table(
|
|
"users",
|
|
metadata,
|
|
Column(
|
|
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
|
),
|
|
Column("name", String(50)),
|
|
test_needs_fk=True,
|
|
)
|
|
Table(
|
|
"addresses",
|
|
metadata,
|
|
Column(
|
|
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
|
),
|
|
Column("email", String(50)),
|
|
Column("user_id", Integer, ForeignKey("users.id")),
|
|
test_needs_fk=True,
|
|
)
|
|
|
|
def _roundtrip(self):
|
|
User = Base.registry._class_registry["User"]
|
|
Address = Base.registry._class_registry["Address"]
|
|
|
|
u1 = User(
|
|
name="u1", addresses=[Address(email="one"), Address(email="two")]
|
|
)
|
|
with fixture_session() as sess:
|
|
sess.add(u1)
|
|
sess.commit()
|
|
|
|
with fixture_session() as sess:
|
|
eq_(
|
|
sess.query(User).all(),
|
|
[
|
|
User(
|
|
name="u1",
|
|
addresses=[Address(email="one"), Address(email="two")],
|
|
)
|
|
],
|
|
)
|
|
a1 = sess.query(Address).filter(Address.email == "two").one()
|
|
eq_(a1, Address(email="two"))
|
|
eq_(a1.user, User(name="u1"))
|
|
|
|
def test_exception_prepare_not_called(self):
|
|
class User(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "users"
|
|
addresses = relationship("Address", backref="user")
|
|
|
|
class Address(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "addresses"
|
|
|
|
assert_raises_message(
|
|
orm_exc.UnmappedClassError,
|
|
"Class test.ext.declarative.test_reflection.User is a "
|
|
"subclass of DeferredReflection. Mappings are not produced "
|
|
r"until the .prepare\(\) method is called on the class "
|
|
"hierarchy.",
|
|
Session().query,
|
|
User,
|
|
)
|
|
|
|
@testing.variation("bind", ["engine", "connection", "raise_"])
|
|
def test_basic_deferred(self, bind):
|
|
class User(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "users"
|
|
addresses = relationship("Address", backref="user")
|
|
|
|
class Address(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "addresses"
|
|
|
|
if bind.engine:
|
|
DeferredReflection.prepare(testing.db)
|
|
elif bind.connection:
|
|
with testing.db.connect() as conn:
|
|
DeferredReflection.prepare(conn)
|
|
elif bind.raise_:
|
|
with expect_raises_message(
|
|
exc.ArgumentError, "Expected Engine or Connection, got 'foo'"
|
|
):
|
|
DeferredReflection.prepare("foo")
|
|
return
|
|
else:
|
|
bind.fail()
|
|
|
|
self._roundtrip()
|
|
|
|
@testing.requires.view_reflection
|
|
@testing.variation("include_views", [True, False])
|
|
def test_views(self, metadata, connection, include_views):
|
|
Table(
|
|
"test_table",
|
|
metadata,
|
|
Column("id", Integer, primary_key=True),
|
|
Column("data", String(50)),
|
|
)
|
|
query = "CREATE VIEW view_name AS SELECT id, data FROM test_table"
|
|
|
|
event.listen(metadata, "after_create", DDL(query))
|
|
event.listen(
|
|
metadata, "before_drop", DDL("DROP VIEW IF EXISTS view_name")
|
|
)
|
|
metadata.create_all(connection)
|
|
|
|
class ViewName(DeferredReflection, Base):
|
|
__tablename__ = "view_name"
|
|
|
|
id = Column(Integer, primary_key=True)
|
|
|
|
if include_views:
|
|
DeferredReflection.prepare(connection, views=True)
|
|
else:
|
|
with expect_raises_message(
|
|
exc.InvalidRequestError, r"Could not reflect: .*view_name"
|
|
):
|
|
DeferredReflection.prepare(connection)
|
|
return
|
|
|
|
self.assert_compile(
|
|
select(ViewName),
|
|
"SELECT view_name.id, view_name.data FROM view_name",
|
|
)
|
|
|
|
def test_abstract_base(self):
|
|
class DefBase(DeferredReflection, Base):
|
|
__abstract__ = True
|
|
|
|
class OtherDefBase(DeferredReflection, Base):
|
|
__abstract__ = True
|
|
|
|
class User(ComparableEntity, DefBase):
|
|
__tablename__ = "users"
|
|
addresses = relationship("Address", backref="user")
|
|
|
|
class Address(ComparableEntity, DefBase):
|
|
__tablename__ = "addresses"
|
|
|
|
class Fake(OtherDefBase):
|
|
__tablename__ = "nonexistent"
|
|
|
|
DefBase.prepare(testing.db)
|
|
self._roundtrip()
|
|
|
|
def test_redefine_fk_double(self):
|
|
class User(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "users"
|
|
addresses = relationship("Address", backref="user")
|
|
|
|
class Address(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "addresses"
|
|
user_id = Column(Integer, ForeignKey("users.id"))
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
self._roundtrip()
|
|
|
|
def test_mapper_args_deferred(self):
|
|
"""test that __mapper_args__ is not called until *after*
|
|
table reflection"""
|
|
|
|
class User(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "users"
|
|
|
|
@declared_attr
|
|
def __mapper_args__(cls):
|
|
return {"primary_key": cls.__table__.c.id}
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
with fixture_session() as sess:
|
|
sess.add_all(
|
|
[
|
|
User(name="G"),
|
|
User(name="Q"),
|
|
User(name="A"),
|
|
User(name="C"),
|
|
]
|
|
)
|
|
sess.commit()
|
|
eq_(
|
|
sess.query(User).order_by(User.name).all(),
|
|
[
|
|
User(name="A"),
|
|
User(name="C"),
|
|
User(name="G"),
|
|
User(name="Q"),
|
|
],
|
|
)
|
|
|
|
@testing.requires.predictable_gc
|
|
def test_cls_not_strong_ref(self):
|
|
class User(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "users"
|
|
|
|
class Address(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "addresses"
|
|
|
|
eq_(len(_DeferredDeclarativeConfig._configs), 2)
|
|
del Address
|
|
gc_collect()
|
|
gc_collect()
|
|
eq_(len(_DeferredDeclarativeConfig._configs), 1)
|
|
DeferredReflection.prepare(testing.db)
|
|
gc_collect()
|
|
assert not _DeferredDeclarativeConfig._configs
|
|
|
|
|
|
class DeferredSecondaryReflectionTest(DeferredReflectBase):
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
Table(
|
|
"users",
|
|
metadata,
|
|
Column(
|
|
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
|
),
|
|
Column("name", String(50)),
|
|
test_needs_fk=True,
|
|
)
|
|
|
|
Table(
|
|
"user_items",
|
|
metadata,
|
|
Column("user_id", ForeignKey("users.id"), primary_key=True),
|
|
Column("item_id", ForeignKey("items.id"), primary_key=True),
|
|
test_needs_fk=True,
|
|
)
|
|
|
|
Table(
|
|
"items",
|
|
metadata,
|
|
Column(
|
|
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
|
),
|
|
Column("name", String(50)),
|
|
test_needs_fk=True,
|
|
)
|
|
|
|
def _roundtrip(self):
|
|
User = Base.registry._class_registry["User"]
|
|
Item = Base.registry._class_registry["Item"]
|
|
|
|
u1 = User(name="u1", items=[Item(name="i1"), Item(name="i2")])
|
|
|
|
with fixture_session() as sess:
|
|
sess.add(u1)
|
|
sess.commit()
|
|
|
|
eq_(
|
|
sess.query(User).all(),
|
|
[User(name="u1", items=[Item(name="i1"), Item(name="i2")])],
|
|
)
|
|
|
|
def test_string_resolution(self):
|
|
class User(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "users"
|
|
|
|
items = relationship("Item", secondary="user_items")
|
|
|
|
class Item(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "items"
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
self._roundtrip()
|
|
|
|
def test_table_resolution(self):
|
|
class User(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "users"
|
|
|
|
items = relationship(
|
|
"Item", secondary=Table("user_items", Base.metadata)
|
|
)
|
|
|
|
class Item(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "items"
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
self._roundtrip()
|
|
|
|
|
|
class DeferredInhReflectBase(DeferredReflectBase):
|
|
def _roundtrip(self):
|
|
Foo = Base.registry._class_registry["Foo"]
|
|
Bar = Base.registry._class_registry["Bar"]
|
|
|
|
with fixture_session() as s:
|
|
s.add_all(
|
|
[
|
|
Bar(data="d1", bar_data="b1"),
|
|
Bar(data="d2", bar_data="b2"),
|
|
Bar(data="d3", bar_data="b3"),
|
|
Foo(data="d4"),
|
|
]
|
|
)
|
|
s.commit()
|
|
|
|
eq_(
|
|
s.query(Foo).order_by(Foo.id).all(),
|
|
[
|
|
Bar(data="d1", bar_data="b1"),
|
|
Bar(data="d2", bar_data="b2"),
|
|
Bar(data="d3", bar_data="b3"),
|
|
Foo(data="d4"),
|
|
],
|
|
)
|
|
|
|
|
|
class DeferredSingleInhReflectionTest(DeferredInhReflectBase):
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
Table(
|
|
"foo",
|
|
metadata,
|
|
Column(
|
|
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
|
),
|
|
Column("type", String(32)),
|
|
Column("data", String(30)),
|
|
Column("bar_data", String(30)),
|
|
)
|
|
|
|
def test_basic(self, decl_base):
|
|
class Foo(DeferredReflection, ComparableEntity, decl_base):
|
|
__tablename__ = "foo"
|
|
__mapper_args__ = {
|
|
"polymorphic_on": "type",
|
|
"polymorphic_identity": "foo",
|
|
}
|
|
|
|
class Bar(Foo):
|
|
__mapper_args__ = {"polymorphic_identity": "bar"}
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
self._roundtrip()
|
|
|
|
def test_add_subclass_column(self, decl_base):
|
|
class Foo(DeferredReflection, ComparableEntity, decl_base):
|
|
__tablename__ = "foo"
|
|
__mapper_args__ = {
|
|
"polymorphic_on": "type",
|
|
"polymorphic_identity": "foo",
|
|
}
|
|
|
|
class Bar(Foo):
|
|
__mapper_args__ = {"polymorphic_identity": "bar"}
|
|
bar_data = Column(String(30))
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
self._roundtrip()
|
|
|
|
def test_add_subclass_mapped_column(self, decl_base):
|
|
class Foo(DeferredReflection, ComparableEntity, decl_base):
|
|
__tablename__ = "foo"
|
|
__mapper_args__ = {
|
|
"polymorphic_on": "type",
|
|
"polymorphic_identity": "foo",
|
|
}
|
|
|
|
class Bar(Foo):
|
|
__mapper_args__ = {"polymorphic_identity": "bar"}
|
|
bar_data: Mapped[str]
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
self._roundtrip()
|
|
|
|
def test_subclass_mapped_column_no_existing(self, decl_base):
|
|
class Foo(DeferredReflection, ComparableEntity, decl_base):
|
|
__tablename__ = "foo"
|
|
__mapper_args__ = {
|
|
"polymorphic_on": "type",
|
|
"polymorphic_identity": "foo",
|
|
}
|
|
|
|
with expect_raises_message(
|
|
exc.ArgumentError,
|
|
"Can't use use_existing_column with deferred mappers",
|
|
):
|
|
|
|
class Bar(Foo):
|
|
__mapper_args__ = {"polymorphic_identity": "bar"}
|
|
bar_data: Mapped[str] = mapped_column(use_existing_column=True)
|
|
|
|
def test_add_pk_column(self, decl_base):
|
|
class Foo(DeferredReflection, ComparableEntity, decl_base):
|
|
__tablename__ = "foo"
|
|
__mapper_args__ = {
|
|
"polymorphic_on": "type",
|
|
"polymorphic_identity": "foo",
|
|
}
|
|
id = Column(Integer, primary_key=True)
|
|
|
|
class Bar(Foo):
|
|
__mapper_args__ = {"polymorphic_identity": "bar"}
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
self._roundtrip()
|
|
|
|
def test_add_pk_mapped_column(self, decl_base):
|
|
class Foo(DeferredReflection, ComparableEntity, decl_base):
|
|
__tablename__ = "foo"
|
|
__mapper_args__ = {
|
|
"polymorphic_on": "type",
|
|
"polymorphic_identity": "foo",
|
|
}
|
|
id: Mapped[int] = mapped_column(primary_key=True)
|
|
|
|
class Bar(Foo):
|
|
__mapper_args__ = {"polymorphic_identity": "bar"}
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
self._roundtrip()
|
|
|
|
|
|
class DeferredJoinedInhReflectionTest(DeferredInhReflectBase):
|
|
@classmethod
|
|
def define_tables(cls, metadata):
|
|
Table(
|
|
"foo",
|
|
metadata,
|
|
Column(
|
|
"id", Integer, primary_key=True, test_needs_autoincrement=True
|
|
),
|
|
Column("type", String(32)),
|
|
Column("data", String(30)),
|
|
test_needs_fk=True,
|
|
)
|
|
Table(
|
|
"bar",
|
|
metadata,
|
|
Column("id", Integer, ForeignKey("foo.id"), primary_key=True),
|
|
Column("bar_data", String(30)),
|
|
test_needs_fk=True,
|
|
)
|
|
|
|
def test_basic(self):
|
|
class Foo(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "foo"
|
|
__mapper_args__ = {
|
|
"polymorphic_on": "type",
|
|
"polymorphic_identity": "foo",
|
|
}
|
|
|
|
class Bar(Foo):
|
|
__tablename__ = "bar"
|
|
__mapper_args__ = {"polymorphic_identity": "bar"}
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
self._roundtrip()
|
|
|
|
def test_add_subclass_column(self):
|
|
class Foo(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "foo"
|
|
__mapper_args__ = {
|
|
"polymorphic_on": "type",
|
|
"polymorphic_identity": "foo",
|
|
}
|
|
|
|
class Bar(Foo):
|
|
__tablename__ = "bar"
|
|
__mapper_args__ = {"polymorphic_identity": "bar"}
|
|
bar_data = Column(String(30))
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
self._roundtrip()
|
|
|
|
def test_add_pk_column(self):
|
|
class Foo(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "foo"
|
|
__mapper_args__ = {
|
|
"polymorphic_on": "type",
|
|
"polymorphic_identity": "foo",
|
|
}
|
|
id = Column(Integer, primary_key=True)
|
|
|
|
class Bar(Foo):
|
|
__tablename__ = "bar"
|
|
__mapper_args__ = {"polymorphic_identity": "bar"}
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
self._roundtrip()
|
|
|
|
def test_add_fk_pk_column(self):
|
|
class Foo(DeferredReflection, ComparableEntity, Base):
|
|
__tablename__ = "foo"
|
|
__mapper_args__ = {
|
|
"polymorphic_on": "type",
|
|
"polymorphic_identity": "foo",
|
|
}
|
|
|
|
class Bar(Foo):
|
|
__tablename__ = "bar"
|
|
__mapper_args__ = {"polymorphic_identity": "bar"}
|
|
id = Column(Integer, ForeignKey("foo.id"), primary_key=True)
|
|
|
|
DeferredReflection.prepare(testing.db)
|
|
self._roundtrip()
|