mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-13 12:17:28 -04:00
Detection of PEP 604 union syntax.
### Description
Fixes #8478
Handle `UnionType` as arguments to `Mapped`, e.g., `Mapped[str | None]`:
- adds `utils.typing.is_optional_union()` used to detect if a column should be nullable.
- adds `"UnionType"` to `utils.typing.is_optional()` names.
- uses `get_origin()` in `utils.typing.is_origin_of()` as `UnionType` has no `__origin__` attribute.
- tests with runtime type and postponed annotations and guard the tests running with `compat.py310`.
### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)
-->
This pull request is:
- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [x] A short code fix
- please include the issue number, and create an issue if none exists, which
must include a complete example of the issue. one line code fixes without an
issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests. one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.
**Have a nice day!**
Closes: #8479
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8479
Pull-request-sha: 1241765482
Change-Id: Ib3248043dd4a97324ac592c048385006536b2d49
This commit is contained in:
committed by
sqla-tester
parent
d3e0b8e750
commit
c3cfee5b00
@@ -52,8 +52,8 @@ from ..sql.schema import SchemaConst
|
||||
from ..util.typing import de_optionalize_union_types
|
||||
from ..util.typing import de_stringify_annotation
|
||||
from ..util.typing import is_fwd_ref
|
||||
from ..util.typing import is_optional_union
|
||||
from ..util.typing import is_pep593
|
||||
from ..util.typing import NoneType
|
||||
from ..util.typing import Self
|
||||
from ..util.typing import typing_get_args
|
||||
|
||||
@@ -652,17 +652,15 @@ class MappedColumn(
|
||||
) -> None:
|
||||
sqltype = self.column.type
|
||||
|
||||
nullable = False
|
||||
if is_fwd_ref(argument):
|
||||
argument = de_stringify_annotation(cls, argument)
|
||||
|
||||
if hasattr(argument, "__origin__"):
|
||||
nullable = NoneType in argument.__args__ # type: ignore
|
||||
nullable = is_optional_union(argument)
|
||||
|
||||
if not self._has_nullable:
|
||||
self.column.nullable = nullable
|
||||
|
||||
our_type = de_optionalize_union_types(argument)
|
||||
if is_fwd_ref(our_type):
|
||||
our_type = de_stringify_annotation(cls, our_type)
|
||||
|
||||
use_args_from = None
|
||||
if is_pep593(our_type):
|
||||
|
||||
@@ -169,7 +169,7 @@ def make_union_type(*types: _AnnotationScanType) -> Type[Any]:
|
||||
def expand_unions(
|
||||
type_: Type[Any], include_union: bool = False, discard_none: bool = False
|
||||
) -> Tuple[Type[Any], ...]:
|
||||
"""Return a type as as a tuple of individual types, expanding for
|
||||
"""Return a type as a tuple of individual types, expanding for
|
||||
``Union`` types."""
|
||||
|
||||
if is_union(type_):
|
||||
@@ -191,9 +191,14 @@ def is_optional(type_):
|
||||
type_,
|
||||
"Optional",
|
||||
"Union",
|
||||
"UnionType",
|
||||
)
|
||||
|
||||
|
||||
def is_optional_union(type_: Any) -> bool:
|
||||
return is_optional(type_) and NoneType in typing_get_args(type_)
|
||||
|
||||
|
||||
def is_union(type_):
|
||||
return is_origin_of(type_, "Union")
|
||||
|
||||
@@ -204,7 +209,7 @@ def is_origin_of(
|
||||
"""return True if the given type has an __origin__ with the given name
|
||||
and optional module."""
|
||||
|
||||
origin = getattr(type_, "__origin__", None)
|
||||
origin = typing_get_origin(type_)
|
||||
if origin is None:
|
||||
return False
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ from sqlalchemy.testing import is_false
|
||||
from sqlalchemy.testing import is_not
|
||||
from sqlalchemy.testing import is_true
|
||||
from sqlalchemy.testing.fixtures import fixture_session
|
||||
from sqlalchemy.util import compat
|
||||
from sqlalchemy.util.typing import Annotated
|
||||
|
||||
|
||||
@@ -858,6 +859,7 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
|
||||
|
||||
data: Mapped[Union[float, Decimal]] = mapped_column()
|
||||
reverse_data: Mapped[Union[Decimal, float]] = mapped_column()
|
||||
|
||||
optional_data: Mapped[
|
||||
Optional[Union[float, Decimal]]
|
||||
] = mapped_column()
|
||||
@@ -872,9 +874,22 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
|
||||
reverse_u_optional_data: Mapped[
|
||||
Union[Decimal, float, None]
|
||||
] = mapped_column()
|
||||
|
||||
float_data: Mapped[float] = mapped_column()
|
||||
decimal_data: Mapped[Decimal] = mapped_column()
|
||||
|
||||
if compat.py310:
|
||||
pep604_data: Mapped[float | Decimal] = mapped_column()
|
||||
pep604_reverse: Mapped[Decimal | float] = mapped_column()
|
||||
pep604_optional: Mapped[
|
||||
Decimal | float | None
|
||||
] = mapped_column()
|
||||
pep604_data_fwd: Mapped["float | Decimal"] = mapped_column()
|
||||
pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column()
|
||||
pep604_optional_fwd: Mapped[
|
||||
"Decimal | float | None"
|
||||
] = mapped_column()
|
||||
|
||||
is_(User.__table__.c.data.type, our_type)
|
||||
is_false(User.__table__.c.data.nullable)
|
||||
is_(User.__table__.c.reverse_data.type, our_type)
|
||||
@@ -889,6 +904,18 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL):
|
||||
is_(User.__table__.c.float_data.type, our_type)
|
||||
is_(User.__table__.c.decimal_data.type, our_type)
|
||||
|
||||
if compat.py310:
|
||||
for suffix in ("", "_fwd"):
|
||||
data_col = User.__table__.c[f"pep604_data{suffix}"]
|
||||
reverse_col = User.__table__.c[f"pep604_reverse{suffix}"]
|
||||
optional_col = User.__table__.c[f"pep604_optional{suffix}"]
|
||||
is_(data_col.type, our_type)
|
||||
is_false(data_col.nullable)
|
||||
is_(reverse_col.type, our_type)
|
||||
is_false(reverse_col.nullable)
|
||||
is_(optional_col.type, our_type)
|
||||
is_true(optional_col.nullable)
|
||||
|
||||
def test_missing_mapped_lhs(self, decl_base):
|
||||
with expect_raises_message(
|
||||
ArgumentError,
|
||||
|
||||
Reference in New Issue
Block a user