mirror of
https://github.com/sqlalchemy/sqlalchemy.git
synced 2026-05-07 01:10:52 -04:00
0a185a3bb6
Ensure the _FunctionGenerator method do not shadow the function class of the same name Fixed a typing issue where the typed members of :data:`.func` would return the appropriate class of the same name, however this creates an issue for typecheckers such as Zuban and pyrefly that assume :pep:`749` style typechecking even if the file states that it's a :pep:`563` file; they see the returned name as indicating the method object and not the class object. These typecheckers are actually following along with an upcoming test harness that insists on :pep:`749` style name resolution for this case unconditionally. Since :pep:`749` is the way of the future regardless, differently-named type aliases have been added for these return types. Fixes: #13167 Change-Id: If58a3858001c78ab21b2ed343205dfd9ce868576
304 lines
9.5 KiB
Python
304 lines
9.5 KiB
Python
"""Generate inline stubs for generic functions on func"""
|
|
|
|
# mypy: ignore-errors
|
|
|
|
from __future__ import annotations
|
|
|
|
import inspect
|
|
import re
|
|
from tempfile import NamedTemporaryFile
|
|
import textwrap
|
|
import typing
|
|
|
|
import typing_extensions
|
|
|
|
from sqlalchemy.sql.functions import _registry
|
|
from sqlalchemy.sql.functions import ReturnTypeFromArgs
|
|
from sqlalchemy.sql.functions import ReturnTypeFromOptionalArgs
|
|
from sqlalchemy.types import TypeEngine
|
|
from sqlalchemy.util.tool_support import code_writer_cmd
|
|
|
|
|
|
def _fns_in_deterministic_order():
|
|
reg = _registry["_default"]
|
|
for key in sorted(reg):
|
|
cls = reg[key]
|
|
if cls is ReturnTypeFromArgs or cls is ReturnTypeFromOptionalArgs:
|
|
continue
|
|
yield key, cls
|
|
|
|
|
|
def process_functions(filename: str, cmd: code_writer_cmd) -> str:
|
|
with (
|
|
NamedTemporaryFile(
|
|
mode="w",
|
|
delete=False,
|
|
suffix=".py",
|
|
) as buf,
|
|
open(filename) as orig_py,
|
|
):
|
|
indent = ""
|
|
in_block = False
|
|
alias_mapping: dict[str, str] = {}
|
|
|
|
for line in orig_py:
|
|
m = re.match(
|
|
r"^( *)# START GENERATED FUNCTION ACCESSORS",
|
|
line,
|
|
)
|
|
if m:
|
|
in_block = True
|
|
buf.write(line)
|
|
indent = m.group(1)
|
|
buf.write(
|
|
textwrap.indent(
|
|
"""
|
|
# code within this block is **programmatically,
|
|
# statically generated** by tools/generate_sql_functions.py
|
|
""",
|
|
indent,
|
|
)
|
|
)
|
|
|
|
builtins = set(dir(__builtins__))
|
|
for key, fn_class in _fns_in_deterministic_order():
|
|
is_reserved_word = key in builtins
|
|
|
|
class_name = f"_{fn_class.__name__}_func"
|
|
if issubclass(fn_class, ReturnTypeFromArgs):
|
|
guess_its_generic = True
|
|
if issubclass(fn_class, ReturnTypeFromOptionalArgs):
|
|
_TEE = "Optional[_T]"
|
|
else:
|
|
_TEE = "_T"
|
|
|
|
buf.write(
|
|
textwrap.indent(
|
|
f"""
|
|
|
|
# set ColumnElement[_T] as a separate overload, to appease
|
|
# mypy which seems to not want to accept _T from
|
|
# _ColumnExpressionArgument. Seems somewhat related to the covariant
|
|
# _HasClauseElement as of mypy 1.15
|
|
|
|
@overload
|
|
def {key}( {' # noqa: A001' if is_reserved_word else ''}
|
|
self,
|
|
col: ColumnElement[_T],
|
|
*args: _ColumnExpressionOrLiteralArgument[Any],
|
|
**kwargs: Any,
|
|
) -> {class_name}[_T]:
|
|
...
|
|
|
|
@overload
|
|
def {key}( {' # noqa: A001' if is_reserved_word else ''}
|
|
self,
|
|
col: _ColumnExpressionArgument[{_TEE}],
|
|
*args: _ColumnExpressionOrLiteralArgument[Any],
|
|
**kwargs: Any,
|
|
) -> {class_name}[_T]:
|
|
...
|
|
|
|
@overload
|
|
def {key}( {' # noqa: A001' if is_reserved_word else ''}
|
|
self,
|
|
col: {_TEE},
|
|
*args: _ColumnExpressionOrLiteralArgument[Any],
|
|
**kwargs: Any,
|
|
) -> {class_name}[_T]:
|
|
...
|
|
|
|
def {key}( {' # noqa: A001' if is_reserved_word else ''}
|
|
self,
|
|
col: _ColumnExpressionOrLiteralArgument[{_TEE}],
|
|
*args: _ColumnExpressionOrLiteralArgument[Any],
|
|
**kwargs: Any,
|
|
) -> {class_name}[_T]:
|
|
...
|
|
|
|
""",
|
|
indent,
|
|
)
|
|
)
|
|
else:
|
|
guess_its_generic = bool(fn_class.__parameters__)
|
|
|
|
# the latest flake8 is quite broken here:
|
|
# 1. it insists on linting f-strings, no option
|
|
# to turn it off
|
|
# 2. the f-string indentation rules are either broken
|
|
# or completely impossible to figure out
|
|
# 3. there's no way to E501 a too-long f-string,
|
|
# so I can't even put the expressions all one line
|
|
# to get around the indentation errors
|
|
# 4. Therefore here I have to concat part of the
|
|
# string outside of the f-string
|
|
_type = class_name
|
|
_type += "[Any]" if guess_its_generic else ""
|
|
_reserved_word = (
|
|
" # noqa: A001" if is_reserved_word else ""
|
|
)
|
|
|
|
# now the f-string
|
|
buf.write(
|
|
textwrap.indent(
|
|
f"""
|
|
@property
|
|
def {key}(self) -> Type[{_type}]:{_reserved_word}
|
|
...
|
|
|
|
""",
|
|
indent,
|
|
)
|
|
)
|
|
orig_name = fn_class.__name__
|
|
alias_name = class_name
|
|
if guess_its_generic:
|
|
orig_name += "[_T]"
|
|
alias_mapping[orig_name] = alias_name
|
|
|
|
m = re.match(
|
|
r"^( *)# START GENERATED FUNCTION TYPING TESTS",
|
|
line,
|
|
)
|
|
if m:
|
|
in_block = True
|
|
buf.write(line)
|
|
indent = m.group(1)
|
|
|
|
buf.write(
|
|
textwrap.indent(
|
|
"""
|
|
# code within this block is **programmatically,
|
|
# statically generated** by tools/generate_sql_functions.py
|
|
""",
|
|
indent,
|
|
)
|
|
)
|
|
|
|
count = 0
|
|
for key, fn_class in _fns_in_deterministic_order():
|
|
if issubclass(fn_class, ReturnTypeFromArgs):
|
|
count += 1
|
|
|
|
# Would be ReturnTypeFromArgs
|
|
(orig_base,) = typing_extensions.get_original_bases(
|
|
fn_class
|
|
)
|
|
# Type parameter of ReturnTypeFromArgs
|
|
(rtype,) = typing.get_args(orig_base)
|
|
# The origin type, if rtype is a generic
|
|
orig_type = typing.get_origin(rtype)
|
|
if orig_type is not None:
|
|
coltype = rf"{orig_type.__name__}[int]"
|
|
else:
|
|
coltype = "int"
|
|
|
|
buf.write(
|
|
textwrap.indent(
|
|
rf"""
|
|
|
|
# test the {key}() function.
|
|
# this function is a ReturnTypeFromArgs type.
|
|
|
|
fn{count} = func.{key}(column('x', Integer))
|
|
assert_type(fn{count}, functions.{key}[int])
|
|
|
|
stmt{count} = select(func.{key}(column('x', Integer)))
|
|
assert_type(stmt{count}, Select[{coltype}])
|
|
|
|
|
|
""",
|
|
indent,
|
|
)
|
|
)
|
|
elif fn_class.__name__ == "aggregate_strings":
|
|
count += 1
|
|
buf.write(
|
|
textwrap.indent(
|
|
rf"""
|
|
|
|
# test the aggregate_strings() function.
|
|
# this function is somewhat special case.
|
|
|
|
stmt{count} = select(func.{key}(column('x', String), ','))
|
|
assert_type(stmt{count}, Select[str])
|
|
|
|
""",
|
|
indent,
|
|
)
|
|
)
|
|
|
|
elif hasattr(fn_class, "type") and isinstance(
|
|
fn_class.type, TypeEngine
|
|
):
|
|
python_type = fn_class.type.python_type
|
|
python_expr = python_type.__name__
|
|
argspec = inspect.getfullargspec(fn_class)
|
|
if fn_class.__name__ == "next_value":
|
|
args = "SqlAlchemySequence('x_seq')"
|
|
else:
|
|
args = ", ".join(
|
|
'column("x")' for elem in argspec.args[1:]
|
|
)
|
|
count += 1
|
|
|
|
buf.write(
|
|
textwrap.indent(
|
|
rf"""
|
|
|
|
# test the {key}() function.
|
|
# this function is fixed to the SQL {fn_class.type} class, or the {python_expr} type.
|
|
|
|
fn{count} = func.{key}({args})
|
|
assert_type(fn{count}, functions.{key})
|
|
|
|
stmt{count} = select(func.{key}({args}))
|
|
assert_type(stmt{count}, Select[{python_expr}])
|
|
|
|
""", # noqa: E501
|
|
indent,
|
|
)
|
|
)
|
|
|
|
m = re.match(
|
|
r"^( *)# START GENERATED FUNCTION ALIASES",
|
|
line,
|
|
)
|
|
if m:
|
|
in_block = True
|
|
buf.write(line)
|
|
indent = m.group(1)
|
|
|
|
for name, alias in alias_mapping.items():
|
|
buf.write(f"{indent}{alias}: TypeAlias = {name}\n")
|
|
|
|
if in_block and line.startswith(
|
|
f"{indent}# END GENERATED FUNCTION"
|
|
):
|
|
in_block = False
|
|
|
|
if not in_block:
|
|
buf.write(line)
|
|
return buf.name
|
|
|
|
|
|
def main(cmd: code_writer_cmd) -> None:
|
|
for path in [functions_py, test_functions_py]:
|
|
destination_path = path
|
|
tempfile = process_functions(destination_path, cmd)
|
|
cmd.run_zimports(tempfile)
|
|
cmd.run_black(tempfile)
|
|
cmd.write_output_file_from_tempfile(tempfile, destination_path)
|
|
|
|
|
|
functions_py = "lib/sqlalchemy/sql/functions.py"
|
|
test_functions_py = "test/typing/plain_files/sql/functions.py"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cmd = code_writer_cmd(__file__)
|
|
|
|
with cmd.run_program():
|
|
main(cmd)
|