diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/type_guards.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/type_guards.md new file mode 100644 index 00000000000000..8cb1e196b197b4 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/type_guards.md @@ -0,0 +1,227 @@ +# User-defined type guards + +User-defined type guards are functions of which the return type is either `TypeGuard[...]` or +`TypeIs[...]`. + +## Display + +```py +from knot_extensions import Intersection, Not, TypeOf +from typing_extensions import TypeGuard, TypeIs + +def _( + a: TypeGuard[str], + b: TypeIs[str | int], + c: TypeGuard[Intersection[complex, Not[int], Not[float]]], + d: TypeIs[tuple[TypeOf[bytes]]], +): + reveal_type(a) # revealed: TypeGuard[str] + reveal_type(b) # revealed: TypeIs[str | int] + reveal_type(c) # revealed: TypeGuard[complex & ~int & ~float] + reveal_type(d) # revealed: TypeIs[tuple[Literal[bytes]]] + +def f(a) -> TypeGuard[str]: ... +def g(a) -> TypeIs[str]: ... + +def _(a: object): + reveal_type(f(a)) # revealed: TypeGuard[a, str] + reveal_type(g(a)) # revealed: TypeIs[a, str] +``` + +## Parameters + +A user-defined type guard must accept at least one positional argument, (in addition to `self`/`cls` +for non-static methods). + +```py +from typing_extensions import TypeGuard, TypeIs + +# error: [invalid-type-guard-definition] +def _() -> TypeGuard[str]: ... + +# error: [invalid-type-guard-definition] +def _(**kwargs) -> TypeIs[str]: ... + +class _: + def _(self, /, a) -> TypeGuard[str]: ... + @classmethod + def _(cls, a) -> TypeGuard[str]: ... + @staticmethod + def _(a) -> TypeIs[str]: ... + + def _(self) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition] + def _(self, /, *, a) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition] + @classmethod + def _(cls) -> TypeIs[str]: ... # error: [invalid-type-guard-definition] + @classmethod + def _() -> TypeIs[str]: ... # error: [invalid-type-guard-definition] + @staticmethod + def _(*, a) -> TypeGuard[str]: ... # error: [invalid-type-guard-definition] +``` + +For `TypeIs` functions, the narrowed type must be assignable to the declared type of that parameter, +if any. + +```py +from typing import Any +from typing_extensions import TypeGuard, TypeIs + +def _(a: object) -> TypeIs[str]: ... +def _(a: Any) -> TypeIs[str]: ... +def _(a: tuple[object]) -> TypeIs[tuple[str]]: ... +def _(a: str | Any) -> TypeIs[str]: ... +def _(a) -> TypeIs[str]: ... + +# error: [invalid-type-guard-definition] +def _(a: int) -> TypeIs[str]: ... +# error: [invalid-type-guard-definition] +def _(a: bool | str) -> TypeIs[int]: ... +``` + +## Arguments to special forms + +`TypeGuard` and `TypeIs` accept exactly one type argument. + +```py +from typing_extensions import TypeGuard, TypeIs + +a = 123 + +# error: [invalid-type-form] +def f(_) -> TypeGuard[int, str]: ... +# error: [invalid-type-form] +def g(_) -> TypeIs[a, str]: ... + +reveal_type(f(0)) # revealed: Unknown +reveal_type(g(0)) # revealed: Unknown +``` + +## Return types + +All code paths in a type guard function must return booleans. + +```py +from typing_extensions import Literal, TypeGuard, TypeIs, assert_never + +def f(a: object, flag: bool) -> TypeGuard[str]: + if flag: + # TODO: Emit a diagnostic + return 1 + + # TODO: Emit a diagnostic + return '' + +def g(a: Literal['foo', 'bar']) -> TypeIs[Literal['foo']]: + match a: + case 'foo': + # Logically wrong, but allowed regardless + return False + case 'bar': + return False + case _: + assert_never(a) +``` + +## Invalid calls + +```py +from typing import Any +from typing_extensions import TypeGuard, TypeIs + +def f(a: object) -> TypeGuard[str]: ... +def g(a: object) -> TypeIs[int]: ... + +def _(d: Any): + if f(): # error: [missing-argument] + ... + + # TODO: Is this error correct? + if g(*d): # error: [missing-argument] + ... + + if f("foo"): # error: [invalid-type-guard-call] + ... + + if g(a=d): # error: [invalid-type-guard-call] + ... + +def _(a: tuple[str, int] | tuple[int, str]): + if g(a[0]): # error: [invalid-type-guard-call] + # TODO: Should be `tuple[str, int]` + reveal_type(a) # revealed: tuple[str, int] | tuple[int, str] +``` + +## Narrowing + +```py +from typing import Any +from typing_extensions import TypeGuard, TypeIs + +def guard_str(a: object) -> TypeGuard[str]: ... +def is_int(a: object) -> TypeIs[int]: ... + +def _(a: str | int): + if guard_str(a): + reveal_type(a) # revealed: str + else: + reveal_type(a) # revealed: str | int + + if is_int(a): + reveal_type(a) # revealed: int + else: + reveal_type(a) # revealed: str & ~int + +def _(a: str | int): + b = guard_str(a) + c = is_int(a) + + reveal_type(a) # revealed: str | int + reveal_type(b) # revealed: TypeGuard[a, str] + reveal_type(c) # revealed: TypeIs[a, int] + + if b: + reveal_type(a) # revealed: str + else: + reveal_type(a) # revealed: str | int + + if c: + reveal_type(a) # revealed: int + else: + reveal_type(a) # revealed: str + +def _(x: str | int, flag: bool) -> None: + b = is_int(x) + reveal_type(b) # revealed: TypeIs[x, int] + + if flag: + x = '' + + if b: + reveal_type(x) # revealed: str | int +``` + +## `TypeGuard` special cases + +```py +from typing import Any +from typing_extensions import TypeGuard + +def guard_int(a: object) -> TypeGuard[int]: ... +def is_int(a: object) -> TypeGuard[int]: ... + +def does_not_narrow_in_negative_case(a: str | int): + if not guard_int(a): + reveal_type(a) # revealed: str | int + else: + reveal_type(a) # revealed: int + +def narrowed_type_must_be_exact(a: object, b: bool): + if guard_int(b): + reveal_type(b) # revealed: int + + if isinstance(a, bool) and is_int(a): + reveal_type(a) # revealed: bool + + if isinstance(a, bool) and guard_int(a): + reveal_type(a) # revealed: int +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md index 95907565d82892..28bd65fdfa42e8 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md +++ b/crates/red_knot_python_semantic/resources/mdtest/type_properties/is_subtype_of.md @@ -292,6 +292,31 @@ static_assert(not is_subtype_of(str, AlwaysTruthy)) static_assert(not is_subtype_of(str, AlwaysFalsy)) ``` +### `TypeGuard` and `TypeIs` + +`TypeGuard[...]` and `TypeIs[...]` are subtypes of `bool`. + +```py +from knot_extensions import is_subtype_of, static_assert +from typing_extensions import TypeGuard, TypeIs + +static_assert(is_subtype_of(TypeGuard[int], bool)) +static_assert(is_subtype_of(TypeIs[str], bool)) +``` + +`TypeIs` is invariant. `TypeGuard` is covariant. + +```py +from knot_extensions import is_subtype_of, static_assert +from typing_extensions import TypeGuard, TypeIs + +static_assert(is_subtype_of(TypeGuard[bool], TypeGuard[int])) + +static_assert(not is_subtype_of(TypeGuard[int], TypeGuard[bool])) +static_assert(not is_subtype_of(TypeIs[bool], TypeIs[int])) +static_assert(not is_subtype_of(TypeIs[int], TypeIs[bool])) +``` + ### Module literals ```py diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 7c89f13e751ba6..e90f8f869e634a 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -28,7 +28,7 @@ use crate::module_resolver::{file_to_module, resolve_module, KnownModule}; use crate::semantic_index::ast_ids::HasScopedExpressionId; use crate::semantic_index::attribute_assignment::AttributeAssignment; use crate::semantic_index::definition::Definition; -use crate::semantic_index::symbol::ScopeId; +use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId}; use crate::semantic_index::{ attribute_assignments, imported_modules, semantic_index, symbol_table, use_def_map, }; @@ -211,7 +211,7 @@ pub enum Type<'db> { ModuleLiteral(ModuleLiteralType<'db>), /// A specific class object ClassLiteral(ClassLiteralType<'db>), - // The set of all class objects that are subclasses of the given class (C), spelled `type[C]`. + /// The set of all class objects that are subclasses of the given class (C), spelled `type[C]`. SubclassOf(SubclassOfType<'db>), /// The set of Python objects with the given class in their __class__'s method resolution order Instance(InstanceType<'db>), @@ -242,6 +242,10 @@ pub enum Type<'db> { /// A heterogeneous tuple type, with elements of the given types in source order. // TODO: Support variable length homogeneous tuple type like `tuple[int, ...]`. Tuple(TupleType<'db>), + /// A subtype of `bool` that allows narrowing in positive cases. + TypeGuard(TypeGuardType<'db>), + /// A subtype of `bool` that allows narrowing in both positive and negative cases. + TypeIs(TypeIsType<'db>), // TODO protocols, callable types, overloads, generics, type vars } @@ -475,7 +479,9 @@ impl<'db> Type<'db> { | Type::ClassLiteral(_) | Type::KnownInstance(_) | Type::IntLiteral(_) - | Type::SubclassOf(_) => self, + | Type::SubclassOf(_) + | Type::TypeGuard(_) + | Type::TypeIs(_) => self, } } @@ -599,9 +605,17 @@ impl<'db> Type<'db> { (Type::StringLiteral(_) | Type::LiteralString, _) => { KnownClass::Str.to_instance(db).is_subtype_of(db, target) } - (Type::BooleanLiteral(_), _) => { + + // `TypeGuard` is covariant + (Type::TypeGuard(left), Type::TypeGuard(right)) => { + left.ty(db).is_subtype_of(db, *right.ty(db)) + } + + // `TypeGuard[T]` and `TypeIs[T]` are subtypes of `bool`. + (Type::BooleanLiteral(_) | Type::TypeGuard(_) | Type::TypeIs(_), _) => { KnownClass::Bool.to_instance(db).is_subtype_of(db, target) } + (Type::IntLiteral(_), _) => KnownClass::Int.to_instance(db).is_subtype_of(db, target), (Type::BytesLiteral(_), _) => { KnownClass::Bytes.to_instance(db).is_subtype_of(db, target) @@ -1074,14 +1088,21 @@ impl<'db> Type<'db> { known_instance_ty.is_disjoint_from(db, KnownClass::Tuple.to_instance(db)) } - (Type::BooleanLiteral(..), Type::Instance(InstanceType { class })) - | (Type::Instance(InstanceType { class }), Type::BooleanLiteral(..)) => { + ( + Type::BooleanLiteral(..) | Type::TypeGuard(_) | Type::TypeIs(_), + Type::Instance(InstanceType { class }), + ) + | ( + Type::Instance(InstanceType { class }), + Type::BooleanLiteral(..) | Type::TypeGuard(_) | Type::TypeIs(_), + ) => { // A `Type::BooleanLiteral()` must be an instance of exactly `bool` // (it cannot be an instance of a `bool` subclass) !KnownClass::Bool.is_subclass_of(db, class) } - (Type::BooleanLiteral(..), _) | (_, Type::BooleanLiteral(..)) => true, + (Type::BooleanLiteral(..) | Type::TypeGuard(_) | Type::TypeIs(_), _) + | (_, Type::BooleanLiteral(..) | Type::TypeGuard(_) | Type::TypeIs(_)) => true, (Type::IntLiteral(..), Type::Instance(InstanceType { class })) | (Type::Instance(InstanceType { class }), Type::IntLiteral(..)) => { @@ -1255,6 +1276,8 @@ impl<'db> Type<'db> { .elements(db) .iter() .all(|elem| elem.is_fully_static(db)), + Type::TypeGuard(type_guard) => type_guard.ty(db).is_fully_static(db), + Type::TypeIs(type_is) => type_is.ty(db).is_fully_static(db), // TODO: Once we support them, make sure that we return `false` for other types // containing gradual forms such as `tuple[Any, ...]` or `Callable[..., str]`. // Conversely, make sure to return `true` for homogeneous tuples such as @@ -1274,7 +1297,9 @@ impl<'db> Type<'db> { | Type::StringLiteral(..) | Type::BytesLiteral(..) | Type::SliceLiteral(..) - | Type::LiteralString => { + | Type::LiteralString + | Type::TypeGuard(_) + | Type::TypeIs(_) => { // Note: The literal types included in this pattern are not true singletons. // There can be multiple Python objects (at different memory locations) that // are both of type Literal[345], for example. @@ -1361,7 +1386,9 @@ impl<'db> Type<'db> { | Type::Intersection(..) | Type::LiteralString | Type::AlwaysTruthy - | Type::AlwaysFalsy => false, + | Type::AlwaysFalsy + | Type::TypeGuard(_) + | Type::TypeIs(_) => false, } } @@ -1450,6 +1477,10 @@ impl<'db> Type<'db> { _ => KnownClass::Int.to_instance(db).static_member(db, name), }, + Type::TypeGuard(_) | Type::TypeIs(_) => { + KnownClass::Bool.to_instance(db).member(db, name) + } + Type::BooleanLiteral(bool_value) => match name { "real" | "numerator" => Symbol::bound(Type::IntLiteral(i64::from(*bool_value))), _ => KnownClass::Bool.to_instance(db).static_member(db, name), @@ -1608,7 +1639,9 @@ impl<'db> Type<'db> { | Type::SliceLiteral(..) | Type::Tuple(..) | Type::KnownInstance(..) - | Type::FunctionLiteral(..) => { + | Type::FunctionLiteral(..) + | Type::TypeGuard(..) + | Type::TypeIs(..) => { let member = self.static_member(db, name); let instance = Some(*self); @@ -1679,6 +1712,7 @@ impl<'db> Type<'db> { ) -> Result> { let truthiness = match self { Type::Dynamic(_) | Type::Never => Truthiness::Ambiguous, + Type::TypeGuard(_) | Type::TypeIs(_) => Truthiness::Ambiguous, Type::FunctionLiteral(_) => Truthiness::AlwaysTrue, Type::Callable(_) => Truthiness::AlwaysTrue, Type::ModuleLiteral(_) => Truthiness::AlwaysTrue, @@ -2421,7 +2455,9 @@ impl<'db> Type<'db> { | Type::Tuple(_) | Type::LiteralString | Type::AlwaysTruthy - | Type::AlwaysFalsy => Type::unknown(), + | Type::AlwaysFalsy + | Type::TypeGuard(_) + | Type::TypeIs(_) => Type::unknown(), } } @@ -2628,6 +2664,7 @@ impl<'db> Type<'db> { ), }, + Type::TypeGuard(_) | Type::TypeIs(_) => KnownClass::Bool.to_class_literal(db), Type::StringLiteral(_) | Type::LiteralString => KnownClass::Str.to_class_literal(db), Type::Dynamic(dynamic) => SubclassOfType::from(db, ClassBase::Dynamic(*dynamic)), // TODO intersections @@ -3323,6 +3360,10 @@ pub enum KnownInstanceType<'db> { TypeVar(TypeVarInstance<'db>), /// A single instance of `typing.TypeAliasType` (PEP 695 type alias) TypeAliasType(TypeAliasType<'db>), + /// The symbol `typing.TypeGuard` (which can also be found as `typing_extensions.TypeGuard`) + TypeGuard, + /// The symbol `typing.TypeIs` (which can also be found as `typing_extensions.TypeIs`) + TypeIs, /// The symbol `knot_extensions.Unknown` Unknown, /// The symbol `knot_extensions.AlwaysTruthy` @@ -3347,8 +3388,6 @@ pub enum KnownInstanceType<'db> { Required, NotRequired, TypeAlias, - TypeGuard, - TypeIs, ReadOnly, // TODO: fill this enum out with more special forms, etc. } @@ -3981,6 +4020,8 @@ pub enum KnownFunction { Repr, /// `typing(_extensions).final` Final, + /// `builtins.staticmethod` + StaticMethod, /// [`typing(_extensions).no_type_check`](https://typing.readthedocs.io/en/latest/spec/directives.html#no-type-check) NoTypeCheck, @@ -4033,6 +4074,7 @@ impl KnownFunction { "len" => Self::Len, "repr" => Self::Repr, "final" => Self::Final, + "staticmethod" => Self::StaticMethod, "no_type_check" => Self::NoTypeCheck, "assert_type" => Self::AssertType, "cast" => Self::Cast, @@ -4062,7 +4104,7 @@ impl KnownFunction { module.is_builtins() } }, - Self::Len | Self::Repr => module.is_builtins(), + Self::Len | Self::Repr | Self::StaticMethod => module.is_builtins(), Self::AssertType | Self::Cast | Self::RevealType | Self::Final | Self::NoTypeCheck => { matches!(module, KnownModule::Typing | KnownModule::TypingExtensions) } @@ -4104,7 +4146,8 @@ impl KnownFunction { | Self::NoTypeCheck | Self::RevealType | Self::GetattrStatic - | Self::StaticAssert => ParameterExpectations::AllValueExpressions, + | Self::StaticAssert + | Self::StaticMethod => ParameterExpectations::AllValueExpressions, } } } @@ -5339,6 +5382,7 @@ impl SliceLiteralType<'_> { (self.start(db), self.stop(db), self.step(db)) } } + #[salsa::interned] pub struct TupleType<'db> { #[return_ref] @@ -5404,6 +5448,63 @@ impl<'db> TupleType<'db> { } } +macro_rules! type_guard_type_is_impl { + ($type:ident, $struct:ident) => { + impl<'db> $struct<'db> { + pub fn unbound(db: &'db dyn Db, ty: Type<'db>) -> Type<'db> { + Type::$type(Self::new(db, ty, None)) + } + + pub fn bound( + db: &'db dyn Db, + ty: Type<'db>, + scope: ScopeId<'db>, + symbol: ScopedSymbolId, + name: String, + ) -> Type<'db> { + Type::$type(Self::new(db, ty, Some((scope, symbol, name)))) + } + + #[must_use] + pub fn bind( + self, + db: &'db dyn Db, + scope: ScopeId<'db>, + symbol: ScopedSymbolId, + name: String, + ) -> Type<'db> { + Self::bound(db, *self.ty(db), scope, symbol, name) + } + + pub fn is_bound(&self, db: &'db dyn Db) -> bool { + self.symbol_info(db).is_some() + } + + pub fn is_unbound(&self, db: &'db dyn Db) -> bool { + self.symbol_info(db).is_none() + } + } + }; +} + +#[salsa::interned] +pub struct TypeGuardType<'db> { + #[return_ref] + ty: Type<'db>, + symbol_info: Option<(ScopeId<'db>, ScopedSymbolId, String)>, +} + +type_guard_type_is_impl!(TypeGuard, TypeGuardType); + +#[salsa::interned] +pub struct TypeIsType<'db> { + #[return_ref] + ty: Type<'db>, + symbol_info: Option<(ScopeId<'db>, ScopedSymbolId, String)>, +} + +type_guard_type_is_impl!(TypeIs, TypeIsType); + // Make sure that the `Type` enum does not grow unexpectedly. #[cfg(not(debug_assertions))] #[cfg(target_pointer_width = "64")] diff --git a/crates/red_knot_python_semantic/src/types/class_base.rs b/crates/red_knot_python_semantic/src/types/class_base.rs index e3ac7bc7d1f3ad..fad2f1e9581fa3 100644 --- a/crates/red_knot_python_semantic/src/types/class_base.rs +++ b/crates/red_knot_python_semantic/src/types/class_base.rs @@ -82,7 +82,9 @@ impl<'db> ClassBase<'db> { | Type::ModuleLiteral(_) | Type::SubclassOf(_) | Type::AlwaysFalsy - | Type::AlwaysTruthy => None, + | Type::AlwaysTruthy + | Type::TypeGuard(_) + | Type::TypeIs(_) => None, Type::KnownInstance(known_instance) => match known_instance { KnownInstanceType::TypeVar(_) | KnownInstanceType::TypeAliasType(_) diff --git a/crates/red_knot_python_semantic/src/types/diagnostic.rs b/crates/red_knot_python_semantic/src/types/diagnostic.rs index 29538ac353ff8d..18a30f70b4dd87 100644 --- a/crates/red_knot_python_semantic/src/types/diagnostic.rs +++ b/crates/red_knot_python_semantic/src/types/diagnostic.rs @@ -40,6 +40,8 @@ pub(crate) fn register_lints(registry: &mut LintRegistryBuilder) { registry.register_lint(&INVALID_PARAMETER_DEFAULT); registry.register_lint(&INVALID_RAISE); registry.register_lint(&INVALID_TYPE_FORM); + registry.register_lint(&INVALID_TYPE_GUARD_DEFINITION); + registry.register_lint(&INVALID_TYPE_GUARD_CALL); registry.register_lint(&INVALID_TYPE_VARIABLE_CONSTRAINTS); registry.register_lint(&MISSING_ARGUMENT); registry.register_lint(&NON_SUBSCRIPTABLE); @@ -425,6 +427,40 @@ declare_lint! { } } +declare_lint! { + /// ## What it does + /// Checks for type guard functions without + /// a first non-self-like non-keyword-only non-variadic parameter. + /// + /// ## Why is this bad? + /// Type narrowing functions must accept at least one positional argument + /// (non-static methods must accept another in addition to `self`/`cls`). + /// + /// Extra parameters/arguments are allowed but do not affect narrowing. + pub(crate) static INVALID_TYPE_GUARD_DEFINITION = { + summary: "detects malformed type guard functions", + status: LintStatus::preview("1.0.0"), + default_level: Level::Error, + } +} + +declare_lint! { + /// ## What it does + /// Checks for type guard function calls without a valid target. + /// + /// ## Why is this bad? + /// The first non-keyword non-variadic argument to a type guard function + /// is its target and must map to a symbol. + /// + /// Starred (`is_str(*a)`), literal (`is_str(42)`) and other non-symbol-like + /// expressions are invalid as narrowing targets. + pub(crate) static INVALID_TYPE_GUARD_CALL = { + summary: "detects type guard function calls that has no narrowing effect", + status: LintStatus::preview("1.0.0"), + default_level: Level::Error, + } +} + declare_lint! { /// TODO #14889 pub(crate) static INVALID_TYPE_VARIABLE_CONSTRAINTS = { @@ -1151,3 +1187,36 @@ pub(crate) fn report_invalid_arguments_to_annotated<'db>( ), ); } + +pub(crate) fn report_type_guard_function_with_incorrect_arity( + context: &InferContext, + node: AnyNodeRef, + is_non_static_method: bool, +) { + context.report_lint( + &INVALID_TYPE_GUARD_DEFINITION, + node, + format_args!( + "This type guard function must accept at least {} positional arguments", + if is_non_static_method { 2 } else { 1 } + ), + ) +} + +pub(crate) fn report_typeis_function_with_incorrect_types<'db>( + db: &'db dyn Db, + context: &InferContext<'db>, + node: AnyNodeRef, + input_ty: Type, + return_ty: Type, +) { + context.report_lint( + &INVALID_TYPE_GUARD_DEFINITION, + node, + format_args!( + "Return type `{}` is not assignable to input type `{}`", + return_ty.display(db), + input_ty.display(db), + ), + ) +} diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 95dd427aa64208..9cfeaa4029f0b8 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -152,6 +152,24 @@ impl Display for DisplayRepresentation<'_> { } Type::AlwaysTruthy => f.write_str("AlwaysTruthy"), Type::AlwaysFalsy => f.write_str("AlwaysFalsy"), + Type::TypeGuard(type_guard) => { + f.write_str("TypeGuard[")?; + if let Some((_, _, name)) = type_guard.symbol_info(self.db) { + f.write_str(&name)?; + f.write_str(", ")?; + } + type_guard.ty(self.db).display(self.db).fmt(f)?; + f.write_str("]") + } + Type::TypeIs(type_is) => { + f.write_str("TypeIs[")?; + if let Some((_, _, name)) = type_is.symbol_info(self.db) { + f.write_str(&name)?; + f.write_str(", ")?; + } + type_is.ty(self.db).display(self.db).fmt(f)?; + f.write_str("]") + } } } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 5265597f1c202a..6386b9d515a574 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -46,9 +46,9 @@ use crate::semantic_index::definition::{ ExceptHandlerDefinitionKind, ForStmtDefinitionKind, TargetKind, }; use crate::semantic_index::expression::{Expression, ExpressionKind}; -use crate::semantic_index::semantic_index; use crate::semantic_index::symbol::{FileScopeId, NodeWithScopeKind, NodeWithScopeRef, ScopeId}; use crate::semantic_index::SemanticIndex; +use crate::semantic_index::{semantic_index, symbol_table}; use crate::symbol::{ builtins_module_scope, builtins_symbol, explicit_global_symbol, module_type_implicit_global_symbol, symbol, symbol_from_bindings, symbol_from_declarations, @@ -61,7 +61,7 @@ use crate::types::diagnostic::{ CALL_NON_CALLABLE, CALL_POSSIBLY_UNBOUND_METHOD, CONFLICTING_DECLARATIONS, CONFLICTING_METACLASS, CYCLIC_CLASS_DEFINITION, DIVISION_BY_ZERO, DUPLICATE_BASE, INCONSISTENT_MRO, INVALID_ATTRIBUTE_ACCESS, INVALID_BASE, INVALID_CONTEXT_MANAGER, - INVALID_DECLARATION, INVALID_PARAMETER_DEFAULT, INVALID_TYPE_FORM, + INVALID_DECLARATION, INVALID_PARAMETER_DEFAULT, INVALID_TYPE_FORM, INVALID_TYPE_GUARD_CALL, INVALID_TYPE_VARIABLE_CONSTRAINTS, POSSIBLY_UNBOUND_ATTRIBUTE, POSSIBLY_UNBOUND_IMPORT, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT, UNSUPPORTED_OPERATOR, }; @@ -72,8 +72,8 @@ use crate::types::{ IntersectionBuilder, IntersectionType, KnownClass, KnownFunction, KnownInstanceType, MetaclassCandidate, MetaclassErrorKind, SliceLiteralType, SubclassOfType, Symbol, SymbolAndQualifiers, Truthiness, TupleType, Type, TypeAliasType, TypeAndQualifiers, - TypeArrayDisplay, TypeQualifiers, TypeVarBoundOrConstraints, TypeVarInstance, UnionBuilder, - UnionType, + TypeArrayDisplay, TypeGuardType, TypeIsType, TypeQualifiers, TypeVarBoundOrConstraints, + TypeVarInstance, UnionBuilder, UnionType, }; use crate::unpack::Unpack; use crate::util::subscript::{PyIndex, PySlice}; @@ -3418,7 +3418,48 @@ impl<'db> TypeInferenceBuilder<'db> { } } - outcome.return_type(self.db()) + let db = self.db(); + let scope = self.scope(); + let return_ty = outcome.return_type(db); + + let find_narrowed_symbol = || match arguments.args.first() { + None => { + self.context.report_lint( + &INVALID_TYPE_GUARD_CALL, + arguments.range(), + format_args!("Type guard call does not have a target"), + ); + None + } + Some(ast::Expr::Name(ast::ExprName { id, .. })) => { + let name = id.as_str(); + let symbol = symbol_table(db, scope).symbol_id_by_name(name)?; + + Some((symbol, name.to_string())) + } + // TODO: Attribute and subscript narrowing + Some(expr) => { + self.context.report_lint( + &INVALID_TYPE_GUARD_CALL, + expr.range(), + format_args!("Type guard call target is not a symbol"), + ); + None + } + }; + + // TODO: Handle unions/intersections + match return_ty { + Type::TypeGuard(type_guard) => match find_narrowed_symbol() { + Some((symbol, name)) => type_guard.bind(db, scope, symbol, name), + None => return_ty, + }, + Type::TypeIs(type_is) => match find_narrowed_symbol() { + Some((symbol, name)) => type_is.bind(db, scope, symbol, name), + None => return_ty, + }, + _ => return_ty, + } } Err(err) => { // TODO: We currently only report the first error. Ideally, we'd report @@ -3827,7 +3868,9 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::LiteralString | Type::BytesLiteral(_) | Type::SliceLiteral(_) - | Type::Tuple(_), + | Type::Tuple(_) + | Type::TypeGuard(_) + | Type::TypeIs(_), ) => { let unary_dunder_method = match op { ast::UnaryOp::Invert => "__invert__", @@ -4043,7 +4086,9 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::LiteralString | Type::BytesLiteral(_) | Type::SliceLiteral(_) - | Type::Tuple(_), + | Type::Tuple(_) + | Type::TypeGuard(_) + | Type::TypeIs(_), Type::FunctionLiteral(_) | Type::Callable(..) | Type::ModuleLiteral(_) @@ -4060,7 +4105,9 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::LiteralString | Type::BytesLiteral(_) | Type::SliceLiteral(_) - | Type::Tuple(_), + | Type::Tuple(_) + | Type::TypeGuard(_) + | Type::TypeIs(_), op, ) => { // We either want to call lhs.__op__ or rhs.__rop__. The full decision tree from @@ -5977,6 +6024,28 @@ impl<'db> TypeInferenceBuilder<'db> { argument_type } }, + KnownInstanceType::TypeGuard | KnownInstanceType::TypeIs => match arguments_slice { + ast::Expr::Tuple(_) => { + self.context.report_lint( + &INVALID_TYPE_FORM, + subscript, + format_args!( + "Special form `{}` expected exactly one type parameter", + known_instance.repr(self.db()) + ), + ); + Type::unknown() + } + _ => { + let ty = self.infer_type_expression(arguments_slice); + + if matches!(known_instance, KnownInstanceType::TypeGuard) { + TypeGuardType::unbound(self.db(), ty) + } else { + TypeIsType::unbound(self.db(), ty) + } + } + }, // TODO: Generics KnownInstanceType::ChainMap => { @@ -6039,14 +6108,6 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_type_expression(arguments_slice); todo_type!("`Required[]` type qualifier") } - KnownInstanceType::TypeIs => { - self.infer_type_expression(arguments_slice); - todo_type!("`TypeIs[]` special form") - } - KnownInstanceType::TypeGuard => { - self.infer_type_expression(arguments_slice); - todo_type!("`TypeGuard[]` special form") - } KnownInstanceType::Concatenate => { self.infer_type_expression(arguments_slice); todo_type!("`Concatenate[]` special form") diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 8822a3a18f1902..5ffe8f433a3108 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -428,6 +428,45 @@ impl<'db> NarrowingConstraintsBuilder<'db> { // TODO: add support for PEP 604 union types on the right hand side of `isinstance` // and `issubclass`, for example `isinstance(x, str | (int | float))`. match callable_ty { + Type::FunctionLiteral(function_type) if function_type.known(self.db).is_none() => { + let return_ty = + inference.expression_type(expr_call.scoped_expression_id(self.db, scope)); + + // TODO: Handle unions and intersections + let (guarded_ty, symbol, is_typeguard) = match return_ty { + Type::TypeGuard(type_guard) => { + let (_, symbol, _) = type_guard.symbol_info(self.db)?; + (*type_guard.ty(self.db), symbol, true) + } + Type::TypeIs(type_is) => { + let (_, symbol, _) = type_is.symbol_info(self.db)?; + (*type_is.ty(self.db), symbol, false) + } + _ => return None, + }; + + let mut constraints = NarrowingConstraints::default(); + + // `TypeGuard` does not narrow in the negative case. + // ```python + // def f(a) -> TypeGuard[str]: ... + // + // a: str | int + // + // if not f(a): + // reveal_type(a) # str | int + // else: + // reveal_type(a) # str + // ``` + if is_positive || !is_typeguard { + constraints.insert( + symbol, + guarded_ty.negate_if(self.db, !is_positive && !is_typeguard), + ); + } + + Some(constraints) + } Type::FunctionLiteral(function_type) if expr_call.arguments.keywords.is_empty() => { let function = function_type .known(self.db) diff --git a/crates/red_knot_python_semantic/src/types/property_tests.rs b/crates/red_knot_python_semantic/src/types/property_tests.rs index 834e3417672511..23416da7674417 100644 --- a/crates/red_knot_python_semantic/src/types/property_tests.rs +++ b/crates/red_knot_python_semantic/src/types/property_tests.rs @@ -573,7 +573,7 @@ mod flaky { ); // Equal element sets of unions implies equivalence - // flaky at laest in part because of https://github.com/astral-sh/ruff/issues/15513 + // flaky at least in part because of https://github.com/astral-sh/ruff/issues/15513 type_property_test!( union_equivalence_not_order_dependent, db, forall types s, t, u. diff --git a/crates/red_knot_python_semantic/src/types/type_ordering.rs b/crates/red_knot_python_semantic/src/types/type_ordering.rs index 483843feb0a8b0..6bf77dad7608fe 100644 --- a/crates/red_knot_python_semantic/src/types/type_ordering.rs +++ b/crates/red_knot_python_semantic/src/types/type_ordering.rs @@ -110,6 +110,14 @@ pub(super) fn union_elements_ordering<'db>(left: &Type<'db>, right: &Type<'db>) Type::Instance(InstanceType { class: right }), ) => left.cmp(right), + (Type::TypeGuard(left), Type::TypeGuard(right)) => left.cmp(right), + (Type::TypeGuard(_), _) => Ordering::Less, + (_, Type::TypeGuard(_)) => Ordering::Greater, + + (Type::TypeIs(left), Type::TypeIs(right)) => left.cmp(right), + (Type::TypeIs(_), _) => Ordering::Less, + (_, Type::TypeIs(_)) => Ordering::Greater, + (Type::Instance(_), _) => Ordering::Less, (_, Type::Instance(_)) => Ordering::Greater,