[ty] Offer string literal completion suggestions based on expected type (#24555)

Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
Anders Brams
2026-04-29 17:16:58 +02:00
committed by GitHub
parent ea4b40641a
commit bfe5b51890
12 changed files with 777 additions and 18 deletions
Generated
+1
View File
@@ -4536,6 +4536,7 @@ dependencies = [
"ruff_python_ast",
"ruff_python_codegen",
"ruff_python_importer",
"ruff_python_literal",
"ruff_python_parser",
"ruff_python_trivia",
"ruff_source_file",
+1
View File
@@ -22,6 +22,7 @@ ruff_memory_usage = { workspace = true }
ruff_python_ast = { workspace = true }
ruff_python_codegen = { workspace = true }
ruff_python_importer = { workspace = true }
ruff_python_literal = { workspace = true }
ruff_python_trivia = { workspace = true }
ruff_source_file = { workspace = true }
ruff_text_size = { workspace = true }
+489 -5
View File
@@ -7,9 +7,11 @@ use ruff_db::source::{SourceText, source_text};
use ruff_diagnostics::Edit;
use ruff_python_ast::find_node::{CoveringNode, covering_node};
use ruff_python_ast::name::{Name, UnqualifiedName};
use ruff_python_ast::str::Quote;
use ruff_python_ast::token::{Token, TokenKind, Tokens};
use ruff_python_ast::{self as ast, AnyNodeRef};
use ruff_python_codegen::Stylist;
use ruff_python_literal::escape::{Escape, UnicodeEscape};
use ruff_text_size::{Ranged, TextRange, TextSize};
use rustc_hash::FxHashSet;
use ty_module_resolver::{KnownModule, Module, ModuleName};
@@ -39,6 +41,25 @@ pub fn completion<'db>(
return vec![];
};
let model = SemanticModel::new(db, file);
if context.cursor.is_in_string() {
let Some(string_expr) = context.cursor.enclosing_string_literal_expr() else {
return vec![];
};
let mut completions =
Completions::new(db, CollectionContext::none(), UserQuery::fuzzy(None));
add_string_literal_completions(
&model,
string_expr,
context.cursor.string_quote_style(),
&mut completions,
);
return completions.into_completions();
}
let query = UserQuery::fuzzy(context.cursor.typed);
let mut completions = Completions::new(db, context.collection_context(db, &model), query);
match context.kind {
@@ -746,7 +767,7 @@ impl<'m> ContextCursor<'m> {
/// Whether the last token is in a place where we should not provide completions.
fn is_in_no_completions_place(&self) -> bool {
self.is_in_comment() || self.is_in_string() || self.is_in_definition_place()
self.is_in_comment() || self.is_in_definition_place()
}
/// Whether the last token is within a comment or not.
@@ -769,6 +790,21 @@ impl<'m> ContextCursor<'m> {
})
}
/// Returns the string literal expression that the cursor is positioned within, if any.
fn enclosing_string_literal_expr(&self) -> Option<&'m ast::ExprStringLiteral> {
match self.covering_node.parent() {
Some(ast::AnyNodeRef::ExprStringLiteral(string_expr)) => Some(string_expr),
_ => None,
}
}
/// Returns the quote style of the string literal that the cursor is positioned within, if any.
fn string_quote_style(&self) -> Option<Quote> {
self.tokens_before
.last()
.map(|token| token.string_quote_style())
}
/// Returns true when the tokens indicate that the definition of a new
/// name is being introduced at the end.
fn is_in_definition_place(&self) -> bool {
@@ -1643,6 +1679,63 @@ fn add_keyword_completions<'db>(db: &'db dyn Db, completions: &mut Completions<'
}
}
fn add_string_literal_completions<'db>(
model: &SemanticModel<'db>,
string_expr: &ast::ExprStringLiteral,
quote_style: Option<Quote>,
completions: &mut Completions<'db>,
) {
fn force_escape_quote(body: &str, quote: Quote) -> String {
let quote_char = quote.as_char();
let mut escaped = String::with_capacity(body.len());
let mut consecutive_backslashes = 0usize;
for ch in body.chars() {
if ch == '\\' {
consecutive_backslashes += 1;
escaped.push(ch);
continue;
}
if ch == quote_char && consecutive_backslashes.is_multiple_of(2) {
escaped.push('\\');
}
consecutive_backslashes = 0;
escaped.push(ch);
}
escaped
}
// When we insert a completion for a string literal, we need to make sure
// to properly escape any special characters in the completion value and
// to use the appropriate quote style.
fn escape_for_quote(value: &str, quote: Quote) -> Option<String> {
let escaped = UnicodeEscape::with_preferred_quote(value, quote);
let mut out = String::new();
escaped.write_body(&mut out).ok()?;
Some(force_escape_quote(&out, quote))
}
let candidates = model.expected_string_literal_completions(string_expr);
if candidates.is_empty() {
return;
}
let quote_style = quote_style.unwrap_or(Quote::Double);
for candidate in candidates {
let Some(insert) = escape_for_quote(&candidate.value, quote_style) else {
continue;
};
completions.add_skip_query(
Completion::builder(candidate.value.as_str())
.insert(insert)
.ty(candidate.ty)
.context_specific(true),
);
}
}
/// Adds completions not in scope.
///
/// `scoped` should be information about the identified scope
@@ -3954,7 +4047,7 @@ quux.<CURSOR>
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().type_signatures().build().snapshot(), @r###"
builder.skip_keywords().skip_builtins().type_signatures().build().snapshot(), @"
bar :: int
baz :: int
foo :: int
@@ -3981,7 +4074,7 @@ quux.<CURSOR>
__sizeof__ :: bound method Quux.__sizeof__() -> int
__str__ :: bound method Quux.__str__() -> str
__subclasshook__ :: bound method type[Quux].__subclasshook__(subclass: type, /) -> bool
"###);
");
}
#[test]
@@ -4000,13 +4093,13 @@ quux.b<CURSOR>
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().type_signatures().build().snapshot(), @r###"
builder.skip_keywords().skip_builtins().type_signatures().build().snapshot(), @"
bar :: int
baz :: int
__getattribute__ :: bound method Quux.__getattribute__(name: str, /) -> Any
__init_subclass__ :: bound method type[Quux].__init_subclass__() -> None
__subclasshook__ :: bound method type[Quux].__subclasshook__(subclass: type, /) -> bool
"###);
");
}
#[test]
@@ -6769,6 +6862,397 @@ print(t'''{Foo} and Foo.zqzq<CURSOR>
);
}
#[test]
fn string_literal_completions_function_argument() {
let builder = completion_test_builder(
r#"
from typing import Literal
A = Literal["a", "b", "c"]
def func(a: A): ...
func("<CURSOR>")
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"
a :: Literal["a"]
b :: Literal["b"]
c :: Literal["c"]
"#,
);
}
#[test]
fn string_literal_completions_overloaded_function_argument() {
let builder = completion_test_builder(
r#"
from typing import Literal, overload
@overload
def func(mode: Literal["r"]) -> int: ...
@overload
def func(mode: Literal["w"]) -> str: ...
def func(mode: str) -> int | str: ...
func("<CURSOR>")
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"
r :: Literal["r"]
w :: Literal["w"]
"#,
);
}
#[test]
fn string_literal_completions_annotated_assignment() {
let builder = completion_test_builder(
r#"
from typing import Literal
value: Literal["x", "y"] = "<CURSOR>"
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"
x :: Literal["x"]
y :: Literal["y"]
"#,
);
}
#[test]
fn string_literal_completions_nested_expected_type() {
let builder = completion_test_builder(
r#"
from typing import Literal
type A = Literal["foo", "bar", "baz"]
xs: list[A] = ["<CURSOR>"]
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"
bar :: Literal["bar"]
baz :: Literal["baz"]
foo :: Literal["foo"]
"#,
);
}
#[test]
fn string_literal_completions_filter_non_string_literals() {
let builder = completion_test_builder(
r#"
from typing import Literal
Mixed = Literal["left", 1, "right"]
def consume(value: Mixed): ...
consume("<CURSOR>")
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"
left :: Literal["left"]
right :: Literal["right"]
"#,
);
}
#[test]
fn string_literal_completions_typed_dict_keys() {
let builder = completion_test_builder(
r#"
from typing import TypedDict
class TD(TypedDict):
left: int
right: str
td: TD = {"left": 1, "right": "x"}
td["<CURSOR>"]
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"
left :: Literal["left"]
right :: Literal["right"]
"#,
);
}
#[test]
fn string_literal_completions_typed_dict_keys_assignment() {
let builder = completion_test_builder(
r#"
from typing import TypedDict
class TD(TypedDict):
left: int
right: str
td: TD = {"left": 1, "right": "x"}
td["<CURSOR>"] = 1
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"
left :: Literal["left"]
right :: Literal["right"]
"#,
);
}
#[test]
fn string_literal_completions_typed_dict_keys_deletion() {
let builder = completion_test_builder(
r#"
from typing import TypedDict
class TD(TypedDict):
left: int
right: str
td: TD = {"left": 1, "right": "x"}
del td["<CURSOR>"]
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"
left :: Literal["left"]
right :: Literal["right"]
"#,
);
}
#[test]
fn string_literal_completions_do_not_offer_typed_dict_keys_for_typed_dict_value_context() {
let builder = completion_test_builder(
r#"
from typing import TypedDict
class TD(TypedDict):
left: int
right: str
x: TD = "<CURSOR>"
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().build().snapshot(),
@"<No completions found>",
);
}
#[test]
fn string_literal_completions_typed_dict_union_keys() {
let builder = completion_test_builder(
r#"
from typing import TypedDict
class A(TypedDict):
a: int
both: int
class B(TypedDict):
b: int
both: str
x: A | B = {"both": 1}
x["<CURSOR>"]
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"
a :: Literal["a"]
b :: Literal["b"]
both :: Literal["both"]
"#,
);
}
#[test]
fn string_literal_completions_intersection_positive_union_semantics() {
let builder = completion_test_builder(
r#"
from typing import Literal
from ty_extensions import Intersection
x: Intersection[Literal["a", "b"], Literal["b", "c"]] = "<CURSOR>"
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"b :: Literal["b"]"#,
);
}
#[test]
fn string_literal_completions_intersection_excludes_negative_elements() {
let builder = completion_test_builder(
r#"
from typing import Literal
from ty_extensions import Intersection, Not
x: Intersection[Literal["a"], Not[Literal["b"]]] = "<CURSOR>"
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"a :: Literal["a"]"#,
);
}
#[test]
fn string_literal_completions_type_alias_recursion_safe() {
let builder = completion_test_builder(
r#"
from typing import Literal
type A = Literal["foo", "bar"]
type B = A
x: B = "<CURSOR>"
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"
bar :: Literal["bar"]
foo :: Literal["foo"]
"#,
);
}
#[test]
fn string_literal_completions_single_quote_escaping() {
let builder = completion_test_builder(
r#"
from typing import Literal
x: Literal["can't", "won't"] = '<CURSOR>'
"#,
);
let builder = builder.skip_keywords().skip_builtins().skip_auto_import();
let test = builder.build();
let inserts = test
.completions()
.iter()
.map(|completion| {
completion
.insert
.as_deref()
.unwrap_or(completion.name.as_str())
.to_string()
})
.collect::<Vec<_>>();
assert_eq!(inserts, vec![r"can\'t", r"won\'t"]);
}
#[test]
fn string_literal_completions_double_quote_escaping() {
let builder = completion_test_builder(
r#"
from typing import Literal
x: Literal['say "hi"', 'say "bye"'] = "<CURSOR>"
"#,
);
let builder = builder.skip_keywords().skip_builtins().skip_auto_import();
let test = builder.build();
let inserts = test
.completions()
.iter()
.map(|completion| {
completion
.insert
.as_deref()
.unwrap_or(completion.name.as_str())
.to_string()
})
.collect::<Vec<_>>();
assert_eq!(inserts, vec![r#"say \"bye\""#, r#"say \"hi\""#]);
}
#[test]
fn string_literal_completions_backslash_escaping() {
let builder = completion_test_builder(
r#"
from typing import Literal
x: Literal["a\\b"] = "<CURSOR>"
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().build().snapshot(),
@r"a\\b",
);
}
#[test]
fn string_literal_completions_in_incomplete_string() {
let builder = completion_test_builder(
r#"
from typing import Literal
x: Literal["a"] = "a<CURSOR>
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().build().snapshot(),
@"a",
);
}
#[test]
fn string_literal_completions_in_standalone_statement() {
let builder = completion_test_builder(
r#"
from collections.abc import Callable
from typing import Literal
def func(callback: Callable[[], Literal["yes", "no"]]) -> None: ...
x = y = func(lambda: "<CURSOR>")
"#,
);
assert_snapshot!(
builder.skip_keywords().skip_builtins().skip_auto_import().type_signatures().build().snapshot(),
@r#"
no :: Literal["no"]
yes :: Literal["yes"]
"#,
);
}
#[test]
fn typevar_with_upper_bound() {
let builder = completion_test_builder(
+2 -2
View File
@@ -16,8 +16,8 @@ use ruff_db::parsed::parsed_module;
use ruff_db::source::{SourceTextError, source_text};
use rustc_hash::FxHasher;
pub use semantic_model::{
Completion, HasDefinition, HasOptionalDefinition, HasType, MemberDefinition, NameKind,
SemanticModel,
Completion, ExpectedStringLiteralCompletion, HasDefinition, HasOptionalDefinition, HasType,
MemberDefinition, NameKind, SemanticModel,
};
use std::hash::BuildHasherDefault;
pub use suppression::{
@@ -15,7 +15,7 @@ use crate::place::implicit_globals::all_implicit_module_globals;
use crate::types::ide_support::{ImportAliasResolution, definition_for_name};
use crate::types::list_members::{Member, all_members, all_reachable_members};
use crate::types::{
Type, TypeQualifiers, binding_type, declaration_type, infer_complete_scope_types,
CycleDetector, Type, TypeQualifiers, binding_type, declaration_type, infer_complete_scope_types,
};
use ty_python_core::definition::Definition;
use ty_python_core::place_table;
@@ -437,6 +437,77 @@ impl<'db> SemanticModel<'db> {
_ => TypeQualifiers::empty(),
}
}
/// Returns completion candidates for a string-literal expression based on its expected type.
pub fn expected_string_literal_completions(
&self,
string_expr: &ast::ExprStringLiteral,
) -> Vec<ExpectedStringLiteralCompletion<'db>> {
struct StringLiteralCandidates;
type StringLiteralCandidatesVisitor<'db> = CycleDetector<
StringLiteralCandidates,
Type<'db>,
Vec<ExpectedStringLiteralCompletion<'db>>,
>;
fn collect<'db>(
db: &'db dyn Db,
ty: Type<'db>,
visitor: &StringLiteralCandidatesVisitor<'db>,
) -> Vec<ExpectedStringLiteralCompletion<'db>> {
match ty {
Type::LiteralValue(literal) => literal
.as_string()
.map(|string_literal| {
let value = string_literal.value(db).to_string();
vec![ExpectedStringLiteralCompletion {
ty: Type::string_literal(db, &value),
value,
}]
})
.unwrap_or_default(),
Type::Union(union) => union
.elements(db)
.iter()
.flat_map(|element| collect(db, *element, visitor))
.collect(),
Type::Intersection(intersection) => intersection
.positive(db)
.iter()
.flat_map(|element| collect(db, *element, visitor))
.collect(),
Type::TypeAlias(alias) => {
visitor.visit(ty, || collect(db, alias.value_type(db), visitor))
}
_ => Vec::new(),
}
}
let Some(expected_ty) = self.string_literal_completion_expected_type(string_expr) else {
return Vec::new();
};
let mut candidates = collect(
self.db,
expected_ty,
&StringLiteralCandidatesVisitor::default(),
);
candidates.sort_unstable_by(|left, right| left.value.cmp(&right.value));
candidates.dedup_by(|left, right| left.value == right.value);
candidates
}
fn string_literal_completion_expected_type(
&self,
string_expr: &ast::ExprStringLiteral,
) -> Option<Type<'db>> {
let expr = ast::ExprRef::from(string_expr);
let index = semantic_index(self.db, self.file);
let file_scope = index.try_expression_scope_id(&self.expr_ref_in_ast(expr))?;
let scope = file_scope.to_scope_id(self.db, self.file);
infer_complete_scope_types(self.db, scope).try_expected_type(expr)
}
}
/// The type and definition of a symbol.
@@ -500,6 +571,12 @@ pub struct Completion<'db> {
pub builtin: bool,
}
#[derive(Clone, Debug)]
pub struct ExpectedStringLiteralCompletion<'db> {
pub value: String,
pub ty: Type<'db>,
}
pub trait HasType {
/// Returns the inferred type of `self`.
///
@@ -698,6 +698,8 @@ pub(crate) struct ScopeInference<'db> {
struct ScopeInferenceExtra<'db> {
/// String annotations found in this region
string_annotations: FxHashSet<ExpressionNodeKey>,
/// Expected types for expression nodes tracked for IDE completion.
expected_types: FxHashMap<ExpressionNodeKey, Type<'db>>,
/// Metadata for type expressions in this region.
type_expression_flags: FxHashMap<ExpressionNodeKey, TypeExpressionFlags>,
@@ -753,6 +755,15 @@ impl<'db> ScopeInference<'db> {
.or_else(|| self.fallback_type())
}
pub(crate) fn try_expected_type(
&self,
expression: impl Into<ExpressionNodeKey>,
) -> Option<Type<'db>> {
self.extra
.as_deref()
.and_then(|extra| extra.expected_types.get(&expression.into()).copied())
}
fn fallback_type(&self) -> Option<Type<'db>> {
self.extra.as_ref().and_then(|extra| extra.cycle_recovery)
}
@@ -810,6 +821,8 @@ pub(crate) struct DefinitionInference<'db> {
struct DefinitionInferenceExtra<'db> {
/// String annotations found in this region
string_annotations: FxHashSet<ExpressionNodeKey>,
/// Expected types for expression nodes tracked for IDE completion.
expected_types: FxHashMap<ExpressionNodeKey, Type<'db>>,
/// Functions called while inferring this definition.
called_functions: Box<[FunctionType<'db>]>,
@@ -1005,6 +1018,8 @@ pub(crate) struct ExpressionInference<'db> {
struct ExpressionInferenceExtra<'db> {
/// String annotations found in this region
string_annotations: FxHashSet<ExpressionNodeKey>,
/// Expected types for expression nodes tracked for IDE completion.
expected_types: FxHashMap<ExpressionNodeKey, Type<'db>>,
/// Metadata for type expressions in this region.
type_expression_flags: FxHashMap<ExpressionNodeKey, TypeExpressionFlags>,
@@ -1130,6 +1145,9 @@ struct StatementInferenceInnerExtra<'db> {
/// String annotations found in this region
string_annotations: FxHashSet<ExpressionNodeKey>,
/// Expected types for expression nodes tracked for IDE completion.
expected_types: FxHashMap<ExpressionNodeKey, Type<'db>>,
/// Functions called while inferring this statement.
called_functions: Box<[FunctionType<'db>]>,
@@ -248,6 +248,8 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
/// Expressions that are string annotations
string_annotations: FxHashSet<ExpressionNodeKey>,
/// Expected types for expression nodes tracked for IDE completion.
expected_types: FxHashMap<ExpressionNodeKey, Type<'db>>,
/// The scope this region is part of.
scope: ScopeId<'db>,
@@ -353,6 +355,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
qualifiers: FxHashMap::default(),
type_expression_flags: FxHashMap::default(),
string_annotations: FxHashSet::default(),
expected_types: FxHashMap::default(),
bindings: VecMap::default(),
declarations: VecMap::default(),
typevar_binding_context: None,
@@ -400,6 +403,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.deferred.extend(extra.deferred.iter().copied());
self.string_annotations
.extend(extra.string_annotations.iter().copied());
self.expected_types.extend(extra.expected_types.iter());
self.qualifiers.extend(extra.qualifiers.iter());
self.type_expression_flags
.extend(extra.type_expression_flags.iter());
@@ -431,6 +435,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.deferred.extend(extra.deferred.iter().copied());
self.string_annotations
.extend(extra.string_annotations.iter().copied());
self.expected_types.extend(extra.expected_types.iter());
self.qualifiers.extend(extra.qualifiers.iter());
self.type_expression_flags
.extend(extra.type_expression_flags.iter());
@@ -452,6 +457,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.extend_cycle_recovery(extra.cycle_recovery);
self.string_annotations
.extend(extra.string_annotations.iter().copied());
self.expected_types.extend(extra.expected_types.iter());
self.type_expression_flags
.extend(extra.type_expression_flags.iter());
@@ -469,6 +475,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.extend_cycle_recovery(extra.cycle_recovery);
self.string_annotations
.extend(extra.string_annotations.iter().copied());
self.expected_types.extend(extra.expected_types.iter());
self.type_expression_flags
.extend(extra.type_expression_flags.iter());
}
@@ -5281,10 +5288,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// We use a speculative builder to silence any diagnostics emitted during multi-inference, as the
// type context is only used as a hint to infer a more assignable argument type, and should not lead
// to diagnostics for non-matching overloads.
let mut speculative_builder = self.speculate();
let inferred_ty = infer_argument_ty(
&mut self.speculate(),
&mut speculative_builder,
(argument_index, ast_argument, parameter_tcx),
);
self.union_expected_types(&speculative_builder.expected_types);
argument_types.insert(parameter.annotated_type(), inferred_ty);
}
@@ -5483,6 +5492,47 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
assert_eq!(previous, None);
}
fn store_maybe_expected_type(
&mut self,
expression: impl Into<ExpressionNodeKey>,
ty: Type<'db>,
) {
if !self.has_string_literal_completion_candidates(ty) {
return;
}
self.store_expected_type(expression, ty);
}
fn store_expected_type(&mut self, expression: impl Into<ExpressionNodeKey>, ty: Type<'db>) {
self.expected_types.insert(expression.into(), ty);
}
fn has_string_literal_completion_candidates(&self, ty: Type<'db>) -> bool {
match ty {
Type::LiteralValue(literal) => literal.as_string().is_some(),
Type::Union(union) => union
.elements(self.db())
.iter()
.any(|ty| self.has_string_literal_completion_candidates(*ty)),
Type::Intersection(intersection) => intersection
.iter_positive(self.db())
.any(|ty| self.has_string_literal_completion_candidates(ty)),
Type::TypeAlias(_) => true,
_ => false,
}
}
fn union_expected_types(&mut self, expected_types: &FxHashMap<ExpressionNodeKey, Type<'db>>) {
let db = self.db();
for (expression, ty) in expected_types {
self.expected_types
.entry(*expression)
.and_modify(|existing| *existing = UnionType::from_two_elements(db, *existing, *ty))
.or_insert(*ty);
}
}
fn infer_number_literal_expression(&self, literal: &ast::ExprNumberLiteral) -> Type<'db> {
let ast::ExprNumberLiteral {
range: _,
@@ -5517,6 +5567,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
literal: &ast::ExprStringLiteral,
tcx: TypeContext<'db>,
) -> Type<'db> {
if let Some(expected) = tcx.annotation {
self.store_maybe_expected_type(ast::ExprRef::from(literal), expected);
}
if tcx.is_typealias() {
let aliased_type = self.infer_string_type_expression(literal);
return Type::KnownInstance(KnownInstanceType::LiteralStringAlias(InternedType::new(
@@ -8993,7 +9047,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
mut expressions,
qualifiers: _,
mut type_expression_flags,
string_annotations,
mut string_annotations,
mut expected_types,
scope,
bindings,
declarations,
@@ -9029,6 +9084,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let extra =
(!string_annotations.is_empty()
|| !type_expression_flags.is_empty()
|| !expected_types.is_empty()
|| cycle_recovery.is_some()
|| !bindings.is_empty()
|| !diagnostics.is_empty()).then(|| {
@@ -9041,8 +9097,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
type_expression_flags.shrink_to_fit();
expected_types.shrink_to_fit();
string_annotations.shrink_to_fit();
Box::new(ExpressionInferenceExtra {
string_annotations,
expected_types,
type_expression_flags,
bindings: bindings.into_boxed_slice(),
diagnostics,
@@ -9068,7 +9127,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
mut expressions,
mut qualifiers,
mut type_expression_flags,
string_annotations,
mut string_annotations,
mut expected_types,
scope,
bindings,
declarations,
@@ -9095,15 +9155,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let extra = (!diagnostics.is_empty()
|| !string_annotations.is_empty()
|| cycle_recovery.is_some()
|| !expected_types.is_empty()
|| !deferred.is_empty()
|| !called_functions.is_empty()
|| !qualifiers.is_empty()
|| !type_expression_flags.is_empty())
.then(|| {
qualifiers.shrink_to_fit();
expected_types.shrink_to_fit();
type_expression_flags.shrink_to_fit();
string_annotations.shrink_to_fit();
Box::new(StatementInferenceInnerExtra {
string_annotations,
expected_types,
called_functions: called_functions
.into_iter()
.collect::<Vec<_>>()
@@ -9176,6 +9240,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
deferred: _,
scope: _,
string_annotations: _,
expected_types: _,
return_types_and_ranges: _,
dataclass_field_specifiers: _,
undecorated_type: _,
@@ -9209,7 +9274,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
mut expressions,
mut qualifiers,
mut type_expression_flags,
string_annotations,
mut string_annotations,
mut expected_types,
scope,
bindings,
declarations,
@@ -9233,6 +9299,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let extra = (!diagnostics.is_empty()
|| !string_annotations.is_empty()
|| !expected_types.is_empty()
|| cycle_recovery.is_some()
|| undecorated_type.is_some()
|| !deferred.is_empty()
@@ -9242,8 +9309,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.then(|| {
qualifiers.shrink_to_fit();
type_expression_flags.shrink_to_fit();
expected_types.shrink_to_fit();
string_annotations.shrink_to_fit();
Box::new(DefinitionInferenceExtra {
string_annotations,
expected_types,
called_functions: called_functions
.into_iter()
.collect::<Vec<_>>()
@@ -9290,7 +9360,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let Self {
context,
string_annotations,
mut string_annotations,
mut expected_types,
mut type_expression_flags,
mut expressions,
scope,
@@ -9321,12 +9392,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let extra = (!string_annotations.is_empty()
|| !type_expression_flags.is_empty()
|| !expected_types.is_empty()
|| !diagnostics.is_empty()
|| cycle_recovery.is_some())
.then(|| {
type_expression_flags.shrink_to_fit();
expected_types.shrink_to_fit();
string_annotations.shrink_to_fit();
Box::new(ScopeInferenceExtra {
string_annotations,
expected_types,
type_expression_flags,
cycle_recovery,
diagnostics,
@@ -9363,6 +9438,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
context: _,
expressions: _,
string_annotations: _,
expected_types: _,
scope: _,
bindings: _,
declarations: _,
@@ -9401,6 +9477,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
expressions,
type_expression_flags,
string_annotations,
expected_types,
scope,
bindings,
declarations,
@@ -9439,6 +9516,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.extend_cycle_recovery(cycle_recovery);
self.string_annotations
.extend(string_annotations.iter().copied());
self.expected_types.extend(expected_types.iter());
self.type_expression_flags
.extend(type_expression_flags.iter());
@@ -26,8 +26,8 @@ use crate::types::subscript::{LegacyGenericOrigin, SubscriptError, SubscriptErro
use crate::types::tuple::{Tuple, TupleType};
use crate::types::typed_dict::{TypedDictAssignmentKind, TypedDictKeyAssignment};
use crate::types::{
BoundTypeVarInstance, CallArguments, CallDunderError, DynamicType, InternedType, KnownClass,
KnownInstanceType, LintDiagnosticGuard, Parameter, Parameters, SpecialFormType,
BoundTypeVarInstance, CallArguments, CallDunderError, CycleDetector, DynamicType, InternedType,
KnownClass, KnownInstanceType, LintDiagnosticGuard, Parameter, Parameters, SpecialFormType,
StaticClassLiteral, Type, TypeAliasType, TypeAndQualifiers, TypeContext,
TypeVarBoundOrConstraints, UnionType, UnionTypeInstance, any_over_type, todo_type,
};
@@ -38,6 +38,57 @@ use ty_python_core::place::{PlaceExpr, PlaceExprRef};
use ty_python_core::scope::FileScopeId;
impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn typed_dict_key_expected_type(&self, ty: Type<'db>) -> Option<Type<'db>> {
struct TypedDictKeyExpectedType;
type TypedDictKeyExpectedTypeVisitor<'db> =
CycleDetector<TypedDictKeyExpectedType, Type<'db>, Option<Type<'db>>>;
fn imp<'db>(
db: &'db dyn Db,
ty: Type<'db>,
visitor: &TypedDictKeyExpectedTypeVisitor<'db>,
) -> Option<Type<'db>> {
match ty {
Type::TypedDict(typed_dict) => {
let keys = typed_dict
.items(db)
.keys()
.map(|key| Type::string_literal(db, key.as_str()))
.collect_vec();
(!keys.is_empty()).then(|| UnionType::from_elements(db, keys))
}
Type::Union(union) => {
let keys = union
.elements(db)
.iter()
.filter_map(|element| imp(db, *element, visitor))
.collect_vec();
(!keys.is_empty()).then(|| UnionType::from_elements(db, keys))
}
Type::Intersection(intersection) => {
let keys = intersection
.positive(db)
.iter()
.filter_map(|element| imp(db, *element, visitor))
.collect_vec();
(!keys.is_empty()).then(|| UnionType::from_elements(db, keys))
}
Type::TypeAlias(alias) => {
visitor.visit(ty, || imp(db, alias.value_type(db), visitor))
}
_ => None,
}
}
imp(self.db(), ty, &TypedDictKeyExpectedTypeVisitor::default())
}
fn store_typed_dict_key_expected_type(&mut self, slice: &ast::Expr, value_ty: Type<'db>) {
if let Some(expected_key_ty) = self.typed_dict_key_expected_type(value_ty) {
self.store_expected_type(slice, expected_key_ty);
}
}
pub(super) fn infer_subscript_expression(
&mut self,
subscript: &ast::ExprSubscript,
@@ -54,12 +105,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ExprContext::Load => self.infer_subscript_load(subscript),
ExprContext::Store => {
let value_ty = self.infer_expression(value, TypeContext::default());
self.store_typed_dict_key_expected_type(slice, value_ty);
let slice_ty = self.infer_expression(slice, TypeContext::default());
self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
Type::Never
}
ExprContext::Del => {
let value_ty = self.infer_expression(value, TypeContext::default());
self.store_typed_dict_key_expected_type(slice, value_ty);
let slice_ty = self.infer_expression(slice, TypeContext::default());
self.validate_subscript_deletion(subscript, value_ty, slice_ty);
Type::Never
@@ -101,6 +154,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ctx,
} = subscript;
self.store_typed_dict_key_expected_type(slice, value_ty);
let mut constraint_keys = vec![];
// If `value` is a valid reference, we attempt type narrowing by assignment.
@@ -1176,6 +1231,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} = target;
let object_ty = self.infer_expression(object, TypeContext::default());
self.store_typed_dict_key_expected_type(slice, object_ty);
let mut infer_slice_ty = |builder: &mut Self, tcx| builder.infer_expression(slice, tcx);
self.validate_subscript_assignment_impl(
+1 -1
View File
@@ -443,7 +443,7 @@ pub(crate) fn server_capabilities(
},
)),
completion_provider: Some(CompletionOptions {
trigger_characters: Some(vec!['.'.to_string()]),
trigger_characters: Some(vec!['.'.to_string(), '"'.to_string(), '\''.to_string()]),
..Default::default()
}),
selection_range_provider: Some(SelectionRangeProviderCapability::Simple(true)),
+40
View File
@@ -346,3 +346,43 @@ re.match('', '', fla<CURSOR>
Ok(())
}
/// Tests the LSP-facing shape for string-literal completions with an already-typed prefix.
///
/// The server intentionally returns the full completion in `insertText`. Without an explicit
/// `textEdit`, LSP clients are allowed to interpret that insert text relative to the current
/// word; for example, VS Code applies `insertText: "apple"` at `app|` as the suffix `le`.
#[test]
fn string_literal_completion_uses_full_lsp_insert_text() -> Result<()> {
let workspace_root = SystemPath::new("src");
let foo = SystemPath::new("src/foo.py");
let foo_content = "\
from typing import Literal
x: Literal[\"apple\"] = \"app\"
";
let mut server = TestServerBuilder::new()?
.with_initialization_options(ClientOptions::default().with_auto_import(false))
.with_workspace(workspace_root, None)?
.with_file(foo, foo_content)?
.build()
.wait_until_workspaces_are_initialized();
server.open_text_document(foo, foo_content, 1);
let completions = server.completion_request(&server.file_uri(foo), Position::new(1, 26));
insta::assert_json_snapshot!(completions, @r#"
[
{
"label": "apple",
"kind": 12,
"detail": "Literal[\"apple\"]",
"sortText": "0",
"insertText": "apple"
}
]
"#);
Ok(())
}
@@ -25,7 +25,9 @@ expression: initialization_result
"hoverProvider": true,
"completionProvider": {
"triggerCharacters": [
"."
".",
"\"",
"'"
]
},
"signatureHelpProvider": {
@@ -25,7 +25,9 @@ expression: initialization_result
"hoverProvider": true,
"completionProvider": {
"triggerCharacters": [
"."
".",
"\"",
"'"
]
},
"signatureHelpProvider": {