[ty] Model functools.partial call results (#24582)

## Summary

This PR adds initial support for `functools.partial`, including:

- Constructor-time checking of bound arguments (e.g., `partial(f, "x")`
should report an immediate error if `"x` is not a valid type for the
parameter)
- Reduced signatures for partials (e.g., `def f(a: int, b: str, *, c:
bool) -> bytes` with `partial(f, 1)` becomes `partial[(b: str, *, c:
bool) -> bytes]`).
- Support for overloads, assignability checks, and more.

There are a few things that are _not_ covered and were instead cordoned
off into separate commits, namely:

- Preserving unprovided generic type variables in the returned partial
signature (fixed in: https://github.com/astral-sh/ruff/pull/24583). As
of this commit, we get:

```python
from functools import partial
from typing import TypeVar

T = TypeVar("T")
U = TypeVar("U")

def combine(a: T, b: U) -> tuple[T, U]:
    return (a, b)

# partial[(b: Unknown) -> tuple[Literal[1], Unknown]]
p = partial(combine, 1)
```

- Keyword overrides in generics (e.g., `partial(combine, b=1)` can later
be called as `p("x", b="y")`, since keyword arguments can be overridden
at call time -- TIL!).
- Constructor modeling (`__new__`, etc.)

But this gets us much of the way there. After this PR, I believe our
handling of `functools.partial` is generally ahead of Mypy and Pyright
with the significant exception of generic modeling, where ty is behind.

(I choose to include tests for the above in
`crates/ty_python_semantic/resources/mdtest/call/functools_partial.md`,
with TODOs, which get resolved in subsequent PRs.)

See: https://github.com/astral-sh/ty/issues/1536.
This commit is contained in:
Charlie Marsh
2026-04-30 11:28:06 -04:00
committed by GitHub
parent 0fb4f62330
commit 95670c1f56
13 changed files with 2691 additions and 59 deletions
File diff suppressed because it is too large Load Diff
+90 -1
View File
@@ -2290,6 +2290,9 @@ impl<'db> Type<'db> {
/// Return true if this type is non-empty and all inhabitants of this type compare equal.
pub(crate) fn is_single_valued(self, db: &'db dyn Db) -> bool {
match self {
// Each `partial()` call creates a distinct object at runtime.
Type::KnownInstance(KnownInstanceType::FunctoolsPartial(_)) => false,
Type::FunctionLiteral(..)
| Type::WrapperDescriptor(_)
| Type::KnownBoundMethod(_)
@@ -3477,6 +3480,39 @@ impl<'db> Type<'db> {
.into()
}
Type::KnownInstance(KnownInstanceType::FunctoolsPartial(partial))
if name_str == "__call__" =>
{
Place::bound(Type::Callable(partial.partial(db))).into()
}
Type::KnownInstance(KnownInstanceType::FunctoolsPartial(partial)) => {
let wrapped = partial.wrapped(db).inner(db);
let nominal_lookup = partial
.partial(db)
.into_functools_partial_instance(db)
.member_lookup_with_policy(db, name.clone(), policy);
if name_str == "func" {
match nominal_lookup.place {
Place::Defined(DefinedPlace {
origin,
definedness,
public_type_policy,
..
}) => Place::Defined(DefinedPlace {
ty: wrapped,
origin,
definedness,
public_type_policy,
})
.into(),
Place::Undefined => Place::bound(wrapped).into(),
}
} else {
nominal_lookup
}
}
Type::NominalInstance(..)
| Type::ProtocolInstance(..)
| Type::NewTypeInstance(..)
@@ -4146,6 +4182,10 @@ impl<'db> Type<'db> {
)
.into(),
Type::KnownInstance(KnownInstanceType::FunctoolsPartial(partial)) => {
Type::Callable(partial.partial(db)).bindings(db)
}
Type::KnownInstance(known_instance) => {
known_instance.instance_fallback(db).bindings(db)
}
@@ -4402,6 +4442,47 @@ impl<'db> Type<'db> {
)
}
KnownClass::FunctoolsPartial => {
// ```py
// class partial(Generic[_T]):
// def __new__(cls, func: Callable[..., _T], /, *args: Any, **kwargs: Any) -> Self: ...
// ```
let return_ty = BoundTypeVarInstance::synthetic(
db,
Name::new_static("_T"),
TypeVarVariance::Covariant,
);
Some(
Binding::single(
self,
Signature::new_generic(
Some(GenericContext::from_typevar_instances(db, [return_ty])),
Parameters::new(
db,
[
Parameter::positional_only(Some(Name::new_static("func")))
.with_annotated_type(Type::single_callable(
db,
Signature::new(
Parameters::gradual_form(),
Type::TypeVar(return_ty),
),
)),
Parameter::variadic(Name::new_static("args"))
.with_annotated_type(Type::any()),
Parameter::keyword_variadic(Name::new_static("kwargs"))
.with_annotated_type(Type::any()),
],
),
KnownClass::FunctoolsPartial
.to_specialized_instance(db, &[Type::TypeVar(return_ty)]),
),
)
.into(),
)
}
KnownClass::Tuple => {
let element_ty = BoundTypeVarInstance::synthetic(
db,
@@ -4531,6 +4612,7 @@ impl<'db> Type<'db> {
KnownClass::Bool
| KnownClass::Type
| KnownClass::Object
| KnownClass::FunctoolsPartial
| KnownClass::Property
| KnownClass::Super
| KnownClass::TypeAliasType
@@ -5328,6 +5410,12 @@ impl<'db> Type<'db> {
}
KnownInstanceType::Callable(callable) => Ok(Type::Callable(*callable)),
KnownInstanceType::LiteralStringAlias(ty) => Ok(ty.inner(db)),
KnownInstanceType::FunctoolsPartial(_) => Err(InvalidTypeExpressionError {
invalid_expressions: smallvec_inline![InvalidTypeExpression::InvalidType(
*self, scope_id
)],
fallback_type: Type::unknown(),
}),
},
Type::SpecialForm(special_form) => special_form
@@ -6036,7 +6124,8 @@ impl<'db> Type<'db> {
| KnownInstanceType::Literal(_)
| KnownInstanceType::LiteralStringAlias(_)
| KnownInstanceType::NamedTupleSpec(_)
| KnownInstanceType::NewType(_) => {
| KnownInstanceType::NewType(_)
| KnownInstanceType::FunctoolsPartial(_) => {
// TODO: For some of these, we may need to try to find legacy typevars in inner types.
}
},
@@ -242,12 +242,45 @@ impl<'a, 'db> CallArguments<'a, 'db> {
}
/// Create a new [`CallArguments`] starting from the specified index.
pub(super) fn start_from(&self, index: usize) -> Self {
pub(crate) fn start_from(&self, index: usize) -> Self {
Self {
items: self.items[index..].to_vec(),
}
}
/// Returns the `functools.partial(...)` bound-argument slice when argument expansion is
/// concrete enough for partial-application analysis.
pub(crate) fn functools_partial_bound_arguments(&self, db: &'db dyn Db) -> Option<Self> {
let bound_call_arguments = self.start_from(1);
// We only handle variadics and keyword-maps that can be normalized to concrete argument
// positions for overload matching.
if bound_call_arguments.iter().any(|(argument, argument_ty)| {
let argument_ty = argument_ty.get_default().unwrap_or_else(Type::unknown);
match argument {
Argument::Variadic => !matches!(
argument_ty
.as_nominal_instance()
.and_then(|nominal| nominal.tuple_spec(db)),
Some(spec) if spec.as_fixed_length().is_some()
),
// Optional TypedDict keys may be absent at runtime, so we can only refine
// `partial(...)` when every expanded key is guaranteed to be present.
Argument::Keywords => argument_ty.as_typed_dict().is_none_or(|typed_dict| {
typed_dict
.items(db)
.values()
.any(|field| !field.is_required())
}),
Argument::Positional | Argument::Synthetic | Argument::Keyword(_) => false,
}
}) {
return None;
}
Some(bound_call_arguments)
}
/// Returns an iterator on performing [argument type expansion].
///
/// Each element of the iterator represents a set of argument lists, where each argument list
@@ -47,16 +47,17 @@ use crate::types::generics::{
use crate::types::known_instance::FieldInstance;
use crate::types::signatures::{
CallableSignature, Parameter, ParameterForm, ParameterKind, Parameters, ParametersKind,
PartialApplication, PartialSignatureApplication,
};
use crate::types::tuple::{TupleLength, TupleSpec, TupleType};
use crate::types::typevar::BoundTypeVarIdentity;
use crate::types::{
BoundMethodType, BoundTypeVarInstance, CallableType, ClassLiteral, DATACLASS_FLAGS,
DataclassFlags, DataclassParams, GenericAlias, InternedConstraintSet, IntersectionType,
KnownBoundMethodType, KnownClass, KnownInstanceType, LiteralValueTypeKind, NominalInstanceType,
PropertyInstanceType, SpecialFormType, TypeAliasType, TypeContext, TypeVarBoundOrConstraints,
TypeVarVariance, UnionAccumulator, UnionBuilder, UnionType, WrapperDescriptorKind, enums,
list_members,
BoundMethodType, BoundTypeVarInstance, CallableType, CallableTypes, ClassLiteral,
DATACLASS_FLAGS, DataclassFlags, DataclassParams, GenericAlias, InternedConstraintSet,
IntersectionType, KnownBoundMethodType, KnownClass, KnownInstanceType, LiteralValueTypeKind,
NominalInstanceType, PropertyInstanceType, SpecialFormType, TypeAliasType, TypeContext,
TypeVarBoundOrConstraints, TypeVarVariance, UnionAccumulator, UnionBuilder, UnionType,
WrapperDescriptorKind, enums, list_members,
};
use crate::{DisplaySettings, FxOrderSet, Program};
use ruff_db::diagnostic::{Annotation, Diagnostic, Span, SubDiagnostic, SubDiagnosticSeverity};
@@ -191,6 +192,26 @@ impl<'db> CallableItem<'db> {
self.callable().callable_type
}
/// Returns the reduced callable synthesized from this callable item.
fn functools_partial_callable<'a>(
&self,
db: &'db dyn Db,
partial_overload: &mut Binding<'db>,
bound_call_arguments: &CallArguments<'a, 'db>,
) -> Option<CallableType<'db>> {
match self {
CallableItem::Regular(binding) => CallableType::partially_apply(
db,
binding.partial_signature_applications(
db,
partial_overload,
bound_call_arguments,
)?,
),
CallableItem::Constructor(_) => None,
}
}
fn map<F>(self, f: &F) -> CallableItem<'db>
where
F: Fn(CallableBinding<'db>) -> CallableBinding<'db>,
@@ -652,6 +673,18 @@ impl<'db> Bindings<'db> {
.filter_map(CallableItem::as_constructor_mut)
}
fn clear_deferred_constructor_errors_for_partial_application(&mut self) {
for binding in self.iter_flat_mut() {
binding.clear_deferred_constructor_errors_for_partial_application();
}
for constructor in self.iter_constructor_items_mut() {
if let Some(downstream) = constructor.downstream_constructor_mut() {
downstream.clear_deferred_constructor_errors_for_partial_application();
}
}
}
/// Visits the callables that should contribute argument type context, including deferred
/// constructor callables that are relevant to the matched upstream constructor path.
pub(crate) fn visit_type_context_callables<'a>(
@@ -707,6 +740,98 @@ impl<'db> Bindings<'db> {
UnionType::from_elements(db, element_types)
}
/// Maps each `CallableItem` to a type and combines results while preserving
/// the union-of-intersections structure:
///
/// - callable items inside an element are intersected
/// - elements are unioned
fn map_item_types(
&self,
db: &'db dyn Db,
mut map: impl FnMut(&CallableItem<'db>) -> Option<Type<'db>>,
) -> Type<'db> {
let mut element_types = Vec::with_capacity(self.elements.len());
for element in &self.elements {
let mut item_types = Vec::new();
for item in element.items() {
if let Some(ty) = map(item) {
item_types.push(ty);
}
}
if !item_types.is_empty() {
element_types.push(IntersectionType::from_elements(db, item_types));
}
}
UnionType::from_elements(db, element_types)
}
/// Builds matched bindings for the callable wrapped by `functools.partial(...)`.
///
/// This handles the shared partial-specific preprocessing (callable validation and argument
/// normalization) used by both inference and known-call evaluation.
pub(crate) fn functools_partial_matched_bindings<'a>(
db: &'db dyn Db,
wrapped_callable_ty: Type<'db>,
call_arguments: &CallArguments<'a, 'db>,
) -> Option<(CallArguments<'a, 'db>, Bindings<'db>)> {
// We can only infer bound-argument context from an actual callable.
wrapped_callable_ty.try_upcast_to_callable(db)?;
let bound_call_arguments = call_arguments.functools_partial_bound_arguments(db)?;
let mut partial_bindings = wrapped_callable_ty
.bindings(db)
.match_parameters(db, &bound_call_arguments);
for binding in partial_bindings.iter_flat_mut() {
binding.clear_missing_argument_errors_for_partial_application();
}
for constructor in partial_bindings.iter_constructor_items_mut() {
if let Some(downstream) = constructor.downstream_constructor_mut() {
downstream.clear_deferred_constructor_errors_for_partial_application();
}
}
Some((bound_call_arguments, partial_bindings))
}
/// Synthesizes the precise `functools.partial(...)` type for the already-matched bindings.
///
/// Wrapped unions and intersections keep their original callable structure by partially
/// applying each callable item independently. A single wrapped callable instead exposes one
/// reduced callable whose overload set is merged before being wrapped as `partial[...]`.
fn functools_partial_type<'a>(
&self,
db: &'db dyn Db,
wrapped_callable_ty: Type<'db>,
partial_overload: &mut Binding<'db>,
bound_call_arguments: &CallArguments<'a, 'db>,
) -> Type<'db> {
if wrapped_callable_ty.is_union() || wrapped_callable_ty.is_intersection() {
return self.map_item_types(db, |partial_item| {
partial_item
.functools_partial_callable(db, partial_overload, bound_call_arguments)
.map(|callable| {
callable.into_precise_functools_partial_instance(db, wrapped_callable_ty)
})
});
}
let partial_callables: SmallVec<[CallableType<'db>; 1]> = self
.iter_callable_items()
.filter_map(|partial_item| {
partial_item.functools_partial_callable(db, partial_overload, bound_call_arguments)
})
.collect();
if partial_callables.is_empty() {
Type::Never
} else {
CallableTypes::from_elements(partial_callables)
.into_precise_functools_partial_instance(db, wrapped_callable_ty)
}
}
fn map_with<F>(self, f: &F) -> Self
where
F: Fn(CallableBinding<'db>) -> CallableBinding<'db>,
@@ -2416,6 +2541,14 @@ impl<'db> Bindings<'db> {
}
}
Some(KnownClass::FunctoolsPartial) => {
if let Some(new_return_type) =
overload.functools_partial_return_type(db, call_arguments)
{
overload.set_return_type(new_return_type);
}
}
Some(KnownClass::Tuple) if overload_index == 1 => {
// `tuple(range(42))` => `tuple[int, ...]`
// BUT `tuple((1, 2))` => `tuple[Literal[1], Literal[2]]` rather than `tuple[Literal[1, 2], ...]`
@@ -2548,6 +2681,24 @@ pub(crate) struct CallableBinding<'db> {
overloads: SmallVec<[Binding<'db>; 1]>,
}
#[derive(Copy, Clone)]
enum FailingOverloadSelection {
/// Consider all errors that participate in overload filtering.
AffectsOverloadResolution,
/// Consider only errors that are reported during `functools.partial(...)` construction.
ReportableForPartial,
}
impl FailingOverloadSelection {
/// Returns whether this selection mode should count the given error.
fn includes(self, error: &BindingError<'_>) -> bool {
match self {
Self::AffectsOverloadResolution => error.affects_overload_resolution(),
Self::ReportableForPartial => error.is_relevant_for_partial_application(),
}
}
}
impl<'db> CallableBinding<'db> {
pub(crate) fn from_overloads(
signature_type: Type<'db>,
@@ -2580,6 +2731,8 @@ impl<'db> CallableBinding<'db> {
}
}
/// Rewrites overload signatures as if an implicit bound receiver argument had already been
/// consumed.
pub(crate) fn bake_bound_type_into_overloads(&mut self, db: &'db dyn Db) {
let Some(bound_self) = self.bound_type.take() else {
return;
@@ -2589,6 +2742,139 @@ impl<'db> CallableBinding<'db> {
}
}
/// Ignore missing-argument errors when constructing `functools.partial(...)`.
///
/// Partial application intentionally leaves some parameters unbound, so we still want to
/// type-check all explicitly bound arguments against each overload.
fn clear_missing_argument_errors_for_partial_application(&mut self) {
for overload in &mut self.overloads {
overload.clear_missing_argument_errors_for_partial_application();
}
}
/// Ignore downstream constructor call-shape errors when constructing
/// `functools.partial(...)`.
///
/// The merged partial signature decides which parameters remain callable, so downstream
/// arity/name mismatches caused by as-yet-unbound constructor parameters should not reject
/// partial construction. Explicit bound-argument type errors are still preserved.
fn clear_deferred_constructor_errors_for_partial_application(&mut self) {
for overload in &mut self.overloads {
overload.clear_deferred_constructor_errors_for_partial_application();
}
}
/// Chooses which overload to use as the source for diagnostics when no overload fully matches.
///
/// If step 1 of overload resolution identified a single arity match, we keep using that
/// overload as the diagnostic source. Otherwise, we rank failing overloads by error quality:
/// fewer unknown-argument errors and fewer relevant errors are preferred.
fn best_failing_overload_index(&self, selection: FailingOverloadSelection) -> Option<usize> {
self.matching_overload_before_type_checking.or_else(|| {
self.overloads
.iter()
.enumerate()
.filter_map(|(index, overload)| {
let mut relevant_count = 0;
let mut unknown_argument_count = 0;
for error in &overload.errors {
if !selection.includes(error) {
continue;
}
relevant_count += 1;
if matches!(error, BindingError::UnknownArgument { .. }) {
unknown_argument_count += 1;
}
}
(relevant_count > 0).then_some((index, unknown_argument_count, relevant_count))
})
.min_by_key(|(_, unknown_argument_count, relevant_count)| {
(*unknown_argument_count, *relevant_count)
})
.map(|(index, _, _)| index)
})
}
/// Returns the matching overload indexes when `functools.partial(...)` ignores errors that are
/// only relevant at invocation time.
fn matching_partial_overload_index(&self) -> MatchingOverloadIndex {
let mut matching_overloads = self.overloads.iter().enumerate().filter(|(_, overload)| {
!overload
.errors
.iter()
.any(BindingError::is_relevant_for_partial_application)
});
match matching_overloads.next() {
None => MatchingOverloadIndex::None,
Some((first, _)) => {
if let Some((second, _)) = matching_overloads.next() {
let mut indexes = vec![first, second];
for (index, _) in matching_overloads {
indexes.push(index);
}
MatchingOverloadIndex::Multiple(indexes)
} else {
MatchingOverloadIndex::Single(first)
}
}
}
}
/// Selects the reduced signature applications for this `functools.partial(...)` binding.
///
/// Diagnostics for invalid bound arguments are still reported back to the outer `partial(...)`
/// overload. Callable construction happens in the callable layer after this summary is built.
fn partial_signature_applications<'a>(
&self,
db: &'db dyn Db,
partial_overload: &mut Binding<'db>,
bound_call_arguments: &CallArguments<'a, 'db>,
) -> Option<SmallVec<[PartialSignatureApplication<'db>; 1]>> {
if self.overloads().is_empty() {
return None;
}
let selected_overload_indexes = match self.matching_partial_overload_index() {
MatchingOverloadIndex::Single(index) => vec![index],
MatchingOverloadIndex::Multiple(indexes) => indexes,
MatchingOverloadIndex::None => {
let source_overload_index = self
.best_failing_overload_index(FailingOverloadSelection::ReportableForPartial)
.unwrap_or(0);
let source_errors = &self.overloads()[source_overload_index].errors;
for error in source_errors {
if error.is_relevant_for_partial_application() {
let error = error.clone().maybe_apply_argument_index_offset(Some(1));
if !partial_overload.errors.contains(&error) {
partial_overload.errors.push(error);
}
}
}
// When no overload is compatible with the bound arguments, don't manufacture a
// precise reduced signature from an arbitrary overloaded callable shape.
if self.overloads().len() > 1 {
return None;
}
vec![source_overload_index]
}
};
let signature_arguments = bound_call_arguments.with_self(self.bound_type);
let applications: SmallVec<_> = selected_overload_indexes
.into_iter()
.filter_map(|index| {
self.overloads().get(index).map(|overload| {
overload.partial_signature_application(signature_arguments.as_ref(), db)
})
})
.collect();
(!applications.is_empty()).then_some(applications)
}
pub(crate) fn with_bound_type(mut self, bound_type: Type<'db>) -> Self {
self.bound_type = Some(bound_type);
self
@@ -4842,7 +5128,9 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
);
} else {
let index = callable_binding
.matching_overload_before_type_checking
.best_failing_overload_index(
FailingOverloadSelection::AffectsOverloadResolution,
)
.unwrap_or(0);
// TODO: We should also update the specialization for the `ParamSpec` to reflect
// the matching overload here.
@@ -5257,6 +5545,131 @@ impl<'db> Binding<'db> {
&self.parameter_tys
}
/// Returns the reduced callable type exposed by this `functools.partial(...)` overload.
fn functools_partial_return_type<'a>(
&mut self,
db: &'db dyn Db,
call_arguments: &CallArguments<'a, 'db>,
) -> Option<Type<'db>> {
// `partial(...)` receives the wrapped callable as its first explicit argument (after
// constructor receiver handling).
let func_ty = match self.parameter_types() {
[Some(func_ty), ..] => *func_ty,
_ => return None,
};
let fallback_return_type =
KnownClass::FunctoolsPartial.to_specialized_instance(db, &[Type::unknown()]);
let (bound_call_arguments, partial_bindings) =
Bindings::functools_partial_matched_bindings(db, func_ty, call_arguments)?;
// Reuse call-binding machinery to resolve which wrapped overloads are compatible with
// bound arguments and to surface binding diagnostics.
let partial_bindings = match partial_bindings.check_types(
db,
&ConstraintSetBuilder::new(),
&bound_call_arguments,
TypeContext::default(),
&[],
) {
Ok(bindings) => bindings,
Err(CallError(_, bindings)) => *bindings,
};
let new_return_type =
partial_bindings.functools_partial_type(db, func_ty, self, &bound_call_arguments);
Some(if new_return_type.is_never() {
fallback_return_type
} else {
new_return_type
})
}
/// `functools.partial(...)` is allowed to leave required parameters unbound.
fn clear_missing_argument_errors_for_partial_application(&mut self) {
self.errors
.retain(|error| !matches!(error, BindingError::MissingArguments { .. }));
}
/// Downstream constructor validation is deferred until after partial signatures are merged.
fn clear_deferred_constructor_errors_for_partial_application(&mut self) {
self.errors.retain(|error| {
!matches!(
error,
BindingError::MissingArguments { .. }
| BindingError::UnknownArgument { .. }
| BindingError::PositionalOnlyParameterAsKwarg { .. }
| BindingError::TooManyPositionalArguments { .. }
| BindingError::ParameterAlreadyAssigned { .. }
)
});
}
/// Collects the parameter-level effects of a `functools.partial(...)` application.
fn partial_application(&self, arguments: &CallArguments<'_, 'db>) -> PartialApplication<'db> {
let parameters = self.signature.parameters().as_slice();
let mut partial_application = PartialApplication::new(parameters.len());
for ((argument, argument_ty), argument_matches) in
arguments.iter().zip(&self.argument_matches)
{
match argument {
Argument::Positional | Argument::Synthetic | Argument::Variadic => {
for (parameter_index, _) in argument_matches.iter() {
let parameter = &parameters[parameter_index];
if parameter.is_positional()
&& parameter.annotated_type() != Type::Never
&& !parameter.is_variadic()
&& !parameter.is_keyword_variadic()
{
partial_application.bind_positionally(parameter_index);
}
}
}
Argument::Keyword(_) | Argument::Keywords => {
for (parameter_index, matched_ty) in argument_matches.iter() {
if partial_application.is_positionally_bound(parameter_index) {
continue;
}
let parameter = &parameters[parameter_index];
if parameter.is_positional_only()
|| parameter.is_variadic()
|| parameter.is_keyword_variadic()
{
continue;
}
partial_application.bind_by_keyword(
parameter_index,
(parameter.annotated_type() != Type::Never).then(|| {
matched_ty.unwrap_or_else(|| {
argument_ty.get_default().unwrap_or_else(Type::unknown)
})
}),
);
}
}
}
}
partial_application
}
/// Packages the information needed to synthesize this overload's reduced partial signature.
fn partial_signature_application(
&self,
arguments: &CallArguments<'_, 'db>,
db: &'db dyn Db,
) -> PartialSignatureApplication<'db> {
PartialSignatureApplication::new(
self.signature.clone(),
self.partial_application(arguments),
self.specialization,
self.unspecialized_return_type(db),
)
}
/// Returns the bound type for the specified parameter, or `None` if no argument was matched to
/// that parameter.
///
@@ -5712,6 +6125,27 @@ pub(crate) enum BindingError<'db> {
}
impl BindingError<'_> {
/// Returns whether this error is relevant to `functools.partial(...)` construction.
///
/// These errors are used both to filter incompatible wrapped overloads and to report
/// statically-detectable call-shape errors at construction time. (Runtime `functools.partial`
/// can defer some call-shape errors until invocation.)
///
/// For example, `partial(f, 1)` should ignore `MissingArguments` for the parameters that stay
/// unbound, while `partial(f, "x")` should still report `InvalidArgumentType` immediately.
fn is_relevant_for_partial_application(&self) -> bool {
matches!(
self,
Self::InvalidArgumentType { .. }
| Self::InvalidKeyType { .. }
| Self::UnknownArgument { .. }
| Self::PositionalOnlyParameterAsKwarg { .. }
| Self::TooManyPositionalArguments { .. }
| Self::ParameterAlreadyAssigned { .. }
| Self::SpecializationError { .. }
)
}
pub(crate) fn maybe_apply_argument_index_offset(mut self, offset: Option<usize>) -> Self {
if let Some(offset) = offset {
self.apply_argument_index_offset(offset);
@@ -1,4 +1,5 @@
use ruff_python_ast::name::Name;
use rustc_hash::FxHashSet;
use smallvec::{SmallVec, smallvec_inline};
use crate::{
@@ -6,12 +7,13 @@ use crate::{
place::Place,
types::{
ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassType, FindLegacyTypeVarsVisitor,
KnownInstanceType, LiteralValueTypeKind, MemberLookupPolicy, Parameter, Parameters,
Signature, SubclassOfInner, Type, TypeContext, TypeMapping, TypeVarBoundOrConstraints,
UnionType,
InternedType, KnownClass, KnownInstanceType, LiteralValueTypeKind, MemberLookupPolicy,
Parameter, Parameters, Signature, SubclassOfInner, Type, TypeContext, TypeMapping,
TypeVarBoundOrConstraints, UnionType,
constraints::{ConstraintSet, IteratorConstraintsExtension},
known_instance::FunctoolsPartialInstance,
relation::{TypeRelation, TypeRelationChecker},
signatures::CallableSignature,
signatures::{CallableSignature, PartialSignatureApplication},
visitor, walk_signature,
},
};
@@ -213,6 +215,10 @@ impl<'db> Type<'db> {
| Type::TypeGuard(_)
| Type::TypedDict(_) => None,
Type::KnownInstance(KnownInstanceType::FunctoolsPartial(partial)) => {
Some(CallableTypes::one(partial.partial(db)))
}
// TODO
Type::DataclassDecorator(_)
| Type::ModuleLiteral(_)
@@ -368,6 +374,35 @@ impl<'db> CallableType<'db> {
CallableType::new(db, self.signatures(db), CallableTypeKind::Regular)
}
/// Returns the reduced callable produced by partially applying selected overloads.
pub(crate) fn partially_apply(
db: &'db dyn Db,
overloads: impl IntoIterator<Item = PartialSignatureApplication<'db>>,
) -> Option<Self> {
Some(Self::new(
db,
CallableSignature::partially_apply(db, overloads)?,
CallableTypeKind::Regular,
))
}
/// Reifies this callable as the nominal `functools.partial[T]` instance for its return type.
pub(crate) fn into_functools_partial_instance(self, db: &'db dyn Db) -> Type<'db> {
let return_ty = self.signatures(db).overload_return_type_or_unknown(db);
KnownClass::FunctoolsPartial.to_specialized_instance(db, &[return_ty])
}
/// Wraps this reduced callable as a synthetic `functools.partial(...)` instance type.
pub(crate) fn into_precise_functools_partial_instance(
self,
db: &'db dyn Db,
wrapped: Type<'db>,
) -> Type<'db> {
Type::KnownInstance(KnownInstanceType::FunctoolsPartial(
FunctoolsPartialInstance::new(db, InternedType::new(db, wrapped), self),
))
}
pub(crate) fn bind_self(
self,
db: &'db dyn Db,
@@ -494,6 +529,35 @@ impl<'db> CallableTypes<'db> {
pub(crate) fn map(self, mut f: impl FnMut(CallableType<'db>) -> CallableType<'db>) -> Self {
Self::from_elements(self.0.iter().map(|element| f(*element)))
}
/// Merges reduced callables into one precise `functools.partial(...)` instance type.
pub(crate) fn into_precise_functools_partial_instance(
self,
db: &'db dyn Db,
wrapped: Type<'db>,
) -> Type<'db> {
let mut overloads = Vec::new();
let mut seen_overloads = FxHashSet::default();
for callable in self.0 {
for signature in callable.signatures(db) {
let signature = signature.clone();
let dedup_key = signature.clone().with_definition(None);
if seen_overloads.insert(dedup_key) {
overloads.push(signature);
}
}
}
debug_assert!(!overloads.is_empty(), "CallableTypes should not be empty");
CallableType::new(
db,
CallableSignature::from_overloads(overloads),
CallableTypeKind::Regular,
)
.into_precise_functools_partial_instance(db, wrapped)
}
}
impl<'a, 'db> IntoIterator for &'a CallableTypes<'db> {
@@ -136,6 +136,8 @@ pub enum KnownClass {
Template,
// pathlib
Path,
// functools
FunctoolsPartial,
// ty_extensions
ConstraintSet,
GenericContext,
@@ -254,6 +256,7 @@ impl KnownClass {
| Self::GenericContext
| Self::Specialization
| Self::ProtocolMeta
| Self::FunctoolsPartial
| Self::TypedDictFallback => Some(Truthiness::Ambiguous),
Self::Tuple => None,
@@ -350,7 +353,8 @@ impl KnownClass {
| KnownClass::BuiltinFunctionType
| KnownClass::ProtocolMeta
| KnownClass::Template
| KnownClass::Path => false,
| KnownClass::Path
| KnownClass::FunctoolsPartial => false,
}
}
@@ -443,7 +447,8 @@ impl KnownClass {
| KnownClass::BuiltinFunctionType
| KnownClass::ProtocolMeta
| KnownClass::Template
| KnownClass::Path => false,
| KnownClass::Path
| KnownClass::FunctoolsPartial => false,
}
}
@@ -535,7 +540,8 @@ impl KnownClass {
| KnownClass::BuiltinFunctionType
| KnownClass::ProtocolMeta
| KnownClass::Template
| KnownClass::Path => false,
| KnownClass::Path
| KnownClass::FunctoolsPartial => false,
}
}
@@ -639,6 +645,7 @@ impl KnownClass {
| Self::ProtocolMeta
| Self::Template
| Self::Path
| Self::FunctoolsPartial
| Self::Mapping
| Self::Sequence => false,
}
@@ -733,6 +740,7 @@ impl KnownClass {
| KnownClass::NamedTupleLike
| KnownClass::Template
| KnownClass::Path
| KnownClass::FunctoolsPartial
| KnownClass::ConstraintSet
| KnownClass::GenericContext
| KnownClass::Specialization => false,
@@ -856,6 +864,7 @@ impl KnownClass {
Self::TypedDictFallback => "TypedDictFallback",
Self::Template => "Template",
Self::Path => "Path",
Self::FunctoolsPartial => "partial",
Self::ProtocolMeta => "_ProtocolMeta",
}
}
@@ -1239,6 +1248,7 @@ impl KnownClass {
| Self::Specialization => KnownModule::TyExtensions,
Self::Template => KnownModule::Templatelib,
Self::Path => KnownModule::Pathlib,
Self::FunctoolsPartial => KnownModule::Functools,
}
}
@@ -1333,7 +1343,8 @@ impl KnownClass {
| Self::BuiltinFunctionType
| Self::ProtocolMeta
| Self::Template
| Self::Path => Some(false),
| Self::Path
| Self::FunctoolsPartial => Some(false),
Self::Tuple => None,
}
@@ -1431,7 +1442,8 @@ impl KnownClass {
| Self::BuiltinFunctionType
| Self::ProtocolMeta
| Self::Template
| Self::Path => false,
| Self::Path
| Self::FunctoolsPartial => false,
}
}
@@ -1542,6 +1554,7 @@ impl KnownClass {
"TypedDictFallback" => &[Self::TypedDictFallback],
"Template" => &[Self::Template],
"Path" => &[Self::Path],
"partial" => &[Self::FunctoolsPartial],
"_ProtocolMeta" => &[Self::ProtocolMeta],
_ => return None,
};
@@ -1627,7 +1640,8 @@ impl KnownClass {
| Self::Generator
| Self::AsyncGenerator
| Self::Template
| Self::Path => module == self.canonical_module(db),
| Self::Path
| Self::FunctoolsPartial => module == self.canonical_module(db),
Self::NoneType => matches!(module, KnownModule::Typeshed | KnownModule::Types),
Self::SpecialForm
| Self::TypeAliasType
@@ -194,7 +194,8 @@ impl<'db> ClassBase<'db> {
// A class inheriting from a newtype would make intuitive sense, but newtype
// wrappers are just identity callables at runtime, so this sort of inheritance
// doesn't work and isn't allowed.
| KnownInstanceType::NewType(_) => None,
| KnownInstanceType::NewType(_)
| KnownInstanceType::FunctoolsPartial(_) => None,
KnownInstanceType::TypeGenericAlias(_) => {
Self::try_from_type(db, KnownClass::Type.to_class_literal(db), subclass)
}
@@ -3110,6 +3110,13 @@ impl<'db> FmtDetailed<'db> for DisplayKnownInstanceRepr<'db> {
f.write_str("'>")
}
KnownInstanceType::NamedTupleSpec(_) => f.write_str("NamedTupleSpec"),
KnownInstanceType::FunctoolsPartial(partial) => {
f.write_str("partial[")?;
Type::Callable(partial.partial(self.db))
.display_with(self.db, DisplaySettings::default().singleline())
.fmt_detailed(f)?;
f.write_str("]")
}
}
}
}
+42 -35
View File
@@ -877,46 +877,53 @@ impl<'db> GenericContext<'db> {
I: IntoIterator<Item = Option<Type<'db>>>,
I::IntoIter: ExactSizeIterator,
{
fn specialize_recursive_impl<'db>(
db: &'db dyn Db,
context: GenericContext<'db>,
mut types: Box<[Type<'db>]>,
) -> Specialization<'db> {
let len = types.len();
loop {
let mut any_changed = false;
for i in 0..len {
let specialization = ApplySpecialization::Partial {
generic_context: context,
types: &types,
// Don't recursively substitute type[i] in itself. Ideally, we could instead
// check if the result is self-referential after we're done applying the
// partial specialization. But when we apply a paramspec, we don't use the
// callable that it maps to directly; we create a new callable that reuses
// parts of it. That means we can't look for the previous type directly.
// Instead we use this to skip specializing the type in itself in the first
// place.
skip: Some(i),
};
let updated = types[i].apply_type_mapping(
db,
&TypeMapping::ApplySpecialization(specialization),
TypeContext::default(),
);
if updated != types[i] {
types[i] = updated;
any_changed = true;
}
let types = self.fill_in_defaults(db, types);
self.specialize_from_types_recursive(db, types)
}
/// Builds a specialization and recursively resolves references between the chosen types.
fn specialize_from_types_recursive(
self,
db: &'db dyn Db,
mut types: Box<[Type<'db>]>,
) -> Specialization<'db> {
let len = types.len();
let variables = self.variables(db).collect_vec();
loop {
let mut any_changed = false;
for i in 0..len {
// Preserve identity mappings for unresolved type variables.
if types[i] == Type::TypeVar(variables[i]) {
continue;
}
if !any_changed {
return Specialization::new(db, context, types, None, None);
let specialization = ApplySpecialization::Partial {
generic_context: self,
types: &types,
// Don't recursively substitute type[i] in itself. Ideally, we could instead
// check if the result is self-referential after we're done applying the
// partial specialization. But when we apply a paramspec, we don't use the
// callable that it maps to directly; we create a new callable that reuses
// parts of it. That means we can't look for the previous type directly.
// Instead we use this to skip specializing the type in itself in the first
// place.
skip: Some(i),
};
let updated = types[i].apply_type_mapping(
db,
&TypeMapping::ApplySpecialization(specialization),
TypeContext::default(),
);
if updated != types[i] {
types[i] = updated;
any_changed = true;
}
}
}
let types = self.fill_in_defaults(db, types);
specialize_recursive_impl(db, self, types)
if !any_changed {
return Specialization::new(db, self, types, None, None);
}
}
}
/// Creates a specialization of this generic context for the `tuple` class.
@@ -1618,6 +1618,15 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
Type::unknown()
}
KnownInstanceType::FunctoolsPartial(_) => {
self.infer_type_expression(&subscript.slice);
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
builder.into_diagnostic(format_args!(
"`functools.partial` instances cannot be specialized",
));
}
Type::unknown()
}
},
Type::Dynamic(DynamicType::UnknownGeneric(_)) => {
self.infer_explicit_type_alias_specialization(subscript, value_ty, true)
@@ -32,6 +32,16 @@ pub struct InternedConstraintSet<'db> {
// The Salsa heap is tracked separately.
impl get_size2::GetSize for InternedConstraintSet<'_> {}
/// A salsa-interned payload for `functools.partial(...)` instances.
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
pub struct FunctoolsPartialInstance<'db> {
pub wrapped: InternedType<'db>,
pub partial: CallableType<'db>,
}
// The Salsa heap is tracked separately.
impl get_size2::GetSize for FunctoolsPartialInstance<'_> {}
/// Singleton types that are heavily special-cased by ty. Despite its name,
/// quite a different type to [`super::NominalInstanceType`].
///
@@ -104,6 +114,10 @@ pub enum KnownInstanceType<'db> {
/// The inferred spec for a functional `NamedTuple` class.
NamedTupleSpec(NamedTupleSpec<'db>),
/// A `functools.partial(func, ...)` call result where we could determine
/// the remaining callable signature after binding some arguments.
FunctoolsPartial(FunctoolsPartialInstance<'db>),
}
pub(super) fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
@@ -159,6 +173,9 @@ pub(super) fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Size
visitor.visit_type(db, field.ty);
}
}
KnownInstanceType::FunctoolsPartial(partial) => {
visitor.visit_callable_type(db, partial.partial(db));
}
}
}
@@ -221,6 +238,9 @@ impl<'db> KnownInstanceType<'db> {
Self::NamedTupleSpec(spec) => spec
.recursive_type_normalized_impl(db, div, true)
.map(Self::NamedTupleSpec),
Self::FunctoolsPartial(partial) => partial
.recursive_type_normalized_impl(db, div, nested)
.map(Self::FunctoolsPartial),
}
}
@@ -248,6 +268,7 @@ impl<'db> KnownInstanceType<'db> {
Self::LiteralStringAlias(_) => KnownClass::Str,
Self::NewType(_) => KnownClass::NewType,
Self::NamedTupleSpec(_) => KnownClass::Sequence,
Self::FunctoolsPartial(_) => KnownClass::FunctoolsPartial,
}
}
@@ -260,7 +281,7 @@ impl<'db> KnownInstanceType<'db> {
/// For example, an alias created using the `type` statement is an instance of
/// `typing.TypeAliasType`, so `KnownInstanceType::TypeAliasType(_).instance_fallback(db)`
/// returns `Type::NominalInstance(NominalInstanceType { class: <typing.TypeAliasType> })`.
pub(super) fn instance_fallback(self, db: &dyn Db) -> Type<'_> {
pub(super) fn instance_fallback(self, db: &'db dyn Db) -> Type<'db> {
self.class(db).to_instance(db)
}
@@ -313,6 +334,11 @@ impl<'db> KnownInstanceType<'db> {
callable_type.apply_type_mapping_impl(db, type_mapping, tcx, visitor),
))
}
KnownInstanceType::FunctoolsPartial(partial) => {
Type::KnownInstance(KnownInstanceType::FunctoolsPartial(
partial.apply_type_mapping_impl(db, type_mapping, tcx, visitor),
))
}
KnownInstanceType::TypeGenericAlias(ty) => {
Type::KnownInstance(KnownInstanceType::TypeGenericAlias(InternedType::new(
db,
@@ -557,6 +583,49 @@ impl<'db> UnionTypeInstance<'db> {
}
}
impl<'db> FunctoolsPartialInstance<'db> {
/// Normalizes both the wrapped callable and the exposed reduced callable recursively.
fn recursive_type_normalized_impl(
self,
db: &'db dyn Db,
div: Type<'db>,
nested: bool,
) -> Option<Self> {
Some(Self::new(
db,
InternedType::new(
db,
self.wrapped(db)
.inner(db)
.recursive_type_normalized_impl(db, div, nested)?,
),
self.partial(db)
.recursive_type_normalized_impl(db, div, nested)?,
))
}
/// Applies a type mapping to both the wrapped callable and the exposed reduced callable.
fn apply_type_mapping_impl(
self,
db: &'db dyn Db,
type_mapping: &TypeMapping<'_, 'db>,
tcx: TypeContext<'db>,
visitor: &ApplyTypeMappingVisitor<'db>,
) -> Self {
Self::new(
db,
InternedType::new(
db,
self.wrapped(db)
.inner(db)
.apply_type_mapping_impl(db, type_mapping, tcx, visitor),
),
self.partial(db)
.apply_type_mapping_impl(db, type_mapping, tcx, visitor),
)
}
}
/// A salsa-interned `Type`
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
pub struct InternedType<'db> {
@@ -980,6 +980,26 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> {
})
}
(
Type::KnownInstance(KnownInstanceType::FunctoolsPartial(source_partial)),
Type::KnownInstance(KnownInstanceType::FunctoolsPartial(target_partial)),
) => self.with_recursion_guard(source, target, || {
self.check_callable_pair(db, source_partial.partial(db), target_partial.partial(db))
}),
// When checking `FunctoolsPartial <: functools.partial[T]`, we need to specialize
// the nominal instance with the partial's return type so the check is precise.
(
Type::KnownInstance(KnownInstanceType::FunctoolsPartial(partial)),
Type::NominalInstance(target_instance),
) if target_instance
.class(db)
.is_known(db, KnownClass::FunctoolsPartial) =>
{
let specialized = partial.partial(db).into_functools_partial_instance(db);
self.check_type_pair(db, specialized, target)
}
// Dynamic is only a subtype of `object` and only a supertype of `Never`; both were
// handled above. It's always assignable, though.
//
@@ -1447,6 +1467,22 @@ impl<'a, 'c, 'db> TypeRelationChecker<'a, 'c, 'db> {
(Type::FunctionLiteral(source_function), Type::FunctionLiteral(target_function)) => {
self.check_function_pair(db, source_function, target_function)
}
(
Type::KnownInstance(KnownInstanceType::FunctoolsPartial(source_partial)),
Type::FunctionLiteral(target_function),
) if matches!(
self.relation,
TypeRelation::Assignability | TypeRelation::ConstraintSetAssignability
) =>
{
self.with_recursion_guard(source, target, || {
self.check_callable_signature_pair(
db,
source_partial.partial(db).signatures(db),
target_function.into_callable_type(db).signatures(db),
)
})
}
(Type::BoundMethod(source_method), Type::BoundMethod(target_method)) => {
self.check_bound_method_pair(db, source_method, target_method)
}
@@ -14,16 +14,18 @@ use std::collections::BTreeMap;
use std::slice::Iter;
use itertools::{Either, EitherOrBoth, Itertools};
use rustc_hash::FxHashMap;
use rustc_hash::{FxHashMap, FxHashSet};
use smallvec::{SmallVec, smallvec_inline};
use super::{DynamicType, Type, TypeVarVariance, semantic_index};
use super::{DynamicType, Type, TypeVarVariance, UnionType, semantic_index};
use crate::types::callable::CallableTypeKind;
use crate::types::constraints::{
ConstraintSet, ConstraintSetBuilder, IteratorConstraintsExtension,
};
use crate::types::cyclic::ActiveRecursionDetector;
use crate::types::generics::{GenericContext, InferableTypeVars, walk_generic_context};
use crate::types::generics::{
ApplySpecialization, GenericContext, InferableTypeVars, Specialization, walk_generic_context,
};
use crate::types::infer::{TypeExpressionFlags, infer_deferred_types};
use crate::types::relation::{
HasRelationToVisitor, IsDisjointVisitor, TypeRelation, TypeRelationChecker,
@@ -32,6 +34,7 @@ use crate::types::typed_dict::{
UnpackedTypedDictKey, extract_unpacked_typed_dict_keys_from_kwargs_annotation,
extract_unpacked_typed_dict_keys_from_value_type,
};
use crate::types::typevar::BoundTypeVarIdentity;
use crate::types::{
ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, CallableType, ErrorContext,
FindLegacyTypeVarsVisitor, KnownClass, MaterializationKind, ParamSpecAttrKind,
@@ -95,6 +98,33 @@ pub struct CallableSignature<'db> {
pub(crate) overloads: SmallVec<[Signature<'db>; 1]>,
}
/// The per-overload information needed to synthesize one reduced signature for
/// `functools.partial(...)`.
#[derive(Clone, Debug)]
pub(crate) struct PartialSignatureApplication<'db> {
signature: Signature<'db>,
partial_application: PartialApplication<'db>,
specialization: Option<Specialization<'db>>,
unspecialized_return_ty: Type<'db>,
}
impl<'db> PartialSignatureApplication<'db> {
/// Creates a new per-overload partial-application summary.
pub(crate) fn new(
signature: Signature<'db>,
partial_application: PartialApplication<'db>,
specialization: Option<Specialization<'db>>,
unspecialized_return_ty: Type<'db>,
) -> Self {
Self {
signature,
partial_application,
specialization,
unspecialized_return_ty,
}
}
}
impl<'db> CallableSignature<'db> {
pub(crate) fn single(signature: Signature<'db>) -> Self {
Self {
@@ -121,6 +151,15 @@ impl<'db> CallableSignature<'db> {
self.overloads.iter()
}
/// Returns the union of all overload return types, or `Unknown` if there are no overloads.
pub(crate) fn overload_return_type_or_unknown(&self, db: &'db dyn Db) -> Type<'db> {
match self.overloads.as_slice() {
[] => Type::unknown(),
[signature] => signature.return_ty,
overloads => UnionType::from_elements(db, overloads.iter().map(|sig| sig.return_ty)),
}
}
pub(crate) fn with_inherited_generic_context(
&self,
db: &'db dyn Db,
@@ -133,6 +172,30 @@ impl<'db> CallableSignature<'db> {
}))
}
/// Returns the reduced overloaded signature exposed by a `functools.partial(...)` object.
pub(crate) fn partially_apply(
db: &'db dyn Db,
overloads: impl IntoIterator<Item = PartialSignatureApplication<'db>>,
) -> Option<Self> {
let mut new_overloads = Vec::new();
let mut seen_overloads = FxHashSet::default();
for overload in overloads {
let signature = overload.signature.partially_apply(
db,
&overload.partial_application,
overload.specialization,
overload.unspecialized_return_ty,
);
let dedup_key = signature.clone().with_definition(None);
if seen_overloads.insert(dedup_key) {
new_overloads.push(signature);
}
}
(!new_overloads.is_empty()).then(|| Self::from_overloads(new_overloads))
}
pub(crate) fn cycle_normalized(
&self,
db: &'db dyn Db,
@@ -465,6 +528,59 @@ pub(super) fn walk_signature<'db, V: super::visitor::TypeVisitor<'db> + ?Sized>(
visitor.visit_type(db, signature.return_ty);
}
/// Describes how a `functools.partial(...)` call binds one overload's parameters.
///
/// `call/bind.rs` computes this from argument matching. Signature rewriting then consumes this
/// summary to synthesize the reduced callable that a partial object exposes.
#[derive(Clone, Debug)]
pub(crate) struct PartialApplication<'db> {
positionally_bound: Box<[bool]>,
keyword_defaults: Box<[Option<Type<'db>>]>,
keyword_bound: Box<[bool]>,
}
impl<'db> PartialApplication<'db> {
/// Creates an empty partial-application summary for a signature with `parameter_count`
/// parameters.
pub(crate) fn new(parameter_count: usize) -> Self {
Self {
positionally_bound: vec![false; parameter_count].into_boxed_slice(),
keyword_defaults: vec![None; parameter_count].into_boxed_slice(),
keyword_bound: vec![false; parameter_count].into_boxed_slice(),
}
}
/// Marks the parameter at `parameter_index` as consumed by a positional binding.
pub(crate) fn bind_positionally(&mut self, parameter_index: usize) {
self.positionally_bound[parameter_index] = true;
}
/// Marks the parameter at `parameter_index` as bound by keyword and records the synthesized
/// default type that should appear in the reduced signature, if any.
pub(crate) fn bind_by_keyword(
&mut self,
parameter_index: usize,
default_ty: Option<Type<'db>>,
) {
self.keyword_bound[parameter_index] = true;
self.keyword_defaults[parameter_index] = default_ty;
}
/// Returns `true` if the parameter at `parameter_index` is removed from the reduced signature
/// because it was already supplied positionally to `functools.partial(...)`.
pub(crate) fn is_positionally_bound(&self, parameter_index: usize) -> bool {
self.positionally_bound[parameter_index]
}
fn keyword_default(&self, parameter_index: usize) -> Option<Type<'db>> {
self.keyword_defaults[parameter_index]
}
fn is_keyword_bound(&self, parameter_index: usize) -> bool {
self.keyword_bound[parameter_index]
}
}
impl<'db> Signature<'db> {
pub(crate) fn new(parameters: Parameters<'db>, return_ty: Type<'db>) -> Self {
Self {
@@ -804,6 +920,152 @@ impl<'db> Signature<'db> {
}
}
/// Returns this signature with the given specialization applied to parameters and return type.
pub(crate) fn apply_specialization(
&self,
db: &'db dyn Db,
specialization: Specialization<'db>,
) -> Self {
let type_mapping =
TypeMapping::ApplySpecialization(ApplySpecialization::Specialization(specialization));
self.apply_type_mapping_impl(
db,
&type_mapping,
TypeContext::default(),
&ApplyTypeMappingVisitor::default(),
)
}
/// Returns the callable signature produced by partially applying this signature.
pub(crate) fn partially_apply(
&self,
db: &'db dyn Db,
partial_application: &PartialApplication<'db>,
specialization: Option<Specialization<'db>>,
unspecialized_return_ty: Type<'db>,
) -> Self {
let signature_specialization =
self.partial_application_specialization(db, partial_application, specialization);
let signature = signature_specialization.map_or_else(
|| self.clone(),
|specialization| self.apply_specialization(db, specialization),
);
let parameters = signature.parameters().as_slice();
let return_ty = specialization.map_or_else(
|| unspecialized_return_ty,
|specialization| {
unspecialized_return_ty
.apply_specialization(db, signature_specialization.unwrap_or(specialization))
},
);
let mut remaining = Vec::with_capacity(parameters.len());
let mut first_keyword_bound_positional_or_keyword = None;
for (index, parameter) in parameters.iter().enumerate() {
if partial_application.is_positionally_bound(index) {
continue;
}
let parameter = partial_application.keyword_default(index).map_or_else(
|| parameter.clone(),
|default_ty| parameter.clone().with_default_type(default_ty),
);
if first_keyword_bound_positional_or_keyword.is_none()
&& partial_application.is_keyword_bound(index)
&& matches!(parameter.kind(), ParameterKind::PositionalOrKeyword { .. })
{
first_keyword_bound_positional_or_keyword = Some(remaining.len());
}
remaining.push(parameter);
}
// Expand `P.args`/`P.kwargs` while the pair is still adjacent. The keyword-only reshuffle
// below can separate them, which would otherwise prevent expansion.
let remaining = Parameters::new(db, remaining).expand_paramspec_variadics(db);
let mut reordered = Vec::with_capacity(remaining.len());
let mut keyword_only = Vec::new();
let mut keyword_variadic = Vec::new();
for (index, parameter) in remaining.iter().cloned().enumerate() {
let parameter = if first_keyword_bound_positional_or_keyword
.is_some_and(|first_bound_index| index >= first_bound_index)
&& matches!(parameter.kind(), ParameterKind::PositionalOrKeyword { .. })
{
parameter.positional_or_keyword_to_keyword_only()
} else {
parameter
};
if parameter.is_keyword_variadic() {
keyword_variadic.push(parameter);
} else if parameter.is_keyword_only() {
keyword_only.push(parameter);
} else {
reordered.push(parameter);
}
}
reordered.extend(keyword_only);
reordered.extend(keyword_variadic);
signature
.with_parameters(Parameters::new(db, reordered))
.with_return_type(return_ty)
}
/// Returns the specialization used for the callable signature exposed by a partial object.
///
/// Surviving type variables that still appear in the reduced parameter list may need a more
/// specific specialization than the plain return-type view.
fn partial_application_specialization(
&self,
db: &'db dyn Db,
partial_application: &PartialApplication<'db>,
specialization: Option<Specialization<'db>>,
) -> Option<Specialization<'db>> {
let specialization = specialization?;
let Some(generic_context) = self.generic_context else {
return Some(specialization);
};
let promoted_typevars: FxHashSet<BoundTypeVarIdentity<'db>> = generic_context
.variables(db)
.filter(|typevar| {
self.parameters
.iter()
.enumerate()
.filter(|(index, _)| !partial_application.is_positionally_bound(*index))
.any(|(_, parameter)| {
parameter
.annotated_type()
.references_typevar(db, typevar.typevar(db).identity(db))
})
})
.map(|typevar| typevar.identity(db))
.collect();
if promoted_typevars.is_empty() {
return Some(specialization);
}
Some(generic_context.specialize_recursive(
db,
generic_context.variables(db).map(|typevar| {
let ty = specialization
.get(db, typevar)
.unwrap_or(Type::TypeVar(typevar));
Some(if promoted_typevars.contains(&typevar.identity(db)) {
ty.promote(db)
} else {
ty
})
}),
))
}
fn inferable_typevars(&self, db: &'db dyn Db) -> InferableTypeVars<'db> {
match self.generic_context {
Some(generic_context) => generic_context.inferable_typevars(db),
@@ -886,6 +1148,11 @@ impl<'db> Signature<'db> {
Self { definition, ..self }
}
/// Create a new signature with the given parameters.
pub(crate) fn with_parameters(self, parameters: Parameters<'db>) -> Self {
Self { parameters, ..self }
}
/// Create a new signature with the given return type.
pub(crate) fn with_return_type(self, return_ty: Type<'db>) -> Self {
Self { return_ty, ..self }
@@ -3137,6 +3404,62 @@ impl<'db> Parameters<'db> {
.enumerate()
.rfind(|(_, parameter)| parameter.is_keyword_variadic())
}
/// Expands adjacent `P.args`/`P.kwargs` placeholders into their mapped parameters.
pub(crate) fn expand_paramspec_variadics(&self, db: &'db dyn Db) -> Self {
let mut variadic_index = None;
let mut paramspec_callable = None;
for (index, parameter) in self.iter().enumerate() {
if !parameter.is_variadic() {
continue;
}
let Type::Callable(callable) = parameter.annotated_type() else {
continue;
};
if callable.kind(db) != CallableTypeKind::ParamSpecValue {
continue;
}
variadic_index = Some(index);
paramspec_callable = Some(callable);
break;
}
let Some(variadic_index) = variadic_index else {
return self.clone();
};
let Some(paramspec_callable) = paramspec_callable else {
return self.clone();
};
let Some(keyword_variadic) = self.get(variadic_index + 1) else {
return self.clone();
};
if !keyword_variadic.is_keyword_variadic() {
return self.clone();
}
let Type::Callable(keyword_callable) = keyword_variadic.annotated_type() else {
return self.clone();
};
if keyword_callable.kind(db) != CallableTypeKind::ParamSpecValue
|| keyword_callable != paramspec_callable
{
return self.clone();
}
let [mapped_signature] = paramspec_callable.signatures(db).overloads.as_slice() else {
return self.clone();
};
let mut expanded = Vec::with_capacity(self.len());
expanded.extend_from_slice(&self.value[..variadic_index]);
expanded.extend_from_slice(mapped_signature.parameters().as_slice());
expanded.extend_from_slice(&self.value[variadic_index + 2..]);
Parameters::new(db, expanded)
}
}
impl<'db, 'a> IntoIterator for &'a Parameters<'db> {
@@ -3587,6 +3910,18 @@ impl<'db> Parameter<'db> {
ParameterKind::Variadic { .. } | ParameterKind::KeywordVariadic { .. } => None,
}
}
/// Rewrites a positional-or-keyword parameter as keyword-only while preserving its metadata.
pub(crate) fn positional_or_keyword_to_keyword_only(&self) -> Self {
let mut result = self.clone();
if let ParameterKind::PositionalOrKeyword { name, default_type } = &self.kind {
result.kind = ParameterKind::KeywordOnly {
name: name.clone(),
default_type: *default_type,
};
}
result
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)]