Detection of PEP 604 union syntax.

### Description

Fixes #8478

Handle `UnionType` as arguments to `Mapped`, e.g., `Mapped[str | None]`:

- adds `utils.typing.is_optional_union()` used to detect if a column should be nullable.
- adds `"UnionType"` to `utils.typing.is_optional()` names.
- uses `get_origin()` in `utils.typing.is_origin_of()` as `UnionType` has no `__origin__` attribute.
- tests with runtime type and postponed annotations and guard the tests running with `compat.py310`.

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical error fix
	- Good to go, no issue or tests are needed
- [x] A short code fix
	- please include the issue number, and create an issue if none exists, which
	  must include a complete example of the issue.  one line code fixes without an
	  issue and demonstration will not be accepted.
	- Please include: `Fixes: #<issue number>` in the commit message
	- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
	- please include the issue number, and create an issue if none exists, which must
	  include a complete example of how the feature would look.
	- Please include: `Fixes: #<issue number>` in the commit message
	- please include tests.

**Have a nice day!**

Closes: #8479
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8479
Pull-request-sha: 1241765482

Change-Id: Ib3248043dd4a97324ac592c048385006536b2d49
This commit is contained in:
Peter Schutt
2022-09-01 19:11:40 -04:00
committed by sqla-tester
parent d3e0b8e750
commit c3cfee5b00
3 changed files with 38 additions and 8 deletions
+4 -6
View File
@@ -52,8 +52,8 @@ from ..sql.schema import SchemaConst
from ..util.typing import de_optionalize_union_types
from ..util.typing import de_stringify_annotation
from ..util.typing import is_fwd_ref
from ..util.typing import is_optional_union
from ..util.typing import is_pep593
from ..util.typing import NoneType
from ..util.typing import Self
from ..util.typing import typing_get_args
@@ -652,17 +652,15 @@ class MappedColumn(
) -> None:
sqltype = self.column.type
nullable = False
if is_fwd_ref(argument):
argument = de_stringify_annotation(cls, argument)
if hasattr(argument, "__origin__"):
nullable = NoneType in argument.__args__ # type: ignore
nullable = is_optional_union(argument)
if not self._has_nullable:
self.column.nullable = nullable
our_type = de_optionalize_union_types(argument)
if is_fwd_ref(our_type):
our_type = de_stringify_annotation(cls, our_type)
use_args_from = None
if is_pep593(our_type):
+7 -2
View File
@@ -169,7 +169,7 @@ def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
def expand_unions(
type_: Type[Any], include_union: bool = False, discard_none: bool = False
) -> Tuple[Type[Any], ...]:
"""Return a type as as a tuple of individual types, expanding for
"""Return a type as a tuple of individual types, expanding for
``Union`` types."""
if is_union(type_):
@@ -191,9 +191,14 @@ def is_optional(type_):
type_,
"Optional",
"Union",
"UnionType",
)
def is_optional_union(type_: Any) -> bool:
return is_optional(type_) and NoneType in typing_get_args(type_)
def is_union(type_):
return is_origin_of(type_, "Union")
@@ -204,7 +209,7 @@ def is_origin_of(
"""return True if the given type has an __origin__ with the given name
and optional module."""
origin = getattr(type_, "__origin__", None)
origin = typing_get_origin(type_)
if origin is None:
return False
@@ -52,6 +52,7 @@ from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_not
from sqlalchemy.testing import is_true
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.util import compat
from sqlalchemy.util.typing import Annotated
@@ -858,6 +859,7 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
data: Mapped[Union[float, Decimal]] = mapped_column()
reverse_data: Mapped[Union[Decimal, float]] = mapped_column()
optional_data: Mapped[
Optional[Union[float, Decimal]]
] = mapped_column()
@@ -872,9 +874,22 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
reverse_u_optional_data: Mapped[
Union[Decimal, float, None]
] = mapped_column()
float_data: Mapped[float] = mapped_column()
decimal_data: Mapped[Decimal] = mapped_column()
if compat.py310:
pep604_data: Mapped[float | Decimal] = mapped_column()
pep604_reverse: Mapped[Decimal | float] = mapped_column()
pep604_optional: Mapped[
Decimal | float | None
] = mapped_column()
pep604_data_fwd: Mapped["float | Decimal"] = mapped_column()
pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column()
pep604_optional_fwd: Mapped[
"Decimal | float | None"
] = mapped_column()
is_(User.__table__.c.data.type, our_type)
is_false(User.__table__.c.data.nullable)
is_(User.__table__.c.reverse_data.type, our_type)
@@ -889,6 +904,18 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
is_(User.__table__.c.float_data.type, our_type)
is_(User.__table__.c.decimal_data.type, our_type)
if compat.py310:
for suffix in ("", "_fwd"):
data_col = User.__table__.c[f"pep604_data{suffix}"]
reverse_col = User.__table__.c[f"pep604_reverse{suffix}"]
optional_col = User.__table__.c[f"pep604_optional{suffix}"]
is_(data_col.type, our_type)
is_false(data_col.nullable)
is_(reverse_col.type, our_type)
is_false(reverse_col.nullable)
is_(optional_col.type, our_type)
is_true(optional_col.nullable)
def test_missing_mapped_lhs(self, decl_base):
with expect_raises_message(
ArgumentError,