diff --git a/doc/build/changelog/unreleased_20/11176.rst b/doc/build/changelog/unreleased_20/11176.rst new file mode 100644 index 0000000000..cc35ab1d54 --- /dev/null +++ b/doc/build/changelog/unreleased_20/11176.rst @@ -0,0 +1,12 @@ +.. change:: + :tag: bug, sql, regression + :tickets: 11176 + + Fixed regression from the 1.4 series where the refactor of the + :meth:`_types.TypeEngine.with_variant` method introduced at + :ref:`change_6980` failed to accommodate for the ``.copy()`` method, which + will lose the variant mappings that are set up. This becomes an issue for + the very specific case of a "schema" type, which includes types such as + :class:`.Enum` and :class:`.ARRAY`, when they are then used in the context + of an ORM Declarative mapping with mixins where copying of types comes into + play. The variant mapping is now copied as well. diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index b638d6e265..38f96780c2 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -1005,9 +1005,11 @@ class TypeEngine(Visitable, Generic[_T]): types with "implementation" types that are specific to a particular dialect. """ - return util.constructor_copy( + typ = util.constructor_copy( self, cast(Type[TypeEngine[Any]], cls), **kw ) + typ._variant_mapping = self._variant_mapping + return typ def coerce_compared_value( self, op: Optional[OperatorType], value: Any diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 898d6fa0a8..0127004438 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1695,6 +1695,19 @@ class VariantTest(fixtures.TestBase, AssertsCompiledSQL): ) self.composite = self.variant.with_variant(self.UTypeThree(), "mysql") + def test_copy_doesnt_lose_variants(self): + """test #11176""" + + v = self.UTypeOne().with_variant(self.UTypeTwo(), "postgresql") + + v_c = v.copy() + + self.assert_compile(v_c, "UTYPEONE", dialect="default") + + self.assert_compile( + v_c, "UTYPETWO", dialect=dialects.postgresql.dialect() + ) + def test_one_dialect_is_req(self): with expect_raises_message( exc.ArgumentError, "At least one dialect name is required"