Merge "improve overloads applied to generic functions" into main

This commit is contained in:
Federico Caselli
2025-04-03 19:22:37 +00:00
committed by Gerrit Code Review
3 changed files with 68 additions and 57 deletions
+57 -50
View File
@@ -5,7 +5,6 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""SQL function API, factories, and built-in functions."""
from __future__ import annotations
@@ -153,7 +152,9 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative):
clause_expr: Grouping[Any]
def __init__(self, *clauses: _ColumnExpressionOrLiteralArgument[Any]):
def __init__(
self, *clauses: _ColumnExpressionOrLiteralArgument[Any]
) -> None:
r"""Construct a :class:`.FunctionElement`.
:param \*clauses: list of column expressions that form the arguments
@@ -777,7 +778,7 @@ class FunctionAsBinary(BinaryExpression[Any]):
def __init__(
self, fn: FunctionElement[Any], left_index: int, right_index: int
):
) -> None:
self.sql_function = fn
self.left_index = left_index
self.right_index = right_index
@@ -829,7 +830,7 @@ class ScalarFunctionColumn(NamedColumn[_T]):
fn: FunctionElement[_T],
name: str,
type_: Optional[_TypeEngineArgument[_T]] = None,
):
) -> None:
self.fn = fn
self.name = name
@@ -928,7 +929,7 @@ class _FunctionGenerator:
""" # noqa
def __init__(self, **opts: Any):
def __init__(self, **opts: Any) -> None:
self.__names: List[str] = []
self.opts = opts
@@ -988,10 +989,10 @@ class _FunctionGenerator:
@property
def ansifunction(self) -> Type[AnsiFunction[Any]]: ...
# set ColumnElement[_T] as a separate overload, to appease mypy
# which seems to not want to accept _T from _ColumnExpressionArgument.
# this is even if all non-generic types are removed from it, so
# reasons remain unclear for why this does not work
# set ColumnElement[_T] as a separate overload, to appease
# mypy which seems to not want to accept _T from
# _ColumnExpressionArgument. Seems somewhat related to the covariant
# _HasClauseElement as of mypy 1.15
@overload
def array_agg(
@@ -1012,7 +1013,7 @@ class _FunctionGenerator:
@overload
def array_agg(
self,
col: _ColumnExpressionOrLiteralArgument[_T],
col: _T,
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
) -> array_agg[_T]: ...
@@ -1030,10 +1031,10 @@ class _FunctionGenerator:
@property
def char_length(self) -> Type[char_length]: ...
# set ColumnElement[_T] as a separate overload, to appease mypy
# which seems to not want to accept _T from _ColumnExpressionArgument.
# this is even if all non-generic types are removed from it, so
# reasons remain unclear for why this does not work
# set ColumnElement[_T] as a separate overload, to appease
# mypy which seems to not want to accept _T from
# _ColumnExpressionArgument. Seems somewhat related to the covariant
# _HasClauseElement as of mypy 1.15
@overload
def coalesce(
@@ -1054,7 +1055,7 @@ class _FunctionGenerator:
@overload
def coalesce(
self,
col: _ColumnExpressionOrLiteralArgument[_T],
col: _T,
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
) -> coalesce[_T]: ...
@@ -1105,10 +1106,10 @@ class _FunctionGenerator:
@property
def localtimestamp(self) -> Type[localtimestamp]: ...
# set ColumnElement[_T] as a separate overload, to appease mypy
# which seems to not want to accept _T from _ColumnExpressionArgument.
# this is even if all non-generic types are removed from it, so
# reasons remain unclear for why this does not work
# set ColumnElement[_T] as a separate overload, to appease
# mypy which seems to not want to accept _T from
# _ColumnExpressionArgument. Seems somewhat related to the covariant
# _HasClauseElement as of mypy 1.15
@overload
def max( # noqa: A001
@@ -1129,7 +1130,7 @@ class _FunctionGenerator:
@overload
def max( # noqa: A001
self,
col: _ColumnExpressionOrLiteralArgument[_T],
col: _T,
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
) -> max[_T]: ...
@@ -1141,10 +1142,10 @@ class _FunctionGenerator:
**kwargs: Any,
) -> max[_T]: ...
# set ColumnElement[_T] as a separate overload, to appease mypy
# which seems to not want to accept _T from _ColumnExpressionArgument.
# this is even if all non-generic types are removed from it, so
# reasons remain unclear for why this does not work
# set ColumnElement[_T] as a separate overload, to appease
# mypy which seems to not want to accept _T from
# _ColumnExpressionArgument. Seems somewhat related to the covariant
# _HasClauseElement as of mypy 1.15
@overload
def min( # noqa: A001
@@ -1165,7 +1166,7 @@ class _FunctionGenerator:
@overload
def min( # noqa: A001
self,
col: _ColumnExpressionOrLiteralArgument[_T],
col: _T,
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
) -> min[_T]: ...
@@ -1210,10 +1211,10 @@ class _FunctionGenerator:
@property
def session_user(self) -> Type[session_user]: ...
# set ColumnElement[_T] as a separate overload, to appease mypy
# which seems to not want to accept _T from _ColumnExpressionArgument.
# this is even if all non-generic types are removed from it, so
# reasons remain unclear for why this does not work
# set ColumnElement[_T] as a separate overload, to appease
# mypy which seems to not want to accept _T from
# _ColumnExpressionArgument. Seems somewhat related to the covariant
# _HasClauseElement as of mypy 1.15
@overload
def sum( # noqa: A001
@@ -1234,7 +1235,7 @@ class _FunctionGenerator:
@overload
def sum( # noqa: A001
self,
col: _ColumnExpressionOrLiteralArgument[_T],
col: _T,
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
) -> sum[_T]: ...
@@ -1330,7 +1331,7 @@ class Function(FunctionElement[_T]):
*clauses: _ColumnExpressionOrLiteralArgument[_T],
type_: None = ...,
packagenames: Optional[Tuple[str, ...]] = ...,
): ...
) -> None: ...
@overload
def __init__(
@@ -1339,7 +1340,7 @@ class Function(FunctionElement[_T]):
*clauses: _ColumnExpressionOrLiteralArgument[Any],
type_: _TypeEngineArgument[_T] = ...,
packagenames: Optional[Tuple[str, ...]] = ...,
): ...
) -> None: ...
def __init__(
self,
@@ -1347,7 +1348,7 @@ class Function(FunctionElement[_T]):
*clauses: _ColumnExpressionOrLiteralArgument[Any],
type_: Optional[_TypeEngineArgument[_T]] = None,
packagenames: Optional[Tuple[str, ...]] = None,
):
) -> None:
"""Construct a :class:`.Function`.
The :data:`.func` construct is normally used to construct
@@ -1523,7 +1524,7 @@ class GenericFunction(Function[_T]):
def __init__(
self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any
):
) -> None:
parsed_args = kwargs.pop("_parsed_args", None)
if parsed_args is None:
parsed_args = [
@@ -1570,7 +1571,7 @@ class next_value(GenericFunction[int]):
("sequence", InternalTraversal.dp_named_ddl_element)
]
def __init__(self, seq: schema.Sequence, **kw: Any):
def __init__(self, seq: schema.Sequence, **kw: Any) -> None:
assert isinstance(
seq, schema.Sequence
), "next_value() accepts a Sequence object as input."
@@ -1595,7 +1596,9 @@ class AnsiFunction(GenericFunction[_T]):
inherit_cache = True
def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
def __init__(
self, *args: _ColumnExpressionArgument[Any], **kwargs: Any
) -> None:
GenericFunction.__init__(self, *args, **kwargs)
@@ -1606,10 +1609,10 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
inherit_cache = True
# set ColumnElement[_T] as a separate overload, to appease mypy which seems
# to not want to accept _T from _ColumnExpressionArgument. this is even if
# all non-generic types are removed from it, so reasons remain unclear for
# why this does not work
# set ColumnElement[_T] as a separate overload, to appease
# mypy which seems to not want to accept _T from
# _ColumnExpressionArgument. Seems somewhat related to the covariant
# _HasClauseElement as of mypy 1.15
@overload
def __init__(
@@ -1617,7 +1620,7 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
col: ColumnElement[_T],
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
): ...
) -> None: ...
@overload
def __init__(
@@ -1625,19 +1628,19 @@ class ReturnTypeFromArgs(GenericFunction[_T]):
col: _ColumnExpressionArgument[_T],
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
): ...
) -> None: ...
@overload
def __init__(
self,
col: _ColumnExpressionOrLiteralArgument[_T],
col: _T,
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
): ...
) -> None: ...
def __init__(
self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any
):
self, *args: _ColumnExpressionOrLiteralArgument[_T], **kwargs: Any
) -> None:
fn_args: Sequence[ColumnElement[Any]] = [
coercions.expect(
roles.ExpressionElementRole,
@@ -1719,7 +1722,7 @@ class char_length(GenericFunction[int]):
type = sqltypes.Integer()
inherit_cache = True
def __init__(self, arg: _ColumnExpressionArgument[str], **kw: Any):
def __init__(self, arg: _ColumnExpressionArgument[str], **kw: Any) -> None:
# slight hack to limit to just one positional argument
# not sure why this one function has this special treatment
super().__init__(arg, **kw)
@@ -1765,7 +1768,7 @@ class count(GenericFunction[int]):
_ColumnExpressionArgument[Any], _StarOrOne, None
] = None,
**kwargs: Any,
):
) -> None:
if expression is None:
expression = literal_column("*")
super().__init__(expression, **kwargs)
@@ -1854,7 +1857,9 @@ class array_agg(ReturnTypeFromArgs[Sequence[_T]]):
inherit_cache = True
def __init__(self, *args: _ColumnExpressionArgument[Any], **kwargs: Any):
def __init__(
self, *args: _ColumnExpressionArgument[Any], **kwargs: Any
) -> None:
fn_args: Sequence[ColumnElement[Any]] = [
coercions.expect(
roles.ExpressionElementRole, c, apply_propagate_attrs=self
@@ -2081,5 +2086,7 @@ class aggregate_strings(GenericFunction[str]):
_has_args = True
inherit_cache = True
def __init__(self, clause: _ColumnExpressionArgument[Any], separator: str):
def __init__(
self, clause: _ColumnExpressionArgument[Any], separator: str
) -> None:
super().__init__(clause, separator)
@@ -1,4 +1,6 @@
from sqlalchemy import column
from sqlalchemy import func
from sqlalchemy import Integer
from sqlalchemy import select
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
@@ -53,6 +55,10 @@ reveal_type(stmt1)
# test #10818
# EXPECTED_TYPE: coalesce[str]
reveal_type(func.coalesce(Foo.c, "a", "b"))
# EXPECTED_TYPE: coalesce[str]
reveal_type(func.coalesce("a", "b"))
# EXPECTED_TYPE: coalesce[int]
reveal_type(func.coalesce(column("x", Integer), 3))
stmt2 = select(Foo.a, func.coalesce(Foo.c, "a", "b")).group_by(Foo.a)
+5 -7
View File
@@ -67,10 +67,10 @@ def process_functions(filename: str, cmd: code_writer_cmd) -> str:
textwrap.indent(
f"""
# set ColumnElement[_T] as a separate overload, to appease mypy
# which seems to not want to accept _T from _ColumnExpressionArgument.
# this is even if all non-generic types are removed from it, so
# reasons remain unclear for why this does not work
# set ColumnElement[_T] as a separate overload, to appease
# mypy which seems to not want to accept _T from
# _ColumnExpressionArgument. Seems somewhat related to the covariant
# _HasClauseElement as of mypy 1.15
@overload
def {key}( {' # noqa: A001' if is_reserved_word else ''}
@@ -90,17 +90,15 @@ def {key}( {' # noqa: A001' if is_reserved_word else ''}
) -> {fn_class.__name__}[_T]:
...
@overload
def {key}( {' # noqa: A001' if is_reserved_word else ''}
self,
col: _ColumnExpressionOrLiteralArgument[_T],
col: _T,
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
) -> {fn_class.__name__}[_T]:
...
def {key}( {' # noqa: A001' if is_reserved_word else ''}
self,
col: _ColumnExpressionOrLiteralArgument[_T],