From f44f2f7054cdac20a94a299dd801ac967ef1496f Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Thu, 31 Oct 2024 17:14:01 -0700 Subject: [PATCH] pass inputs to cycle recovery functions --- .../salsa-macro-rules/src/setup_tracked_fn.rs | 7 +++-- .../src/unexpected_cycle_recovery.rs | 10 ++++--- src/function.rs | 3 +- src/function/execute.rs | 7 ++++- src/function/fetch.rs | 2 +- src/function/memo.rs | 8 +++-- tests/cycle/dataflow.rs | 30 ++++++++++++++++--- tests/cycle/main.rs | 8 ++--- 8 files changed, 55 insertions(+), 20 deletions(-) diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 184ccb50c..72a38039c 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -177,16 +177,17 @@ macro_rules! setup_tracked_fn { $inner($db, $($input_id),*) } - fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db) -> Self::Output<$db_lt> { - $($cycle_recovery_initial)*(db) + fn cycle_initial<$db_lt>(db: &$db_lt dyn $Db, ($($input_id),*): ($($input_ty),*)) -> Self::Output<$db_lt> { + $($cycle_recovery_initial)*(db, $($input_id),*) } fn recover_from_cycle<$db_lt>( db: &$db_lt dyn $Db, value: &Self::Output<$db_lt>, count: u32, + ($($input_id),*): ($($input_ty),*) ) -> $zalsa::CycleRecoveryAction> { - $($cycle_recovery_fn)*(db, value, count) + $($cycle_recovery_fn)*(db, value, count, $($input_id),*) } fn id_to_input<$db_lt>(db: &$db_lt Self::DbView, key: salsa::Id) -> Self::Input<$db_lt> { diff --git a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs index cf8bbce13..a1cd1e73f 100644 --- a/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs +++ b/components/salsa-macro-rules/src/unexpected_cycle_recovery.rs @@ -1,18 +1,20 @@ // Macro that generates the body of the cycle recovery function -// for the case where no cycle recovery is possible. Must be a macro -// because the signature types must match the particular tracked function. +// for the case where no cycle recovery is possible. This has to be +// a macro because it can take a variadic number of arguments. #[macro_export] macro_rules! unexpected_cycle_recovery { - ($db:ident, $value:ident, $count:ident) => {{ + ($db:ident, $value:ident, $count:ident, $($other_inputs:ident),*) => {{ std::mem::drop($db); + std::mem::drop(($($other_inputs),*)); panic!("cannot recover from cycle") }}; } #[macro_export] macro_rules! unexpected_cycle_initial { - ($db:ident) => {{ + ($db:ident, $($other_inputs:ident),*) => {{ std::mem::drop($db); + std::mem::drop(($($other_inputs),*)); panic!("no cycle initial value") }}; } diff --git a/src/function.rs b/src/function.rs index 225a9dead..97511d223 100644 --- a/src/function.rs +++ b/src/function.rs @@ -67,13 +67,14 @@ pub trait Configuration: Any { fn execute<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; /// Get the cycle recovery initial value. - fn cycle_initial(db: &Self::DbView) -> Self::Output<'_>; + fn cycle_initial<'db>(db: &'db Self::DbView, input: Self::Input<'db>) -> Self::Output<'db>; /// Decide whether to iterate a cycle again or fallback. fn recover_from_cycle<'db>( db: &'db Self::DbView, value: &Self::Output<'db>, count: u32, + input: Self::Input<'db>, ) -> CycleRecoveryAction>; } diff --git a/src/function/execute.rs b/src/function/execute.rs index 158f94a02..ef43fc2c8 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -90,7 +90,12 @@ where if !C::values_equal(&new_value, last_provisional_value) { // We are in a cycle that hasn't converged; ask the user's // cycle-recovery function what to do: - match C::recover_from_cycle(db, &new_value, iteration_count) { + match C::recover_from_cycle( + db, + &new_value, + iteration_count, + C::id_to_input(db, id), + ) { crate::CycleRecoveryAction::Iterate => { tracing::debug!("{database_key_index:?}: execute: iterate again"); iteration_count = iteration_count.checked_add(1).expect( diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 6dcc4e3ac..9fa08afa1 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -78,7 +78,7 @@ where ClaimResult::Retry => return None, ClaimResult::Cycle => { return self - .initial_value(db) + .initial_value(db, database_key_index.key_index) .map(|initial_value| { tracing::debug!( "hit cycle at {database_key_index:#?}, \ diff --git a/src/function/memo.rs b/src/function/memo.rs index 8bddd5fa1..8e167ccbd 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -88,9 +88,13 @@ impl IngredientImpl { } } - pub(super) fn initial_value<'db>(&'db self, db: &'db C::DbView) -> Option> { + pub(super) fn initial_value<'db>( + &'db self, + db: &'db C::DbView, + key: Id, + ) -> Option> { match C::CYCLE_STRATEGY { - CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db)), + CycleRecoveryStrategy::Fixpoint => Some(C::cycle_initial(db, C::id_to_input(db, key))), CycleRecoveryStrategy::Panic => None, } } diff --git a/tests/cycle/dataflow.rs b/tests/cycle/dataflow.rs index 53dc301b7..d8ef4cf3a 100644 --- a/tests/cycle/dataflow.rs +++ b/tests/cycle/dataflow.rs @@ -47,7 +47,7 @@ impl Type { } } -#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_fn=use_cycle_recover, cycle_initial=use_cycle_initial)] fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { let defs = u.reaching_definitions(db); match defs[..] { @@ -57,7 +57,7 @@ fn infer_use<'db>(db: &'db dyn Db, u: Use) -> Type { } } -#[salsa::tracked(cycle_fn=cycle_recover, cycle_initial=cycle_initial)] +#[salsa::tracked(cycle_fn=def_cycle_recover, cycle_initial=def_cycle_initial)] fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { let increment_ty = Type::Values(Box::from([def.increment(db)])); if let Some(base) = def.base(db) { @@ -68,11 +68,33 @@ fn infer_definition<'db>(db: &'db dyn Db, def: Definition) -> Type { } } -fn cycle_initial(_db: &dyn Db) -> Type { +fn def_cycle_initial(_db: &dyn Db, _def: Definition) -> Type { Type::Bottom } -fn cycle_recover(_db: &dyn Db, value: &Type, count: u32) -> CycleRecoveryAction { +fn def_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _def: Definition, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn use_cycle_initial(_db: &dyn Db, _use: Use) -> Type { + Type::Bottom +} + +fn use_cycle_recover( + _db: &dyn Db, + value: &Type, + count: u32, + _use: Use, +) -> CycleRecoveryAction { + cycle_recover(value, count) +} + +fn cycle_recover(value: &Type, count: u32) -> CycleRecoveryAction { match value { Type::Bottom => CycleRecoveryAction::Iterate, Type::Values(_) => { diff --git a/tests/cycle/main.rs b/tests/cycle/main.rs index 05100f650..09cb6e830 100644 --- a/tests/cycle/main.rs +++ b/tests/cycle/main.rs @@ -76,7 +76,7 @@ const MIN_COUNT_FALLBACK: u8 = 100; const MIN_VALUE_FALLBACK: u8 = 5; const MIN_VALUE: u8 = 10; -fn min_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction { +fn min_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction { if *value < MIN_VALUE { CycleRecoveryAction::Fallback(MIN_VALUE_FALLBACK) } else if count > 10 { @@ -86,7 +86,7 @@ fn min_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction } } -fn min_initial(_db: &dyn Db) -> u8 { +fn min_initial(_db: &dyn Db, _inputs: Inputs) -> u8 { 255 } @@ -99,7 +99,7 @@ const MAX_COUNT_FALLBACK: u8 = 200; const MAX_VALUE_FALLBACK: u8 = 250; const MAX_VALUE: u8 = 245; -fn max_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction { +fn max_recover(_db: &dyn Db, value: &u8, count: u32, _inputs: Inputs) -> CycleRecoveryAction { if *value > MAX_VALUE { CycleRecoveryAction::Fallback(MAX_VALUE_FALLBACK) } else if count > 10 { @@ -109,7 +109,7 @@ fn max_recover(_db: &dyn Db, value: &u8, count: u32) -> CycleRecoveryAction } } -fn max_initial(_db: &dyn Db) -> u8 { +fn max_initial(_db: &dyn Db, _inputs: Inputs) -> u8 { 0 }