From 81c81f689213259a3d093a983cc088627c7b2dd3 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Fri, 1 May 2026 12:18:58 -0400 Subject: [PATCH] [ty] Unpack Union of TypedDict in various sites (#24958) ## Summary We already have a helper for this; we just weren't using it everywhere. --- .../resources/mdtest/call/function.md | 31 +++++++++++++++++++ .../mdtest/call/functools_partial.md | 20 ++++++++++++ .../src/types/call/arguments.rs | 12 +++---- .../ty_python_semantic/src/types/call/bind.rs | 20 +++++++----- 4 files changed, 69 insertions(+), 14 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/call/function.md b/crates/ty_python_semantic/resources/mdtest/call/function.md index dfcb40ecef..86ff22376b 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/function.md +++ b/crates/ty_python_semantic/resources/mdtest/call/function.md @@ -1363,6 +1363,37 @@ f(**Foo1(a=1, b="b")) f(**Foo2(a=1)) ``` +### TypedDict union + +```py +from typing_extensions import TypedDict + +class GoodA(TypedDict): + a: int + b: int + +class GoodB(TypedDict): + a: int + b: int + +class BadA(TypedDict): + a: int + b: str + +class BadB(TypedDict): + a: int + b: str + +def needs_known_keys(*, a: int, b: int, c: int) -> None: ... +def takes_int_kwargs(**kwargs: int) -> None: ... +def _(good: GoodA | GoodB, bad: BadA | BadB) -> None: + # error: [missing-argument] "No argument provided for required parameter `c` of function `needs_known_keys`" + needs_known_keys(**good) + + # error: [invalid-argument-type] "Argument to function `takes_int_kwargs` is incorrect: Expected `int`, found `str`" + takes_int_kwargs(**bad) +``` + ### Keys must be strings The keys of the mapping passed to a double-starred argument must be strings. diff --git a/crates/ty_python_semantic/resources/mdtest/call/functools_partial.md b/crates/ty_python_semantic/resources/mdtest/call/functools_partial.md index 5e97227e31..d38ca87fa1 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/functools_partial.md +++ b/crates/ty_python_semantic/resources/mdtest/call/functools_partial.md @@ -706,6 +706,26 @@ p = partial(f, **kwargs) reveal_type(p) # revealed: partial[(a: int, *, b: str = ...) -> bool] ``` +### Kwargs splat with union of TypedDicts + +```py +from functools import partial +from typing import TypedDict + +class KwargsA(TypedDict): + b: str + +class KwargsB(TypedDict): + b: str + +def f(*, b: str) -> bool: + return True + +def make(kwargs: KwargsA | KwargsB) -> None: + p = partial(f, **kwargs) + reveal_type(p) # revealed: partial[(*, b: str = ...) -> bool] +``` + ### Mixed keywords and kwargs splat ```py diff --git a/crates/ty_python_semantic/src/types/call/arguments.rs b/crates/ty_python_semantic/src/types/call/arguments.rs index bd3cc7258a..cd56c2edb2 100644 --- a/crates/ty_python_semantic/src/types/call/arguments.rs +++ b/crates/ty_python_semantic/src/types/call/arguments.rs @@ -8,6 +8,7 @@ use rustc_hash::FxHashMap; use crate::Db; use crate::types::enums::{enum_member_literals, enum_metadata}; use crate::types::tuple::Tuple; +use crate::types::typed_dict::extract_unpacked_typed_dict_keys_from_value_type; use crate::types::{KnownClass, Type, TypeContext}; /// Maximum number of expanded types that can be generated from a single tuple's @@ -266,12 +267,11 @@ impl<'a, 'db> CallArguments<'a, 'db> { ), // Optional TypedDict keys may be absent at runtime, so we can only refine // `partial(...)` when every expanded key is guaranteed to be present. - Argument::Keywords => argument_ty.as_typed_dict().is_none_or(|typed_dict| { - typed_dict - .items(db) - .values() - .any(|field| !field.is_required()) - }), + Argument::Keywords => { + extract_unpacked_typed_dict_keys_from_value_type(db, argument_ty).is_none_or( + |unpacked_keys| unpacked_keys.values().any(|key| !key.is_required), + ) + } Argument::Positional | Argument::Synthetic | Argument::Keyword(_) => false, } }) { diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index c6a0ea537b..7f36e40295 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -50,6 +50,7 @@ use crate::types::signatures::{ PartialApplication, PartialSignatureApplication, }; use crate::types::tuple::{TupleLength, TupleSpec, TupleType}; +use crate::types::typed_dict::extract_unpacked_typed_dict_keys_from_value_type; use crate::types::typevar::BoundTypeVarIdentity; use crate::types::{ BoundMethodType, BoundTypeVarInstance, CallableType, CallableTypes, ClassLiteral, @@ -4343,13 +4344,15 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> { argument_index: usize, argument_type: Option>, ) { - if let Some(Type::TypedDict(typed_dict)) = argument_type { - // Special case TypedDict because we know which keys are present. - for (name, field) in typed_dict.items(db) { + if let Some(unpacked_keys) = + argument_type.and_then(|ty| extract_unpacked_typed_dict_keys_from_value_type(db, ty)) + { + // Special case TypedDict-shaped values because we know which keys are present. + for (name, unpacked_key) in unpacked_keys { let _ = self.match_keyword( argument_index, Argument::Keywords, - Some(field.declared_ty), + Some(unpacked_key.value_ty), name.as_str(), ); } @@ -5205,11 +5208,12 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { argument: Argument<'a>, argument_type: Type<'db>, ) { - if let Type::TypedDict(typed_dict) = argument_type { - for (argument_type, parameter_index) in typed_dict - .items(self.db) + if let Some(unpacked_keys) = + extract_unpacked_typed_dict_keys_from_value_type(self.db, argument_type) + { + for (argument_type, parameter_index) in unpacked_keys .values() - .map(|field| field.declared_ty) + .map(|unpacked_key| unpacked_key.value_ty) .zip(&self.argument_matches[argument_index].parameters) { self.check_argument_type(