Files
sqlalchemy/tools/generate_sql_functions.py
Rebecca Chen 0ab0d0c980 Change typing tests to use assert_type instead of reveal_type
Closes: #12922
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/12922
Pull-request-sha: 580f663816

Change-Id: I9f3bdb4c105971f53fa10ed8a934356203ddb080
2025-10-18 11:23:20 -04:00

262 lines
8.1 KiB
Python

"""Generate inline stubs for generic functions on func"""
# mypy: ignore-errors
from __future__ import annotations
import inspect
import re
from tempfile import NamedTemporaryFile
import textwrap
import typing
import typing_extensions
from sqlalchemy.sql.functions import _registry
from sqlalchemy.sql.functions import ReturnTypeFromArgs
from sqlalchemy.types import TypeEngine
from sqlalchemy.util.tool_support import code_writer_cmd
def _fns_in_deterministic_order():
reg = _registry["_default"]
for key in sorted(reg):
cls = reg[key]
if cls is ReturnTypeFromArgs:
continue
yield key, cls
def process_functions(filename: str, cmd: code_writer_cmd) -> str:
with (
NamedTemporaryFile(
mode="w",
delete=False,
suffix=".py",
) as buf,
open(filename) as orig_py,
):
indent = ""
in_block = False
for line in orig_py:
m = re.match(
r"^( *)# START GENERATED FUNCTION ACCESSORS",
line,
)
if m:
in_block = True
buf.write(line)
indent = m.group(1)
buf.write(
textwrap.indent(
"""
# code within this block is **programmatically,
# statically generated** by tools/generate_sql_functions.py
""",
indent,
)
)
builtins = set(dir(__builtins__))
for key, fn_class in _fns_in_deterministic_order():
is_reserved_word = key in builtins
if issubclass(fn_class, ReturnTypeFromArgs):
buf.write(
textwrap.indent(
f"""
# set ColumnElement[_T] as a separate overload, to appease
# mypy which seems to not want to accept _T from
# _ColumnExpressionArgument. Seems somewhat related to the covariant
# _HasClauseElement as of mypy 1.15
@overload
def {key}( {' # noqa: A001' if is_reserved_word else ''}
self,
col: ColumnElement[_T],
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
) -> {fn_class.__name__}[_T]:
...
@overload
def {key}( {' # noqa: A001' if is_reserved_word else ''}
self,
col: _ColumnExpressionArgument[_T],
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
) -> {fn_class.__name__}[_T]:
...
@overload
def {key}( {' # noqa: A001' if is_reserved_word else ''}
self,
col: _T,
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
) -> {fn_class.__name__}[_T]:
...
def {key}( {' # noqa: A001' if is_reserved_word else ''}
self,
col: _ColumnExpressionOrLiteralArgument[_T],
*args: _ColumnExpressionOrLiteralArgument[Any],
**kwargs: Any,
) -> {fn_class.__name__}[_T]:
...
""",
indent,
)
)
else:
guess_its_generic = bool(fn_class.__parameters__)
# the latest flake8 is quite broken here:
# 1. it insists on linting f-strings, no option
# to turn it off
# 2. the f-string indentation rules are either broken
# or completely impossible to figure out
# 3. there's no way to E501 a too-long f-string,
# so I can't even put the expressions all one line
# to get around the indentation errors
# 4. Therefore here I have to concat part of the
# string outside of the f-string
_type = fn_class.__name__
_type += "[Any]" if guess_its_generic else ""
_reserved_word = (
" # noqa: A001" if is_reserved_word else ""
)
# now the f-string
buf.write(
textwrap.indent(
f"""
@property
def {key}(self) -> Type[{_type}]:{_reserved_word}
...
""",
indent,
)
)
m = re.match(
r"^( *)# START GENERATED FUNCTION TYPING TESTS",
line,
)
if m:
in_block = True
buf.write(line)
indent = m.group(1)
buf.write(
textwrap.indent(
"""
# code within this block is **programmatically,
# statically generated** by tools/generate_sql_functions.py
""",
indent,
)
)
count = 0
for key, fn_class in _fns_in_deterministic_order():
if issubclass(fn_class, ReturnTypeFromArgs):
count += 1
# Would be ReturnTypeFromArgs
(orig_base,) = typing_extensions.get_original_bases(
fn_class
)
# Type parameter of ReturnTypeFromArgs
(rtype,) = typing.get_args(orig_base)
# The origin type, if rtype is a generic
orig_type = typing.get_origin(rtype)
if orig_type is not None:
coltype = rf"{orig_type.__name__}[int]"
else:
coltype = "int"
buf.write(
textwrap.indent(
rf"""
stmt{count} = select(func.{key}(column('x', Integer)))
assert_type(stmt{count}, Select[{coltype}])
""",
indent,
)
)
elif fn_class.__name__ == "aggregate_strings":
count += 1
buf.write(
textwrap.indent(
rf"""
stmt{count} = select(func.{key}(column('x', String), ','))
assert_type(stmt{count}, Select[str])
""",
indent,
)
)
elif hasattr(fn_class, "type") and isinstance(
fn_class.type, TypeEngine
):
python_type = fn_class.type.python_type
python_expr = python_type.__name__
argspec = inspect.getfullargspec(fn_class)
if fn_class.__name__ == "next_value":
args = "SqlAlchemySequence('x_seq')"
else:
args = ", ".join(
'column("x")' for elem in argspec.args[1:]
)
count += 1
buf.write(
textwrap.indent(
rf"""
stmt{count} = select(func.{key}({args}))
assert_type(stmt{count}, Select[{python_expr}])
""",
indent,
)
)
if in_block and line.startswith(
f"{indent}# END GENERATED FUNCTION"
):
in_block = False
if not in_block:
buf.write(line)
return buf.name
def main(cmd: code_writer_cmd) -> None:
for path in [functions_py, test_functions_py]:
destination_path = path
tempfile = process_functions(destination_path, cmd)
cmd.run_zimports(tempfile)
cmd.run_black(tempfile)
cmd.write_output_file_from_tempfile(tempfile, destination_path)
functions_py = "lib/sqlalchemy/sql/functions.py"
test_functions_py = "test/typing/plain_files/sql/functions.py"
if __name__ == "__main__":
cmd = code_writer_cmd(__file__)
with cmd.run_program():
main(cmd)