diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 50ed347b9c8e0..50d99eb7a4573 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -5,7 +5,6 @@ use context::InferContext; use diagnostic::{report_not_iterable, report_not_iterable_possibly_unbound}; use indexmap::IndexSet; use itertools::Itertools; -use ruff_db::diagnostic::Severity; use ruff_db::files::File; use ruff_python_ast as ast; use type_ordering::union_elements_ordering; @@ -35,9 +34,7 @@ use crate::semantic_index::{ use crate::stdlib::{builtins_symbol, known_module_symbol, typing_extensions_symbol}; use crate::suppression::check_suppressions; use crate::symbol::{Boundness, Symbol}; -use crate::types::call::{ - bind_call, CallArguments, CallBinding, CallDunderResult, CallOutcome, StaticAssertionErrorKind, -}; +use crate::types::call::{bind_call, CallArguments, CallBinding, CallDunderResult, CallOutcome}; use crate::types::class_base::ClassBase; use crate::types::diagnostic::INVALID_TYPE_FORM; use crate::types::infer::infer_unpack_types; @@ -1944,40 +1941,6 @@ impl<'db> Type<'db> { Type::FunctionLiteral(function_type) => { let mut binding = bind_call(db, arguments, function_type.signature(db), Some(self)); match function_type.known(db) { - Some(KnownFunction::RevealType) => { - let revealed_ty = binding.one_parameter_type().unwrap_or(Type::unknown()); - CallOutcome::revealed(binding, revealed_ty) - } - Some(KnownFunction::StaticAssert) => { - if let Some((parameter_ty, message)) = binding.two_parameter_types() { - let truthiness = parameter_ty.bool(db); - - if truthiness.is_always_true() { - CallOutcome::callable(binding) - } else { - let error_kind = if let Some(message) = - message.into_string_literal().map(|s| &**s.value(db)) - { - StaticAssertionErrorKind::CustomError(message) - } else if parameter_ty == Type::BooleanLiteral(false) { - StaticAssertionErrorKind::ArgumentIsFalse - } else if truthiness.is_always_false() { - StaticAssertionErrorKind::ArgumentIsFalsy(parameter_ty) - } else { - StaticAssertionErrorKind::ArgumentTruthinessIsAmbiguous( - parameter_ty, - ) - }; - - CallOutcome::StaticAssertionError { - binding, - error_kind, - } - } - } else { - CallOutcome::callable(binding) - } - } Some(KnownFunction::IsEquivalentTo) => { let (ty_a, ty_b) = binding .two_parameter_types() @@ -2052,14 +2015,6 @@ impl<'db> Type<'db> { CallOutcome::callable(binding) } - Some(KnownFunction::AssertType) => { - let Some((_, asserted_ty)) = binding.two_parameter_types() else { - return CallOutcome::callable(binding); - }; - - CallOutcome::asserted(binding, asserted_ty) - } - Some(KnownFunction::Cast) => { // TODO: Use `.two_parameter_tys()` exclusively // when overloads are supported. @@ -4074,10 +4029,7 @@ impl<'db> Class<'db> { // TODO we should also check for binding errors that would indicate the metaclass // does not accept the right arguments - CallOutcome::Callable { binding } - | CallOutcome::RevealType { binding, .. } - | CallOutcome::StaticAssertionError { binding, .. } - | CallOutcome::AssertType { binding, .. } => Ok(binding.return_type()), + CallOutcome::Callable { binding } => Ok(binding.return_type()), }; return return_ty_result.map(|ty| ty.to_meta_type(db)); diff --git a/crates/red_knot_python_semantic/src/types/call.rs b/crates/red_knot_python_semantic/src/types/call.rs index 13ab169ede483..8b8f4c3705bbb 100644 --- a/crates/red_knot_python_semantic/src/types/call.rs +++ b/crates/red_knot_python_semantic/src/types/call.rs @@ -1,9 +1,7 @@ use super::context::InferContext; -use super::diagnostic::{CALL_NON_CALLABLE, TYPE_ASSERTION_FAILURE}; -use super::{Severity, Signature, Type, TypeArrayDisplay, UnionBuilder}; -use crate::types::diagnostic::STATIC_ASSERT_ERROR; +use super::diagnostic::CALL_NON_CALLABLE; +use super::{Signature, Type, TypeArrayDisplay, UnionBuilder}; use crate::Db; -use ruff_db::diagnostic::DiagnosticId; use ruff_python_ast as ast; mod arguments; @@ -13,22 +11,10 @@ pub(super) use arguments::{Argument, CallArguments}; pub(super) use bind::{bind_call, CallBinding}; #[derive(Debug, Clone, PartialEq, Eq)] -pub(super) enum StaticAssertionErrorKind<'db> { - ArgumentIsFalse, - ArgumentIsFalsy(Type<'db>), - ArgumentTruthinessIsAmbiguous(Type<'db>), - CustomError(&'db str), -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub(super) enum CallOutcome<'db> { +pub(crate) enum CallOutcome<'db> { Callable { binding: CallBinding<'db>, }, - RevealType { - binding: CallBinding<'db>, - revealed_ty: Type<'db>, - }, NotCallable { not_callable_ty: Type<'db>, }, @@ -40,14 +26,6 @@ pub(super) enum CallOutcome<'db> { called_ty: Type<'db>, call_outcome: Box>, }, - StaticAssertionError { - binding: CallBinding<'db>, - error_kind: StaticAssertionErrorKind<'db>, - }, - AssertType { - binding: CallBinding<'db>, - asserted_ty: Type<'db>, - }, } impl<'db> CallOutcome<'db> { @@ -61,14 +39,6 @@ impl<'db> CallOutcome<'db> { CallOutcome::NotCallable { not_callable_ty } } - /// Create a new `CallOutcome::RevealType` with given revealed and return types. - pub(super) fn revealed(binding: CallBinding<'db>, revealed_ty: Type<'db>) -> CallOutcome<'db> { - CallOutcome::RevealType { - binding, - revealed_ty, - } - } - /// Create a new `CallOutcome::Union` with given wrapped outcomes. pub(super) fn union( called_ty: Type<'db>, @@ -80,22 +50,10 @@ impl<'db> CallOutcome<'db> { } } - /// Create a new `CallOutcome::AssertType` with given asserted and return types. - pub(super) fn asserted(binding: CallBinding<'db>, asserted_ty: Type<'db>) -> CallOutcome<'db> { - CallOutcome::AssertType { - binding, - asserted_ty, - } - } - /// Get the return type of the call, or `None` if not callable. pub(super) fn return_type(&self, db: &'db dyn Db) -> Option> { match self { Self::Callable { binding } => Some(binding.return_type()), - Self::RevealType { - binding, - revealed_ty: _, - } => Some(binding.return_type()), Self::NotCallable { not_callable_ty: _ } => None, Self::Union { outcomes, @@ -114,11 +72,6 @@ impl<'db> CallOutcome<'db> { }) .map(UnionBuilder::build), Self::PossiblyUnboundDunderCall { call_outcome, .. } => call_outcome.return_type(db), - Self::StaticAssertionError { .. } => Some(Type::none(db)), - Self::AssertType { - binding, - asserted_ty: _, - } => Some(binding.return_type()), } } @@ -204,22 +157,12 @@ impl<'db> CallOutcome<'db> { // only non-callable diagnostics in the union case, which is inconsistent. match self { Self::Callable { binding } => { + // TODO: Move this out of the `CallOutcome` and into `TypeInferenceBuilder`? + // This check is required everywhere where we call `return_type_result` + // from the TypeInferenceBuilder. binding.report_diagnostics(context, node); Ok(binding.return_type()) } - Self::RevealType { - binding, - revealed_ty, - } => { - binding.report_diagnostics(context, node); - context.report_diagnostic( - node, - DiagnosticId::RevealedType, - Severity::Info, - format_args!("Revealed type is `{}`", revealed_ty.display(context.db())), - ); - Ok(binding.return_type()) - } Self::NotCallable { not_callable_ty } => Err(NotCallableError::Type { not_callable_ty: *not_callable_ty, return_ty: Type::unknown(), @@ -239,24 +182,12 @@ impl<'db> CallOutcome<'db> { } => { let mut not_callable = vec![]; let mut union_builder = UnionBuilder::new(context.db()); - let mut revealed = false; for outcome in outcomes { let return_ty = match outcome { Self::NotCallable { not_callable_ty } => { not_callable.push(*not_callable_ty); Type::unknown() } - Self::RevealType { - binding, - revealed_ty: _, - } => { - if revealed { - binding.return_type() - } else { - revealed = true; - outcome.unwrap_with_diagnostic(context, node) - } - } _ => outcome.unwrap_with_diagnostic(context, node), }; union_builder = union_builder.add(return_ty); @@ -280,73 +211,6 @@ impl<'db> CallOutcome<'db> { }), } } - Self::StaticAssertionError { - binding, - error_kind, - } => { - binding.report_diagnostics(context, node); - - match error_kind { - StaticAssertionErrorKind::ArgumentIsFalse => { - context.report_lint( - &STATIC_ASSERT_ERROR, - node, - format_args!("Static assertion error: argument evaluates to `False`"), - ); - } - StaticAssertionErrorKind::ArgumentIsFalsy(parameter_ty) => { - context.report_lint( - &STATIC_ASSERT_ERROR, - node, - format_args!( - "Static assertion error: argument of type `{parameter_ty}` is statically known to be falsy", - parameter_ty=parameter_ty.display(context.db()) - ), - ); - } - StaticAssertionErrorKind::ArgumentTruthinessIsAmbiguous(parameter_ty) => { - context.report_lint( - &STATIC_ASSERT_ERROR, - node, - format_args!( - "Static assertion error: argument of type `{parameter_ty}` has an ambiguous static truthiness", - parameter_ty=parameter_ty.display(context.db()) - ), - ); - } - StaticAssertionErrorKind::CustomError(message) => { - context.report_lint( - &STATIC_ASSERT_ERROR, - node, - format_args!("Static assertion error: {message}"), - ); - } - } - - Ok(Type::unknown()) - } - Self::AssertType { - binding, - asserted_ty, - } => { - let [actual_ty, _asserted] = binding.parameter_types() else { - return Ok(binding.return_type()); - }; - - if !actual_ty.is_gradual_equivalent_to(context.db(), *asserted_ty) { - context.report_lint( - &TYPE_ASSERTION_FAILURE, - node, - format_args!( - "Actual type `{}` is not the same as asserted type `{}`", - actual_ty.display(context.db()), - asserted_ty.display(context.db()), - ), - ); - } - - Ok(binding.return_type()) - } } } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 9a5379b6393fc..501ee49dfbffd 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -29,6 +29,7 @@ use std::num::NonZeroU32; use itertools::{Either, Itertools}; +use ruff_db::diagnostic::{DiagnosticId, Severity}; use ruff_db::files::File; use ruff_db::parsed::parsed_module; use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext}; @@ -49,7 +50,7 @@ use crate::semantic_index::semantic_index; use crate::semantic_index::symbol::{NodeWithScopeKind, NodeWithScopeRef, ScopeId}; use crate::semantic_index::SemanticIndex; use crate::stdlib::builtins_module_scope; -use crate::types::call::{Argument, CallArguments}; +use crate::types::call::{Argument, CallArguments, CallOutcome}; use crate::types::diagnostic::{ report_invalid_arguments_to_annotated, report_invalid_assignment, report_invalid_attribute_assignment, report_unresolved_module, TypeCheckDiagnostics, @@ -58,7 +59,8 @@ use crate::types::diagnostic::{ INCONSISTENT_MRO, INVALID_ATTRIBUTE_ACCESS, INVALID_BASE, INVALID_CONTEXT_MANAGER, INVALID_DECLARATION, INVALID_PARAMETER_DEFAULT, INVALID_TYPE_FORM, INVALID_TYPE_VARIABLE_CONSTRAINTS, POSSIBLY_UNBOUND_ATTRIBUTE, POSSIBLY_UNBOUND_IMPORT, - UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT, UNSUPPORTED_OPERATOR, + STATIC_ASSERT_ERROR, TYPE_ASSERTION_FAILURE, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, + UNRESOLVED_IMPORT, UNSUPPORTED_OPERATOR, }; use crate::types::mro::MroErrorKind; use crate::types::unpacker::{UnpackResult, Unpacker}; @@ -3245,9 +3247,86 @@ impl<'db> TypeInferenceBuilder<'db> { .unwrap_or_default(); let call_arguments = self.infer_arguments(arguments, parameter_expectations); - function_type - .call(self.db(), &call_arguments) - .unwrap_with_diagnostic(&self.context, call_expression.into()) + + let call_outcome = function_type.call(self.db(), &call_arguments); + + if let CallOutcome::Callable { binding } = &call_outcome { + if let Type::FunctionLiteral(function_type) = function_type { + match function_type.known(self.db()) { + Some(KnownFunction::RevealType) => { + let revealed_ty = binding.one_parameter_type().unwrap_or(Type::unknown()); + self.context.report_diagnostic( + call_expression.into(), + DiagnosticId::RevealedType, + Severity::Info, + format_args!("Revealed type is `{}`", revealed_ty.display(self.db())), + ); + } + Some(KnownFunction::AssertType) => { + if let [actual_ty, asserted_ty] = binding.parameter_types() { + if !actual_ty.is_gradual_equivalent_to(self.db(), *asserted_ty) { + self.context.report_lint( + &TYPE_ASSERTION_FAILURE, + call_expression.into(), + format_args!( + "Actual type `{}` is not the same as asserted type `{}`", + actual_ty.display(self.db()), + asserted_ty.display(self.db()), + ), + ); + } + }; + } + + Some(KnownFunction::StaticAssert) => { + if let Some((parameter_ty, message)) = binding.two_parameter_types() { + let truthiness = parameter_ty.bool(self.db()); + + if !truthiness.is_always_true() { + if let Some(message) = + message.into_string_literal().map(|s| &**s.value(self.db())) + { + self.context.report_lint( + &STATIC_ASSERT_ERROR, + call_expression.into(), + format_args!("Static assertion error: {message}"), + ); + } else if parameter_ty == Type::BooleanLiteral(false) { + self.context.report_lint( + &STATIC_ASSERT_ERROR, + call_expression.into(), + format_args!( + "Static assertion error: argument evaluates to `False`" + ), + ); + } else if truthiness.is_always_false() { + self.context.report_lint( + &STATIC_ASSERT_ERROR, + call_expression.into(), + format_args!( + "Static assertion error: argument of type `{parameter_ty}` is statically known to be falsy", + parameter_ty=parameter_ty.display(self.db()) + ), + ); + } else { + self.context.report_lint( + &STATIC_ASSERT_ERROR, + call_expression.into(), + format_args!( + "Static assertion error: argument of type `{parameter_ty}` has an ambiguous static truthiness", + parameter_ty=parameter_ty.display(self.db()) + ), + ); + }; + } + } + } + _ => {} + } + } + } + + call_outcome.unwrap_with_diagnostic(&self.context, call_expression.into()) } fn infer_starred_expression(&mut self, starred: &ast::ExprStarred) -> Type<'db> {