From ade4c79a0e17301cc369d631aa0c0c90cba34bac Mon Sep 17 00:00:00 2001 From: InSyncWithFoo Date: Wed, 26 Feb 2025 13:30:46 +0000 Subject: [PATCH] Recognize named type guards --- crates/red_knot_python_semantic/src/types.rs | 8 +++-- .../src/types/narrow.rs | 36 ++++++++++++++++++- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 7c511a91519b54..35a884e9510e67 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -3583,6 +3583,7 @@ pub enum KnownFunction { Repr, /// `typing(_extensions).final` Final, + // TODO: Move this to `KnownClass` /// `builtins.staticmethod` StaticMethod, @@ -3642,7 +3643,9 @@ impl KnownFunction { /// Return `true` if `self` is defined in `module` at runtime. const fn check_module(self, module: KnownModule) -> bool { match self { - Self::IsInstance | Self::IsSubclass | Self::Len | Self::Repr | Self::StaticMethod => module.is_builtins(), + Self::IsInstance | Self::IsSubclass | Self::Len | Self::Repr | Self::StaticMethod => { + module.is_builtins() + } Self::AssertType | Self::Cast | Self::Overload @@ -4562,7 +4565,8 @@ pub(crate) mod tests { KnownFunction::Len | KnownFunction::Repr | KnownFunction::IsInstance - | KnownFunction::IsSubclass => KnownModule::Builtins, + | KnownFunction::IsSubclass + | KnownFunction::StaticMethod => KnownModule::Builtins, KnownFunction::GetattrStatic => KnownModule::Inspect, diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index 1064aba76084b7..a5c0826590ef4b 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -210,7 +210,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> { is_positive: bool, ) -> Option> { match expression_node { - ast::Expr::Name(name) => Some(self.evaluate_expr_name(name, is_positive)), + ast::Expr::Name(name) => Some(self.evaluate_expr_name(name, expression, is_positive)), ast::Expr::Compare(expr_compare) => { self.evaluate_expr_compare(expr_compare, expression, is_positive) } @@ -257,16 +257,50 @@ impl<'db> NarrowingConstraintsBuilder<'db> { fn evaluate_expr_name( &mut self, expr_name: &ast::ExprName, + expression: Expression<'db>, is_positive: bool, ) -> NarrowingConstraints<'db> { let ast::ExprName { id, .. } = expr_name; + let inference = infer_expression_types(self.db, expression); let symbol = self .symbols() .symbol_id_by_name(id) .expect("Should always have a symbol for every Name node"); + let ty = inference.expression_type(expr_name.scoped_expression_id(self.db, self.scope())); + let mut constraints = NarrowingConstraints::default(); + // TODO: Handle unions and intersections + let mut narrow_by_typeguards = || match ty { + Type::TypeGuard(type_guard) => { + let (_, guarded_symbol, _) = type_guard.symbol_info(self.db)?; + + if !is_positive { + return None; + } + + constraints.insert( + guarded_symbol, + type_guard.ty(self.db).negate_if(self.db, !is_positive), + ); + + Some(()) + } + Type::TypeIs(type_is) => { + let (_, guarded_symbol, _) = type_is.symbol_info(self.db)?; + + constraints.insert( + guarded_symbol, + type_is.ty(self.db).negate_if(self.db, !is_positive), + ); + + Some(()) + } + _ => None, + }; + narrow_by_typeguards(); + constraints.insert( symbol, if is_positive {