[ty] Fix bad diagnostic range for incorrect implicit __init_subclass__ calls (#24541)

This commit is contained in:
Alex Waygood
2026-04-10 15:27:13 +01:00
committed by GitHub
parent 2ad94df0e1
commit e89f8ef295
4 changed files with 222 additions and 68 deletions
+7
View File
@@ -3421,6 +3421,13 @@ impl<'a> ArgOrKeyword<'a> {
_ => None,
}
}
pub const fn as_keyword(self) -> Option<&'a Keyword> {
match self {
ArgOrKeyword::Keyword(keyword) => Some(keyword),
ArgOrKeyword::Arg(_) => None,
}
}
}
impl<'a> From<&'a Expr> for ArgOrKeyword<'a> {
@@ -564,9 +564,55 @@ class IncorrectArg(RequiresArg, not_arg="foo"):
h = 8
i = 9
j = 10
class NotCallableInitSubclass:
__init_subclass__ = None
# TODO: this should be an error because `__init_subclass__` on the superclass is not callable
class Bad(NotCallableInitSubclass):
a = 1
b = 2
c = 3
```
#### Multiple inheritance
The `metaclass` keyword is ignored, as it has special meaning and is not passed to
`__init_subclass__` at runtime.
```py
class Base:
def __init_subclass__(cls, arg: int): ...
class Valid(Base, arg=5, metaclass=object): ...
# error: [invalid-argument-type]
class Invalid(Base, metaclass=type, arg="foo"): ...
```
Overload matching is performed correctly:
```py
from typing import Literal, overload
class Base:
@overload
def __init_subclass__(cls, mode: Literal["a"], arg: int) -> None: ...
@overload
def __init_subclass__(cls, mode: Literal["b"], arg: str) -> None: ...
def __init_subclass__(cls, mode: str, arg: int | str) -> None: ...
class Valid(Base, mode="a", arg=5): ...
class Valid(Base, mode="b", arg="foo"): ...
# error: [no-matching-overload]
class InvalidType(Base, mode="b", arg=5):
a = 1
b = 2
c = 3
d = 4
e = 5
```
#### More complex cases
For multiple inheritance, the first resolved `__init_subclass__` method is used.
@@ -650,31 +696,6 @@ class Valid(Base[int], arg=1): ...
class InvalidType(Base[int], arg="x"): ... # error: [invalid-argument-type]
```
So are overloads:
```py
class Base:
@overload
def __init_subclass__(cls, mode: Literal["a"], arg: int) -> None: ...
@overload
def __init_subclass__(cls, mode: Literal["b"], arg: str) -> None: ...
def __init_subclass__(cls, mode: str, arg: int | str) -> None: ...
class Valid(Base, mode="a", arg=5): ...
class Valid(Base, mode="b", arg="foo"): ...
class InvalidType(Base, mode="b", arg=5): ... # error: [no-matching-overload]
```
The `metaclass` keyword is ignored, as it has special meaning and is not passed to
`__init_subclass__` at runtime.
```py
class Base:
def __init_subclass__(cls, arg: int): ...
class Valid(Base, arg=5, metaclass=object): ...
```
## `@staticmethod`
### Basic
@@ -48,6 +48,41 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/call/methods.md
33 | h = 8
34 | i = 9
35 | j = 10
36 |
37 | class NotCallableInitSubclass:
38 | __init_subclass__ = None
39 |
40 | # TODO: this should be an error because `__init_subclass__` on the superclass is not callable
41 | class Bad(NotCallableInitSubclass):
42 | a = 1
43 | b = 2
44 | c = 3
45 | class Base:
46 | def __init_subclass__(cls, arg: int): ...
47 |
48 | class Valid(Base, arg=5, metaclass=object): ...
49 |
50 | # error: [invalid-argument-type]
51 | class Invalid(Base, metaclass=type, arg="foo"): ...
52 | from typing import Literal, overload
53 |
54 | class Base:
55 | @overload
56 | def __init_subclass__(cls, mode: Literal["a"], arg: int) -> None: ...
57 | @overload
58 | def __init_subclass__(cls, mode: Literal["b"], arg: str) -> None: ...
59 | def __init_subclass__(cls, mode: str, arg: int | str) -> None: ...
60 |
61 | class Valid(Base, mode="a", arg=5): ...
62 | class Valid(Base, mode="b", arg="foo"): ...
63 |
64 | # error: [no-matching-overload]
65 | class InvalidType(Base, mode="b", arg=5):
66 | a = 1
67 | b = 2
68 | c = 3
69 | d = 4
70 | e = 5
```
# Diagnostics
@@ -58,7 +93,7 @@ error[missing-argument]: No argument provided for required parameter `arg` of fu
|
18 | # Single-base definitions
19 | class MissingArg(RequiresArg): ... # error: [missing-argument]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20 | class InvalidType(RequiresArg, arg="foo"): ... # error: [invalid-argument-type]
21 | class Valid(RequiresArg, arg=1): ...
|
@@ -76,12 +111,12 @@ info: Parameter declared here
```
error[invalid-argument-type]: Argument to function `__init_subclass__` is incorrect
--> src/mdtest_snippet.py:20:1
--> src/mdtest_snippet.py:20:32
|
18 | # Single-base definitions
19 | class MissingArg(RequiresArg): ... # error: [missing-argument]
20 | class InvalidType(RequiresArg, arg="foo"): ... # error: [invalid-argument-type]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Expected `int`, found `Literal["foo"]`
| ^^^^^^^^^ Expected `int`, found `Literal["foo"]`
21 | class Valid(RequiresArg, arg=1): ...
|
info: Function defined here
@@ -100,20 +135,12 @@ info: Function defined here
error[missing-argument]: No argument provided for required parameter `arg` of function `__init_subclass__`
--> src/mdtest_snippet.py:25:1
|
23 | # error: [missing-argument]
24 | # error: [unknown-argument]
25 | / class IncorrectArg(RequiresArg, not_arg="foo"):
26 | | a = 1
27 | | b = 2
28 | | c = 3
29 | | d = 4
30 | | e = 5
31 | | f = 6
32 | | g = 7
33 | | h = 8
34 | | i = 9
35 | | j = 10
| |__________^
23 | # error: [missing-argument]
24 | # error: [unknown-argument]
25 | class IncorrectArg(RequiresArg, not_arg="foo"):
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26 | a = 1
27 | b = 2
|
info: Parameter declared here
--> src/mdtest_snippet.py:13:32
@@ -129,22 +156,14 @@ info: Parameter declared here
```
error[unknown-argument]: Argument `not_arg` does not match any known parameter of function `__init_subclass__`
--> src/mdtest_snippet.py:25:1
--> src/mdtest_snippet.py:25:33
|
23 | # error: [missing-argument]
24 | # error: [unknown-argument]
25 | / class IncorrectArg(RequiresArg, not_arg="foo"):
26 | | a = 1
27 | | b = 2
28 | | c = 3
29 | | d = 4
30 | | e = 5
31 | | f = 6
32 | | g = 7
33 | | h = 8
34 | | i = 9
35 | | j = 10
| |__________^
23 | # error: [missing-argument]
24 | # error: [unknown-argument]
25 | class IncorrectArg(RequiresArg, not_arg="foo"):
| ^^^^^^^^^^^^^
26 | a = 1
27 | b = 2
|
info: Function signature here
--> src/mdtest_snippet.py:13:9
@@ -157,3 +176,61 @@ info: Function signature here
|
```
```
error[invalid-argument-type]: Argument to function `__init_subclass__` is incorrect
--> src/mdtest_snippet.py:51:37
|
50 | # error: [invalid-argument-type]
51 | class Invalid(Base, metaclass=type, arg="foo"): ...
| ^^^^^^^^^ Expected `int`, found `Literal["foo"]`
52 | from typing import Literal, overload
|
info: Function defined here
--> src/mdtest_snippet.py:46:9
|
44 | c = 3
45 | class Base:
46 | def __init_subclass__(cls, arg: int): ...
| ^^^^^^^^^^^^^^^^^ -------- Parameter declared here
47 |
48 | class Valid(Base, arg=5, metaclass=object): ...
|
```
```
error[no-matching-overload]: No overload of function `__init_subclass__` matches arguments
--> src/mdtest_snippet.py:65:1
|
64 | # error: [no-matching-overload]
65 | class InvalidType(Base, mode="b", arg=5):
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
66 | a = 1
67 | b = 2
|
info: First overload defined here
--> src/mdtest_snippet.py:56:9
|
54 | class Base:
55 | @overload
56 | def __init_subclass__(cls, mode: Literal["a"], arg: int) -> None: ...
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
57 | @overload
58 | def __init_subclass__(cls, mode: Literal["b"], arg: str) -> None: ...
|
info: Possible overloads for function `__init_subclass__`:
info: (cls, mode: Literal["a"], arg: int) -> None
info: (cls, mode: Literal["b"], arg: str) -> None
info: Overload implementation defined here
--> src/mdtest_snippet.py:59:9
|
57 | @overload
58 | def __init_subclass__(cls, mode: Literal["b"], arg: str) -> None: ...
59 | def __init_subclass__(cls, mode: str, arg: int | str) -> None: ...
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
60 |
61 | class Valid(Base, mode="a", arg=5): ...
|
```
@@ -17,6 +17,7 @@ use std::fmt;
use itertools::Itertools;
use ruff_db::parsed::parsed_module;
use ruff_python_ast::name::Name;
use ruff_text_size::{Ranged, TextRange};
use rustc_hash::{FxHashMap, FxHashSet};
use smallvec::{SmallVec, smallvec, smallvec_inline};
@@ -59,7 +60,7 @@ use crate::types::{
};
use crate::{DisplaySettings, FxOrderSet, Program};
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
use ruff_python_ast::{self as ast, AnyNodeRef, ArgOrKeyword, PythonVersion};
use ty_module_resolver::KnownModule;
pub(crate) use self::constructor::ConstructorCallableKind;
@@ -908,7 +909,8 @@ impl<'db> Bindings<'db> {
) {
// If all elements are not callable, report that the type as a whole is not callable.
if self.elements.iter().all(|e| !e.is_callable()) {
if let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, node) {
let range = all_arguments_range(node);
if let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, range) {
builder.into_diagnostic(format_args!(
"Object of type `{}` is not callable",
self.callable_type().display(context.db())
@@ -3277,7 +3279,8 @@ impl<'db> CallableBinding<'db> {
compound_diag: Option<&dyn CompoundDiagnostic>,
) {
if !self.is_callable() {
if let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, node) {
let range = all_arguments_range(node);
if let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, range) {
let mut diag = builder.into_diagnostic(format_args!(
"Object of type `{}` is not callable",
self.callable_type.display(context.db()),
@@ -3290,7 +3293,8 @@ impl<'db> CallableBinding<'db> {
}
if self.dunder_call_is_possibly_unbound {
if let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, node) {
let range = all_arguments_range(node);
if let Some(builder) = context.report_lint(&CALL_NON_CALLABLE, range) {
let mut diag = builder.into_diagnostic(format_args!(
"Object of type `{}` is not callable (possibly missing `__call__` method)",
self.callable_type.display(context.db()),
@@ -3380,7 +3384,8 @@ impl<'db> CallableBinding<'db> {
return;
}
let Some(builder) = context.report_lint(&NO_MATCHING_OVERLOAD, node) else {
let range = all_arguments_range(node);
let Some(builder) = context.report_lint(&NO_MATCHING_OVERLOAD, range) else {
return;
};
let callable_description =
@@ -5810,7 +5815,8 @@ impl<'db> BindingError<'db> {
parameters,
paramspec,
} => {
if let Some(builder) = context.report_lint(&MISSING_ARGUMENT, node) {
let range = all_arguments_range(node);
if let Some(builder) = context.report_lint(&MISSING_ARGUMENT, range) {
let s = if parameters.0.len() == 1 { "" } else { "s" };
let mut diag = builder.into_diagnostic(format_args!(
"No argument{s} provided for required parameter{s} {parameters}{}",
@@ -6081,10 +6087,15 @@ impl<'db> BindingError<'db> {
fn get_node(node: ast::AnyNodeRef<'_>, argument_index: Option<usize>) -> ast::AnyNodeRef<'_> {
// If we have a Call node and an argument index, report the diagnostic on the correct
// argument node; otherwise, report it on the entire provided node.
match Self::get_argument_node(node, argument_index) {
Some(ast::ArgOrKeyword::Arg(expr)) => expr.into(),
Some(ast::ArgOrKeyword::Keyword(expr)) => expr.into(),
None => node,
match (Self::get_argument_node(node, argument_index), node) {
(Some(ast::ArgOrKeyword::Arg(expr)), _) => expr.into(),
(Some(ast::ArgOrKeyword::Keyword(expr)), _) => expr.into(),
(None, ast::AnyNodeRef::StmtClassDef(class_def)) => class_def
.arguments
.as_deref()
.map(ast::AnyNodeRef::Arguments)
.unwrap_or(node),
(None, _) => node,
}
}
@@ -6100,6 +6111,22 @@ impl<'db> BindingError<'db> {
.nth(argument_index)
.expect("argument index should not be out of range"),
),
// If we've been passed a `ClassDef` node, it indicates that we're reporting an error
// relating to the class's keyword arguments. Keyword arguments are passed to `__init_subclass__`,
// or `__new__`/`__prepare__` on the metaclass -- but positional arguments are not, and neither
// is the special keyword argument `metaclass`. These need to be excluded from the
// argument index when looking up the relevant keyword-argument node.
(ast::AnyNodeRef::StmtClassDef(class_def), Some(argument_index)) => {
class_def.arguments.as_deref().and_then(|args| {
args.iter_source_order()
.filter_map(ArgOrKeyword::as_keyword)
.filter(|keyword| {
keyword.arg.as_deref().is_none_or(|arg| arg != "metaclass")
})
.nth(argument_index)
.map(ast::ArgOrKeyword::Keyword)
})
}
_ => None,
}
}
@@ -6377,3 +6404,25 @@ fn parse_struct_format<'db>(db: &'db dyn Db, format_string: &str) -> Option<Vec<
Some(elements)
}
/// Return the range for a binding diagnostic that is not related to one specific
/// argument.
///
/// For a normal function call, this is just the range of the entire call.
/// If we're reporting diagnostics for bad arguments in a class definition,
/// however,
/// restrict the range to just the range of the class name + its arguments.
fn all_arguments_range(node: AnyNodeRef) -> TextRange {
node.as_stmt_class_def()
.map(|class| {
TextRange::new(
class.start(),
class
.arguments
.as_deref()
.map(Ranged::end)
.unwrap_or(class.name.end()),
)
})
.unwrap_or(node.range())
}