Skip to content

Commit

Permalink
Move reveal_type, assert_type handling out of CallOutcome
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaReiser committed Feb 12, 2025
1 parent 03f0828 commit daf19cc
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 197 deletions.
52 changes: 2 additions & 50 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use context::InferContext;
use diagnostic::{report_not_iterable, report_not_iterable_possibly_unbound};
use indexmap::IndexSet;
use itertools::Itertools;
use ruff_db::diagnostic::Severity;
use ruff_db::files::File;
use ruff_python_ast as ast;
use type_ordering::union_elements_ordering;
Expand Down Expand Up @@ -35,9 +34,7 @@ use crate::semantic_index::{
use crate::stdlib::{builtins_symbol, known_module_symbol, typing_extensions_symbol};
use crate::suppression::check_suppressions;
use crate::symbol::{Boundness, Symbol};
use crate::types::call::{
bind_call, CallArguments, CallBinding, CallDunderResult, CallOutcome, StaticAssertionErrorKind,
};
use crate::types::call::{bind_call, CallArguments, CallBinding, CallDunderResult, CallOutcome};
use crate::types::class_base::ClassBase;
use crate::types::diagnostic::INVALID_TYPE_FORM;
use crate::types::infer::infer_unpack_types;
Expand Down Expand Up @@ -1944,40 +1941,6 @@ impl<'db> Type<'db> {
Type::FunctionLiteral(function_type) => {
let mut binding = bind_call(db, arguments, function_type.signature(db), Some(self));
match function_type.known(db) {
Some(KnownFunction::RevealType) => {
let revealed_ty = binding.one_parameter_type().unwrap_or(Type::unknown());
CallOutcome::revealed(binding, revealed_ty)
}
Some(KnownFunction::StaticAssert) => {
if let Some((parameter_ty, message)) = binding.two_parameter_types() {
let truthiness = parameter_ty.bool(db);

if truthiness.is_always_true() {
CallOutcome::callable(binding)
} else {
let error_kind = if let Some(message) =
message.into_string_literal().map(|s| &**s.value(db))
{
StaticAssertionErrorKind::CustomError(message)
} else if parameter_ty == Type::BooleanLiteral(false) {
StaticAssertionErrorKind::ArgumentIsFalse
} else if truthiness.is_always_false() {
StaticAssertionErrorKind::ArgumentIsFalsy(parameter_ty)
} else {
StaticAssertionErrorKind::ArgumentTruthinessIsAmbiguous(
parameter_ty,
)
};

CallOutcome::StaticAssertionError {
binding,
error_kind,
}
}
} else {
CallOutcome::callable(binding)
}
}
Some(KnownFunction::IsEquivalentTo) => {
let (ty_a, ty_b) = binding
.two_parameter_types()
Expand Down Expand Up @@ -2052,14 +2015,6 @@ impl<'db> Type<'db> {
CallOutcome::callable(binding)
}

Some(KnownFunction::AssertType) => {
let Some((_, asserted_ty)) = binding.two_parameter_types() else {
return CallOutcome::callable(binding);
};

CallOutcome::asserted(binding, asserted_ty)
}

Some(KnownFunction::Cast) => {
// TODO: Use `.two_parameter_tys()` exclusively
// when overloads are supported.
Expand Down Expand Up @@ -4074,10 +4029,7 @@ impl<'db> Class<'db> {

// TODO we should also check for binding errors that would indicate the metaclass
// does not accept the right arguments
CallOutcome::Callable { binding }
| CallOutcome::RevealType { binding, .. }
| CallOutcome::StaticAssertionError { binding, .. }
| CallOutcome::AssertType { binding, .. } => Ok(binding.return_type()),
CallOutcome::Callable { binding } => Ok(binding.return_type()),
};

return return_ty_result.map(|ty| ty.to_meta_type(db));
Expand Down
148 changes: 6 additions & 142 deletions crates/red_knot_python_semantic/src/types/call.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use super::context::InferContext;
use super::diagnostic::{CALL_NON_CALLABLE, TYPE_ASSERTION_FAILURE};
use super::{Severity, Signature, Type, TypeArrayDisplay, UnionBuilder};
use crate::types::diagnostic::STATIC_ASSERT_ERROR;
use super::diagnostic::CALL_NON_CALLABLE;
use super::{Signature, Type, TypeArrayDisplay, UnionBuilder};
use crate::Db;
use ruff_db::diagnostic::DiagnosticId;
use ruff_python_ast as ast;

mod arguments;
Expand All @@ -13,22 +11,10 @@ pub(super) use arguments::{Argument, CallArguments};
pub(super) use bind::{bind_call, CallBinding};

#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) enum StaticAssertionErrorKind<'db> {
ArgumentIsFalse,
ArgumentIsFalsy(Type<'db>),
ArgumentTruthinessIsAmbiguous(Type<'db>),
CustomError(&'db str),
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub(super) enum CallOutcome<'db> {
pub(crate) enum CallOutcome<'db> {
Callable {
binding: CallBinding<'db>,
},
RevealType {
binding: CallBinding<'db>,
revealed_ty: Type<'db>,
},
NotCallable {
not_callable_ty: Type<'db>,
},
Expand All @@ -40,14 +26,6 @@ pub(super) enum CallOutcome<'db> {
called_ty: Type<'db>,
call_outcome: Box<CallOutcome<'db>>,
},
StaticAssertionError {
binding: CallBinding<'db>,
error_kind: StaticAssertionErrorKind<'db>,
},
AssertType {
binding: CallBinding<'db>,
asserted_ty: Type<'db>,
},
}

impl<'db> CallOutcome<'db> {
Expand All @@ -61,14 +39,6 @@ impl<'db> CallOutcome<'db> {
CallOutcome::NotCallable { not_callable_ty }
}

/// Create a new `CallOutcome::RevealType` with given revealed and return types.
pub(super) fn revealed(binding: CallBinding<'db>, revealed_ty: Type<'db>) -> CallOutcome<'db> {
CallOutcome::RevealType {
binding,
revealed_ty,
}
}

/// Create a new `CallOutcome::Union` with given wrapped outcomes.
pub(super) fn union(
called_ty: Type<'db>,
Expand All @@ -80,22 +50,10 @@ impl<'db> CallOutcome<'db> {
}
}

/// Create a new `CallOutcome::AssertType` with given asserted and return types.
pub(super) fn asserted(binding: CallBinding<'db>, asserted_ty: Type<'db>) -> CallOutcome<'db> {
CallOutcome::AssertType {
binding,
asserted_ty,
}
}

/// Get the return type of the call, or `None` if not callable.
pub(super) fn return_type(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Self::Callable { binding } => Some(binding.return_type()),
Self::RevealType {
binding,
revealed_ty: _,
} => Some(binding.return_type()),
Self::NotCallable { not_callable_ty: _ } => None,
Self::Union {
outcomes,
Expand All @@ -114,11 +72,6 @@ impl<'db> CallOutcome<'db> {
})
.map(UnionBuilder::build),
Self::PossiblyUnboundDunderCall { call_outcome, .. } => call_outcome.return_type(db),
Self::StaticAssertionError { .. } => Some(Type::none(db)),
Self::AssertType {
binding,
asserted_ty: _,
} => Some(binding.return_type()),
}
}

Expand Down Expand Up @@ -204,22 +157,12 @@ impl<'db> CallOutcome<'db> {
// only non-callable diagnostics in the union case, which is inconsistent.
match self {
Self::Callable { binding } => {
// TODO: Move this out of the `CallOutcome` and into `TypeInferenceBuilder`?
// This check is required everywhere where we call `return_type_result`
// from the TypeInferenceBuilder.
binding.report_diagnostics(context, node);
Ok(binding.return_type())
}
Self::RevealType {
binding,
revealed_ty,
} => {
binding.report_diagnostics(context, node);
context.report_diagnostic(
node,
DiagnosticId::RevealedType,
Severity::Info,
format_args!("Revealed type is `{}`", revealed_ty.display(context.db())),
);
Ok(binding.return_type())
}
Self::NotCallable { not_callable_ty } => Err(NotCallableError::Type {
not_callable_ty: *not_callable_ty,
return_ty: Type::unknown(),
Expand All @@ -239,24 +182,12 @@ impl<'db> CallOutcome<'db> {
} => {
let mut not_callable = vec![];
let mut union_builder = UnionBuilder::new(context.db());
let mut revealed = false;
for outcome in outcomes {
let return_ty = match outcome {
Self::NotCallable { not_callable_ty } => {
not_callable.push(*not_callable_ty);
Type::unknown()
}
Self::RevealType {
binding,
revealed_ty: _,
} => {
if revealed {
binding.return_type()
} else {
revealed = true;
outcome.unwrap_with_diagnostic(context, node)
}
}
_ => outcome.unwrap_with_diagnostic(context, node),
};
union_builder = union_builder.add(return_ty);
Expand All @@ -280,73 +211,6 @@ impl<'db> CallOutcome<'db> {
}),
}
}
Self::StaticAssertionError {
binding,
error_kind,
} => {
binding.report_diagnostics(context, node);

match error_kind {
StaticAssertionErrorKind::ArgumentIsFalse => {
context.report_lint(
&STATIC_ASSERT_ERROR,
node,
format_args!("Static assertion error: argument evaluates to `False`"),
);
}
StaticAssertionErrorKind::ArgumentIsFalsy(parameter_ty) => {
context.report_lint(
&STATIC_ASSERT_ERROR,
node,
format_args!(
"Static assertion error: argument of type `{parameter_ty}` is statically known to be falsy",
parameter_ty=parameter_ty.display(context.db())
),
);
}
StaticAssertionErrorKind::ArgumentTruthinessIsAmbiguous(parameter_ty) => {
context.report_lint(
&STATIC_ASSERT_ERROR,
node,
format_args!(
"Static assertion error: argument of type `{parameter_ty}` has an ambiguous static truthiness",
parameter_ty=parameter_ty.display(context.db())
),
);
}
StaticAssertionErrorKind::CustomError(message) => {
context.report_lint(
&STATIC_ASSERT_ERROR,
node,
format_args!("Static assertion error: {message}"),
);
}
}

Ok(Type::unknown())
}
Self::AssertType {
binding,
asserted_ty,
} => {
let [actual_ty, _asserted] = binding.parameter_types() else {
return Ok(binding.return_type());
};

if !actual_ty.is_gradual_equivalent_to(context.db(), *asserted_ty) {
context.report_lint(
&TYPE_ASSERTION_FAILURE,
node,
format_args!(
"Actual type `{}` is not the same as asserted type `{}`",
actual_ty.display(context.db()),
asserted_ty.display(context.db()),
),
);
}

Ok(binding.return_type())
}
}
}
}
Expand Down
Loading

0 comments on commit daf19cc

Please sign in to comment.