[ty] Support includeDeclaration in references

This commit is contained in:
Micha Reiser
2026-05-04 20:38:44 -07:00
parent ca5bb91cee
commit c50d266bcd
6 changed files with 364 additions and 35 deletions
+220
View File
@@ -1930,4 +1930,224 @@ func<CURSOR>_alias()
|
"#);
}
#[test]
fn without_declaration_excludes_initial_assignment() {
let test = cursor_test(
"
x<CURSOR> = 1
print(x)
",
);
assert_snapshot!(test.references_without_declaration(), @"
info[references]: Found 1 references
--> main.py:3:7
|
3 | print(x)
| -
|
");
}
#[test]
fn without_declaration_keeps_reassignment_without_declaration() {
let test = cursor_test(
"
x = 1
x = 2
print(x<CURSOR>)
",
);
assert_snapshot!(test.references_without_declaration(), @"
info[references]: Found 2 references
--> main.py:3:1
|
3 | x = 2
| -
4 | print(x)
| -
|
");
}
#[test]
fn without_declaration_keeps_assignment_after_annotation() {
let test = cursor_test(
"
x<CURSOR>: int
x = 1
print(x)
",
);
assert_snapshot!(test.references_without_declaration(), @"
info[references]: Found 2 references
--> main.py:3:1
|
3 | x = 1
| -
4 | print(x)
| -
|
");
}
#[test]
fn without_declaration_excludes_repeated_annotation() {
let test = cursor_test(
"
x<CURSOR>: int
x: str
print(x)
",
);
assert_snapshot!(test.references_without_declaration(), @"
info[references]: Found 1 references
--> main.py:4:7
|
4 | print(x)
| -
|
");
}
#[test]
fn without_declaration_excludes_type_alias_name() {
let test = cursor_test(
"
type Box<CURSOR> = int | None
value: Box
",
);
assert_snapshot!(test.references_without_declaration(), @"
info[references]: Found 1 references
--> main.py:3:8
|
3 | value: Box
| ---
|
");
}
#[test]
fn without_declaration_control_flow() {
let test = cursor_test(
"
def test(flag: bool):
if flag:
x: int = 1
return
x = 2
print(x<CURSOR>)
",
);
assert_snapshot!(test.references_without_declaration(), @"
info[references]: Found 1 references
--> main.py:8:11
|
8 | print(x)
| -
|
");
}
#[test]
fn without_declaration_keeps_binding_when_declaration_is_partial() {
let test = cursor_test(
"
def f(flag: bool):
if flag:
x: int
x = 1
print(x<CURSOR>)
",
);
assert_snapshot!(test.references_without_declaration(), @"
info[references]: Found 2 references
--> main.py:5:5
|
5 | x = 1
| -
6 | print(x)
| -
|
");
}
#[test]
fn without_declaration_excludes_live_conditional_assignments() {
let test = cursor_test(
"
if flag:
x = 1
else:
x = 2
print(x<CURSOR>)
",
);
assert_snapshot!(test.references_without_declaration(), @"
info[references]: Found 1 references
--> main.py:6:7
|
6 | print(x)
| -
|
");
}
#[test]
fn without_declaration_excludes_initial_attribute_assignment() {
let test = cursor_test(
"
class C:
def __init__(self):
self.x<CURSOR> = 1
def f(self):
print(self.x)
",
);
assert_snapshot!(test.references_without_declaration(), @"
info[references]: Found 1 references
--> main.py:7:20
|
7 | print(self.x)
| -
|
");
}
#[test]
fn without_declaration_excludes_attribute_assignment_after_base_rebind() {
let test = cursor_test(
"
class C:
def f(self, flag: bool):
if flag:
self.x = 1
else:
self = C()
self.x<CURSOR> = 2
print(self.x)
",
);
assert_snapshot!(test.references_without_declaration(), @"
info[references]: Found 1 references
--> main.py:9:20
|
9 | print(self.x)
| -
|
");
}
}
+71 -31
View File
@@ -20,6 +20,7 @@ use ruff_python_ast::{
visitor::source_order::{SourceOrderVisitor, TraversalSignal},
};
use ruff_text_size::Ranged;
use ty_python_core::definition::{Definition, DefinitionState};
use ty_python_semantic::{ImportAliasResolution, ResolvedDefinition, SemanticModel};
/// Mode for references search behavior
@@ -382,7 +383,6 @@ enum OccurrenceKind {
Reference,
/// An identifier that declares a new symbol.
Declaration,
/// An identifier that binds a new value to a symbol.
Binding,
}
@@ -438,13 +438,13 @@ impl<'a> SourceOrderVisitor<'a> for LocalReferencesFinder<'a> {
let kind = OccurrenceKind::from(attr_expr.ctx);
self.check_identifier(&attr_expr.attr, kind);
}
AnyNodeRef::StmtFunctionDef(func) if self.should_include_declaration() => {
AnyNodeRef::StmtFunctionDef(func) => {
self.check_declaration_identifier(&func.name);
}
AnyNodeRef::StmtClassDef(class) if self.should_include_declaration() => {
AnyNodeRef::StmtClassDef(class) => {
self.check_declaration_identifier(&class.name);
}
AnyNodeRef::Parameter(parameter) if self.should_include_declaration() => {
AnyNodeRef::Parameter(parameter) => {
self.check_declaration_identifier(&parameter.name);
}
AnyNodeRef::Keyword(keyword) => {
@@ -452,47 +452,43 @@ impl<'a> SourceOrderVisitor<'a> for LocalReferencesFinder<'a> {
self.check_reference_identifier(arg);
}
}
AnyNodeRef::StmtGlobal(global_stmt) if self.should_include_declaration() => {
AnyNodeRef::StmtGlobal(global_stmt) => {
for name in &global_stmt.names {
self.check_declaration_identifier(name);
}
}
AnyNodeRef::StmtNonlocal(nonlocal_stmt) if self.should_include_declaration() => {
AnyNodeRef::StmtNonlocal(nonlocal_stmt) => {
for name in &nonlocal_stmt.names {
self.check_declaration_identifier(name);
}
}
AnyNodeRef::ExceptHandlerExceptHandler(handler)
if self.should_include_declaration() =>
{
AnyNodeRef::ExceptHandlerExceptHandler(handler) => {
if let Some(name) = &handler.name {
self.check_binding_identifier(name);
}
}
AnyNodeRef::PatternMatchAs(pattern_as) if self.should_include_declaration() => {
AnyNodeRef::PatternMatchAs(pattern_as) => {
if let Some(name) = &pattern_as.name {
self.check_binding_identifier(name);
}
}
AnyNodeRef::PatternMatchStar(pattern_star) if self.should_include_declaration() => {
AnyNodeRef::PatternMatchStar(pattern_star) => {
if let Some(name) = &pattern_star.name {
self.check_binding_identifier(name);
}
}
AnyNodeRef::PatternMatchMapping(pattern_mapping)
if self.should_include_declaration() =>
{
AnyNodeRef::PatternMatchMapping(pattern_mapping) => {
if let Some(rest_name) = &pattern_mapping.rest {
self.check_binding_identifier(rest_name);
}
}
AnyNodeRef::TypeParamParamSpec(param_spec) if self.should_include_declaration() => {
AnyNodeRef::TypeParamParamSpec(param_spec) => {
self.check_declaration_identifier(&param_spec.name);
}
AnyNodeRef::TypeParamTypeVarTuple(param_tuple) if self.should_include_declaration() => {
AnyNodeRef::TypeParamTypeVarTuple(param_tuple) => {
self.check_declaration_identifier(&param_tuple.name);
}
AnyNodeRef::TypeParamTypeVar(param_var) if self.should_include_declaration() => {
AnyNodeRef::TypeParamTypeVar(param_var) => {
self.check_declaration_identifier(&param_var.name);
}
AnyNodeRef::ExprStringLiteral(string_expr) => {
@@ -511,7 +507,7 @@ impl<'a> SourceOrderVisitor<'a> for LocalReferencesFinder<'a> {
sub_finder.visit_expr(sub_ast.expr());
}
}
AnyNodeRef::Alias(alias) if self.should_include_declaration() => {
AnyNodeRef::Alias(alias) => {
// Handle import alias declarations
if let Some(asname) = &alias.asname {
self.check_declaration_identifier(asname);
@@ -554,18 +550,7 @@ impl<'a> SourceOrderVisitor<'a> for KeywordArgumentReferencesFinder<'a> {
}
impl<'a> LocalReferencesFinder<'a> {
/// Check if we should include declarations based on the current mode
fn should_include_declaration(&self) -> bool {
matches!(
self.mode,
ReferencesMode::References
| ReferencesMode::DocumentHighlights
| ReferencesMode::Rename
| ReferencesMode::RenameMultiFile
)
}
/// Checks an identifier of a binding (e.g. `x = 10`).
/// Checks an identifier of a binding (e.g. `x = 10`)
fn check_binding_identifier(&mut self, identifier: &ast::Identifier) {
self.check_identifier(identifier, OccurrenceKind::Binding);
}
@@ -611,7 +596,6 @@ impl<'a> LocalReferencesFinder<'a> {
Some(definitions)
}
/// Pushes a reference target when the covering node resolves to any target definition.
fn check_covering_node(&mut self, covering_node: &CoveringNode<'_>, kind: OccurrenceKind) {
let Some(current_definitions) = self.definitions_for_covering_node(covering_node) else {
return;
@@ -622,6 +606,18 @@ impl<'a> LocalReferencesFinder<'a> {
return;
}
if matches!(self.mode, ReferencesMode::ReferencesSkipDeclaration) {
let is_declaration = match kind {
OccurrenceKind::Declaration => true,
OccurrenceKind::Reference => false,
OccurrenceKind::Binding => self.is_declaration(covering_node),
};
if is_declaration {
return;
}
}
let target = ReferenceTarget::new(
self.model.file(),
covering_node.node().range(),
@@ -629,4 +625,48 @@ impl<'a> LocalReferencesFinder<'a> {
);
self.references.push(target);
}
fn is_declaration(&self, covering_node: &CoveringNode<'_>) -> bool {
let db = self.model.db();
let Some(local_definition) = self.model.first_local_definition(covering_node) else {
return false;
};
let file = local_definition.file(db);
let module = ruff_db::parsed::parsed_module(db, file).load(db);
let kind = local_definition.kind(db);
let category = kind.category(file.is_stub(db), &module);
if category.is_declaration() {
return true;
}
if self.binding_has_reachable_explicit_declaration(local_definition) {
return false;
}
self.binding_is_first_assignment_on_some_path(local_definition)
}
fn binding_has_reachable_explicit_declaration(&self, binding: Definition<'a>) -> bool {
let db = self.model.db();
let use_def = ty_python_core::use_def_map(db, binding.scope(db));
use_def
.declarations_at_binding(binding)
.any(|declaration| declaration.declaration.definition().is_some())
}
fn binding_is_first_assignment_on_some_path(&self, binding: Definition<'a>) -> bool {
let db = self.model.db();
let use_def = ty_python_core::use_def_map(db, binding.scope(db));
use_def
.bindings_at_definition(binding)
.any(|prior_binding| {
matches!(
prior_binding.binding,
DefinitionState::Deleted | DefinitionState::Undefined
)
})
}
}
+6
View File
@@ -1491,6 +1491,12 @@ impl<'db> LoopHeaderDefinitionKind<'db> {
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, salsa::Update, get_size2::GetSize)]
pub struct DefinitionNodeKey(NodeKey);
impl DefinitionNodeKey {
pub(crate) fn from_node_ref(node: ast::AnyNodeRef<'_>) -> Self {
Self(NodeKey::from_node(node))
}
}
impl From<&ast::Alias> for DefinitionNodeKey {
fn from(node: &ast::Alias) -> Self {
Self(NodeKey::from_node(node))
+9
View File
@@ -502,6 +502,15 @@ impl<'db> SemanticIndex<'db> {
&self.definitions_by_node[&definition_key.into()]
}
/// Returns the [`definition::Definition`] salsa ingredient(s) for `definition_node`, if any.
pub fn try_definitions(
&self,
definition_node: ast::AnyNodeRef<'_>,
) -> Option<&Definitions<'db>> {
let definition_key = DefinitionNodeKey::from_node_ref(definition_node);
self.definitions_by_node.get(&definition_key)
}
/// Returns the [`definition::Definition`] salsa ingredient for `definition_key`.
///
/// ## Panics
@@ -1,10 +1,12 @@
use ruff_db::files::{File, FilePath};
use ruff_db::parsed::{parsed_module, parsed_string_annotation};
use ruff_db::source::{line_index, source_text};
use ruff_python_ast::find_node::CoveringNode;
use ruff_python_ast::{self as ast, ExprStringLiteral, ModExpression};
use ruff_python_ast::{Expr, ExprRef, name::Name};
use ruff_python_parser::Parsed;
use ruff_source_file::LineIndex;
use ruff_text_size::Ranged;
use rustc_hash::FxHashMap;
use ty_module_resolver::{
KnownModule, Module, ModuleName, list_modules, resolve_module, resolve_real_shadowable_module,
@@ -324,6 +326,42 @@ impl<'db> SemanticModel<'db> {
}
}
/// Returns the first local definition created by `covering_node`, if any.
///
/// A local definition is a user-visible definition associated with `covering_node` itself, or
/// one of its ancestors, whose focus range covers the queried node. This returns only the first
/// match because one syntax node can represent multiple semantic definitions, for example
/// `from module import *`. This helper is intended for classifying the local occurrence, such as
/// deciding whether it is a binding or declaration, not for enumerating every symbol introduced
/// by the syntax.
pub fn first_local_definition(
&self,
covering_node: &CoveringNode<'_>,
) -> Option<Definition<'db>> {
let index = semantic_index(self.db, self.file);
let parsed = parsed_module(self.db, self.file).load(self.db);
let target_range = covering_node.node().range();
for node in covering_node.ancestors() {
let Some(definitions) = index.try_definitions(node) else {
continue;
};
if let Some(definition) = definitions.iter().copied().find(|definition| {
let kind = definition.kind(self.db);
kind.is_user_visible()
&& definition
.focus_range(self.db, &parsed)
.range()
.contains_range(target_range)
}) {
return Some(definition);
}
}
None
}
/// Get a "safe" [`ast::AnyNodeRef`] to use for referring to the given (sub-)AST node.
///
/// If we're analyzing a string annotation, it will return the string literal's node.
@@ -726,6 +764,9 @@ impl_binding_has_ty_def!(ast::StmtClassDef);
impl_binding_has_ty_def!(ast::Parameter);
impl_binding_has_ty_def!(ast::ParameterWithDefault);
impl_binding_has_ty_def!(ast::TypeParamTypeVar);
impl_binding_has_ty_def!(ast::TypeParamParamSpec);
impl_binding_has_ty_def!(ast::TypeParamTypeVarTuple);
impl_binding_has_ty_def!(ast::StmtTypeAlias);
impl HasType for ast::Alias {
fn inferred_type<'db>(&self, model: &SemanticModel<'db>) -> Option<Type<'db>> {
@@ -740,6 +781,7 @@ impl HasType for ast::Alias {
impl HasOptionalDefinition for ast::ExceptHandlerExceptHandler {
fn optional_definition<'db>(&self, model: &SemanticModel<'db>) -> Option<Definition<'db>> {
self.name.as_ref()?;
let index = semantic_index(model.db, model.file);
Some(index.expect_single_definition(self))
}
@@ -754,11 +796,10 @@ impl HasType for ast::ExceptHandlerExceptHandler {
#[cfg(test)]
mod tests {
use ruff_db::files::system_path_to_file;
use ruff_db::parsed::parsed_module;
use crate::db::tests::TestDbBuilder;
use crate::{HasType, SemanticModel};
use ruff_db::files::system_path_to_file;
use ruff_db::parsed::parsed_module;
#[test]
fn function_type() -> anyhow::Result<()> {
@@ -1276,7 +1276,7 @@ mod resolve_definition {
use crate::Db;
use crate::module_docstring;
use crate::types::binding_type;
use ty_python_core::definition::{Definition, DefinitionKind};
use ty_python_core::definition::{Definition, DefinitionCategory, DefinitionKind};
use ty_python_core::scope::{NodeWithScopeKind, ScopeId};
use ty_python_core::{global_scope, place_table, semantic_index, use_def_map};
@@ -1308,6 +1308,19 @@ mod resolve_definition {
}
}
pub fn category(&self, db: &dyn Db) -> DefinitionCategory {
match self {
ResolvedDefinition::Definition(definition) => {
let file = definition.file(db);
let parsed = parsed_module(db, file).load(db);
definition.kind(db).category(file.is_stub(db), &parsed)
}
ResolvedDefinition::Module(_) | ResolvedDefinition::FileWithRange(_) => {
DefinitionCategory::DeclarationAndBinding
}
}
}
pub fn definition(&self) -> Option<Definition<'db>> {
match self {
ResolvedDefinition::Definition(definition) => Some(*definition),