Skip to content

Commit

Permalink
[WIP] initial fixpoint iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
carljm committed Nov 14, 2024
1 parent a48d779 commit c8fb8b3
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 32 deletions.
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ rand = { version = "0.8.5" }
rayon = { version = "1.10.0" }
regex = { version = "1.10.2" }
rustc-hash = { version = "2.0.0" }
salsa = { git = "https://github.com/salsa-rs/salsa.git", rev = "254c749b02cde2fd29852a7463a33e800b771758" }
salsa = { git = "https://github.com/salsa-rs/salsa.git", rev = "c1bbdcff28c2675f622d7e7fe10f5a0ca073f221" }
schemars = { version = "0.8.16" }
seahash = { version = "4.1.0" }
serde = { version = "1.0.197", features = ["derive"] }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl<'db> Definition<'db> {
self.file_scope(db).to_scope_id(db, self.file(db))
}

#[allow(unused)]
pub(crate) fn category(self, db: &'db dyn Db) -> DefinitionCategory {
self.kind(db).category()
}
Expand Down
68 changes: 40 additions & 28 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,30 +81,9 @@ pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Ty
TypeInferenceBuilder::new(db, InferenceRegion::Scope(scope), index).finish()
}

/// Cycle recovery for [`infer_definition_types()`]: for now, just [`Type::Unknown`]
/// TODO fixpoint iteration
fn infer_definition_types_cycle_recovery<'db>(
db: &'db dyn Db,
_cycle: &salsa::Cycle,
input: Definition<'db>,
) -> TypeInference<'db> {
tracing::trace!("infer_definition_types_cycle_recovery");
let mut inference = TypeInference::empty(input.scope(db));
let category = input.category(db);
if category.is_declaration() {
inference.declarations.insert(input, Type::Unknown);
}
if category.is_binding() {
inference.bindings.insert(input, Type::Unknown);
}
// TODO we don't fill in expression types for the cycle-participant definitions, which can
// later cause a panic when looking up an expression type.
inference
}

/// Infer all types for a [`Definition`] (including sub-expressions).
/// Use when resolving a symbol name use or public type of a symbol.
#[salsa::tracked(return_ref, recovery_fn=infer_definition_types_cycle_recovery)]
#[salsa::tracked(return_ref, cycle_fn=cycle_recover, cycle_initial=cycle_initial)]
pub(crate) fn infer_definition_types<'db>(
db: &'db dyn Db,
definition: Definition<'db>,
Expand All @@ -122,6 +101,20 @@ pub(crate) fn infer_definition_types<'db>(
TypeInferenceBuilder::new(db, InferenceRegion::Definition(definition), index).finish()
}

fn cycle_recover<'db>(
_db: &'db dyn Db,
_value: &TypeInference<'db>,
count: u32,
_definition: Definition<'db>,
) -> salsa::CycleRecoveryAction<TypeInference<'db>> {
assert!(count < 10, "cycle did not converge within 10 iterations");
salsa::CycleRecoveryAction::Iterate
}

fn cycle_initial<'db>(db: &'db dyn Db, definition: Definition<'db>) -> TypeInference<'db> {
TypeInference::empty(definition.scope(db), Some(Type::Never))
}

/// Infer types for all deferred type expressions in a [`Definition`].
///
/// Deferred expressions are type expressions (annotations, base classes, aliases...) in a stub
Expand Down Expand Up @@ -218,25 +211,33 @@ pub(crate) struct TypeInference<'db> {
/// Are there deferred type expressions in this region?
has_deferred: bool,

/// The scope belong to this region.
/// The scope this region is part of.
scope: ScopeId<'db>,

/// The fallback type for all expressions/bindings/declarations.
fallback_ty: Option<Type<'db>>,
}

impl<'db> TypeInference<'db> {
pub(crate) fn empty(scope: ScopeId<'db>) -> Self {
pub(crate) fn empty(scope: ScopeId<'db>, fallback_ty: Option<Type<'db>>) -> Self {
Self {
expressions: FxHashMap::default(),
bindings: FxHashMap::default(),
declarations: FxHashMap::default(),
diagnostics: TypeCheckDiagnostics::default(),
has_deferred: false,
scope,
fallback_ty,
}
}

#[track_caller]
pub(crate) fn expression_ty(&self, expression: ScopedExpressionId) -> Type<'db> {
self.expressions[&expression]
if let Some(fallback) = self.fallback_ty {
self.try_expression_ty(expression).unwrap_or(fallback)
} else {
self.expressions[&expression]
}
}

pub(crate) fn try_expression_ty(&self, expression: ScopedExpressionId) -> Option<Type<'db>> {
Expand All @@ -245,12 +246,23 @@ impl<'db> TypeInference<'db> {

#[track_caller]
pub(crate) fn binding_ty(&self, definition: Definition<'db>) -> Type<'db> {
self.bindings[&definition]
if let Some(fallback) = self.fallback_ty {
self.bindings.get(&definition).copied().unwrap_or(fallback)
} else {
self.bindings[&definition]
}
}

#[track_caller]
pub(crate) fn declaration_ty(&self, definition: Definition<'db>) -> Type<'db> {
self.declarations[&definition]
if let Some(fallback) = self.fallback_ty {
self.declarations
.get(&definition)
.copied()
.unwrap_or(fallback)
} else {
self.declarations[&definition]
}
}

pub(crate) fn diagnostics(&self) -> &[std::sync::Arc<TypeCheckDiagnostic>] {
Expand Down Expand Up @@ -358,7 +370,7 @@ impl<'db> TypeInferenceBuilder<'db> {
index,
region,
file,
types: TypeInference::empty(scope),
types: TypeInference::empty(scope, None),
diagnostics: TypeCheckDiagnosticsBuilder::new(db, file),
}
}
Expand Down

0 comments on commit c8fb8b3

Please sign in to comment.