[ty] Unpack Union of TypedDict in various sites (#24958)

## Summary

We already have a helper for this; we just weren't using it everywhere.
This commit is contained in:
Charlie Marsh
2026-05-01 12:18:58 -04:00
committed by GitHub
parent c6057e034b
commit 81c81f6892
4 changed files with 69 additions and 14 deletions
@@ -1363,6 +1363,37 @@ f(**Foo1(a=1, b="b"))
f(**Foo2(a=1))
```
### TypedDict union
```py
from typing_extensions import TypedDict
class GoodA(TypedDict):
a: int
b: int
class GoodB(TypedDict):
a: int
b: int
class BadA(TypedDict):
a: int
b: str
class BadB(TypedDict):
a: int
b: str
def needs_known_keys(*, a: int, b: int, c: int) -> None: ...
def takes_int_kwargs(**kwargs: int) -> None: ...
def _(good: GoodA | GoodB, bad: BadA | BadB) -> None:
# error: [missing-argument] "No argument provided for required parameter `c` of function `needs_known_keys`"
needs_known_keys(**good)
# error: [invalid-argument-type] "Argument to function `takes_int_kwargs` is incorrect: Expected `int`, found `str`"
takes_int_kwargs(**bad)
```
### Keys must be strings
The keys of the mapping passed to a double-starred argument must be strings.
@@ -706,6 +706,26 @@ p = partial(f, **kwargs)
reveal_type(p) # revealed: partial[(a: int, *, b: str = ...) -> bool]
```
### Kwargs splat with union of TypedDicts
```py
from functools import partial
from typing import TypedDict
class KwargsA(TypedDict):
b: str
class KwargsB(TypedDict):
b: str
def f(*, b: str) -> bool:
return True
def make(kwargs: KwargsA | KwargsB) -> None:
p = partial(f, **kwargs)
reveal_type(p) # revealed: partial[(*, b: str = ...) -> bool]
```
### Mixed keywords and kwargs splat
```py
@@ -8,6 +8,7 @@ use rustc_hash::FxHashMap;
use crate::Db;
use crate::types::enums::{enum_member_literals, enum_metadata};
use crate::types::tuple::Tuple;
use crate::types::typed_dict::extract_unpacked_typed_dict_keys_from_value_type;
use crate::types::{KnownClass, Type, TypeContext};
/// Maximum number of expanded types that can be generated from a single tuple's
@@ -266,12 +267,11 @@ impl<'a, 'db> CallArguments<'a, 'db> {
),
// Optional TypedDict keys may be absent at runtime, so we can only refine
// `partial(...)` when every expanded key is guaranteed to be present.
Argument::Keywords => argument_ty.as_typed_dict().is_none_or(|typed_dict| {
typed_dict
.items(db)
.values()
.any(|field| !field.is_required())
}),
Argument::Keywords => {
extract_unpacked_typed_dict_keys_from_value_type(db, argument_ty).is_none_or(
|unpacked_keys| unpacked_keys.values().any(|key| !key.is_required),
)
}
Argument::Positional | Argument::Synthetic | Argument::Keyword(_) => false,
}
}) {
@@ -50,6 +50,7 @@ use crate::types::signatures::{
PartialApplication, PartialSignatureApplication,
};
use crate::types::tuple::{TupleLength, TupleSpec, TupleType};
use crate::types::typed_dict::extract_unpacked_typed_dict_keys_from_value_type;
use crate::types::typevar::BoundTypeVarIdentity;
use crate::types::{
BoundMethodType, BoundTypeVarInstance, CallableType, CallableTypes, ClassLiteral,
@@ -4343,13 +4344,15 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
argument_index: usize,
argument_type: Option<Type<'db>>,
) {
if let Some(Type::TypedDict(typed_dict)) = argument_type {
// Special case TypedDict because we know which keys are present.
for (name, field) in typed_dict.items(db) {
if let Some(unpacked_keys) =
argument_type.and_then(|ty| extract_unpacked_typed_dict_keys_from_value_type(db, ty))
{
// Special case TypedDict-shaped values because we know which keys are present.
for (name, unpacked_key) in unpacked_keys {
let _ = self.match_keyword(
argument_index,
Argument::Keywords,
Some(field.declared_ty),
Some(unpacked_key.value_ty),
name.as_str(),
);
}
@@ -5205,11 +5208,12 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
argument: Argument<'a>,
argument_type: Type<'db>,
) {
if let Type::TypedDict(typed_dict) = argument_type {
for (argument_type, parameter_index) in typed_dict
.items(self.db)
if let Some(unpacked_keys) =
extract_unpacked_typed_dict_keys_from_value_type(self.db, argument_type)
{
for (argument_type, parameter_index) in unpacked_keys
.values()
.map(|field| field.declared_ty)
.map(|unpacked_key| unpacked_key.value_ty)
.zip(&self.argument_matches[argument_index].parameters)
{
self.check_argument_type(