From dbe071caa5af586ba1a06fcadf1ceb928a0a04b1 Mon Sep 17 00:00:00 2001 From: Victor Menge Date: Wed, 17 Jan 2024 12:33:57 +0100 Subject: [PATCH] fix: when using one for all, stop all procs before restarting --- Cargo.lock | 2 +- core/Cargo.toml | 2 +- core/src/lib.rs | 111 +++++++++++++++++++++++++++++++------- core/tests/supervision.rs | 107 ++++++++++++++++++++++++++++++++++++ 4 files changed, 201 insertions(+), 21 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b32998a..a085bec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -255,7 +255,7 @@ checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" [[package]] name = "speare" -version = "0.1.6" +version = "0.1.7" dependencies = [ "async-trait", "derive_more", diff --git a/core/Cargo.toml b/core/Cargo.toml index 1b270b3..db65757 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "speare" -version = "0.1.6" +version = "0.1.7" edition = "2021" license = "MIT" description = "actor-like thin abstraction over tokio::task and flume channels" diff --git a/core/src/lib.rs b/core/src/lib.rs index ec2f119..704cfae 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -199,7 +199,7 @@ where match self { Self::Handle => write!(f, "manual exit through Handle::stop"), Self::Parent => write!(f, "exit request from parent supervision strategy"), - Self::Err(e) => write!(f, "error: {e}"), + Self::Err(e) => write!(f, "{e}"), } } } @@ -671,7 +671,12 @@ where handle } - async fn handle_err(&mut self, e: Box, proc_id: u64) -> Option<()> { + async fn handle_err( + &mut self, + e: Box, + proc_id: u64, + err_ack: Sender<()>, + ) -> Option<()> { let directive = self .supervision .deciders @@ -688,7 +693,9 @@ where .get_backoff_duration(self.supervision.max_restarts, self.supervision.backoff) { Some(delay) => { - let _ = child.send(ProcMsg::FromParent(ProcAction::RestartIn(delay))); + let _ = child.send(ProcMsg::FromParent(ProcAction::Restart( + Restart::without_ack(delay), + ))); } None => { @@ -697,6 +704,8 @@ where self.children_proc_msg_tx.remove(&proc_id); } } + + let _ = err_ack.send(()); } (Strategy::OneForOne { counter }, Directive::Stop) => { @@ -704,6 +713,8 @@ where let _ = child.send(ProcMsg::FromParent(ProcAction::Stop)); counter.remove(&proc_id); self.children_proc_msg_tx.remove(&proc_id); + + let _ = err_ack.send(()); } (Strategy::OneForAll { counter }, Directive::Restart) => { @@ -711,8 +722,25 @@ where .get_backoff_duration(self.supervision.max_restarts, self.supervision.backoff) { Some(delay) => { + let _ = err_ack.send(()); + + let mut exit_ack_rxs = vec![]; + let mut can_restart_txs = vec![]; + for child in self.children_proc_msg_tx.values() { - let _ = child.send(ProcMsg::FromParent(ProcAction::RestartIn(delay))); + let (restart, exit_ack_rx, can_restart_tx) = Restart::with_ack(delay); + exit_ack_rxs.push(exit_ack_rx); + can_restart_txs.push(can_restart_tx); + + let _ = child.send(ProcMsg::FromParent(ProcAction::Restart(restart))); + } + + for rx in exit_ack_rxs { + let _ = rx.recv_async().await; + } + + for tx in can_restart_txs { + let _ = tx.send(()); } } @@ -727,6 +755,8 @@ where } (Strategy::OneForAll { .. }, Directive::Stop) => { + let _ = err_ack.send(()); + for child in self.children_proc_msg_tx.values() { let _ = child.send(ProcMsg::FromParent(ProcAction::Stop)); } @@ -741,9 +771,13 @@ where err: e, ack: tx, }); + + let _ = err_ack.send(()); } - (_, Directive::Resume) => (), + (_, Directive::Resume) => { + let _ = err_ack.send(()); + } }; task::yield_now().await; @@ -765,9 +799,48 @@ enum ProcMsg { FromHandle(ProcAction), } +#[derive(Debug)] +struct Restart { + delay: Duration, + exit_ack_tx: Option>, + can_restart_rx: Option>, +} + +impl Restart { + fn with_ack(delay: Duration) -> (Self, Receiver<()>, Sender<()>) { + let (exit_ack_tx, exit_ack_rx) = flume::unbounded(); + let (can_restart_tx, can_restart_rx) = flume::unbounded(); + + let restart = Restart { + delay, + exit_ack_tx: Some(exit_ack_tx), + can_restart_rx: Some(can_restart_rx), + }; + + (restart, exit_ack_rx, can_restart_tx) + } + + fn without_ack(delay: Duration) -> Restart { + Restart { + delay, + exit_ack_tx: None, + can_restart_rx: None, + } + } + + /// Waits for signal from Parent to restart + async fn sync(&self) { + if let (Some(exit_ack_tx), Some(can_restart_rx)) = (&self.exit_ack_tx, &self.can_restart_rx) + { + let _ = exit_ack_tx.send(()); + let _ = can_restart_rx.recv_async().await; + } + } +} + #[derive(Debug)] enum ProcAction { - RestartIn(Duration), + Restart(Restart), Stop, } @@ -781,7 +854,7 @@ where time::sleep(d).await; } - let mut restart_in = None; + let mut restart = None; match Child::init(&mut ctx).await { Err(e) => { @@ -796,8 +869,8 @@ where loop { if let Ok(ProcMsg::FromParent(proc_action)) = ctx.proc_msg_rx.recv_async().await { - if let ProcAction::RestartIn(dur) = proc_action { - restart_in = Some(dur); + if let ProcAction::Restart(r) = proc_action { + restart = Some(r); } break; @@ -810,8 +883,9 @@ where task::yield_now().await; - if restart_in.is_some() { - spawn::(ctx, restart_in) + if let Some(r) = restart { + r.sync().await; + spawn::(ctx, Some(r.delay)) } } @@ -836,15 +910,14 @@ where break }, - Ok(ProcMsg::FromParent(ProcAction::RestartIn(dur))) => { + Ok(ProcMsg::FromParent(ProcAction::Restart(r))) => { exit_reason = exit_reason.or(Some(ExitReason::Parent)); - restart_in = Some(dur); + restart = Some(r); break; } Ok(ProcMsg::FromChild { child_id, err, ack }) => { - ctx.handle_err(err, child_id).await; - let _ = ack.send(()); + ctx.handle_err(err, child_id, ack).await; } Ok(_) => () @@ -882,8 +955,9 @@ where let exit_reason = exit_reason.unwrap_or(ExitReason::Handle); process.exit(exit_reason, &mut ctx).await; - if restart_in.is_some() { - spawn::(ctx, restart_in) + if let Some(r) = restart { + r.sync().await; + spawn::(ctx, Some(r.delay)) } } } @@ -957,8 +1031,7 @@ impl Node { }, Ok(ProcMsg::FromChild { child_id, err, ack }) => { - ctx.handle_err(err, child_id).await; - let _ = ack.send(()); + ctx.handle_err(err, child_id, ack).await; } _ => {} diff --git a/core/tests/supervision.rs b/core/tests/supervision.rs index 44f4d19..b31827b 100644 --- a/core/tests/supervision.rs +++ b/core/tests/supervision.rs @@ -348,6 +348,8 @@ mod one_for_one { } mod one_for_all { + use crate::sync_vec::SyncVec; + use super::*; #[derive(Clone)] @@ -658,4 +660,109 @@ mod one_for_all { // Assert assert_eq!(errors, vec!["EscalateChildErr".to_string()]) } + + struct Dad { + kid0: Handle<()>, + } + + #[async_trait] + impl Process for Dad { + type Props = SyncVec; + type Msg = (); + type Err = (); + + async fn init(ctx: &mut Ctx) -> Result { + ctx.props().push("Dad::init".to_string()).await; + let kid0 = ctx.spawn::((0, ctx.props().clone())); + ctx.spawn::((1, ctx.props().clone())); + ctx.spawn::((2, ctx.props().clone())); + + Ok(Self { kid0 }) + } + + async fn exit(&mut self, _: ExitReason, ctx: &mut Ctx) { + ctx.props().push("Dad::exit".to_string()).await; + } + + fn supervision(_: &Self::Props) -> Supervision { + Supervision::one_for_all().directive(Directive::Restart) + } + + async fn handle(&mut self, _: Self::Msg, _: &mut Ctx) -> Result<(), Self::Err> { + self.kid0.send(()); + Ok(()) + } + } + + struct Kid; + + #[async_trait] + impl Process for Kid { + type Props = (Id, SyncVec); + type Msg = (); + type Err = (); + + async fn init(ctx: &mut Ctx) -> Result { + let (id, evts) = ctx.props(); + evts.push(format!("Kid{id}::init")).await; + Ok(Self) + } + + async fn exit(&mut self, _: ExitReason, ctx: &mut Ctx) { + let (id, evts) = ctx.props(); + evts.push(format!("Kid{id}::exit")).await; + } + + async fn handle(&mut self, _: Self::Msg, _: &mut Ctx) -> Result<(), Self::Err> { + Err(()) + } + } + + #[tokio::test] + async fn stops_all_children_before_starting_them_again() { + // Arrange + let evts = SyncVec::default(); + let mut node = Node::default(); + let dad = node.spawn::(evts.clone()); + + // Act + dad.send(()); + time::sleep(Duration::from_millis(1)).await; + + // Assert + let evts = evts.clone_vec().await; + println!("{:?}", evts); + assert_eq!(evts.len(), 10, "{:?}", evts); + + let evts_clone = evts.clone(); + let err_msg = move |idx: usize, actual: &str| { + evts_clone + .iter() + .enumerate() + .map(|(i, x)| { + if i == idx { + format!("{x} -- ACTUAL VALUE. EXPECTED: {actual}") + } else { + x.to_string() + } + }) + .collect::>() + .join(",\n") + }; + + evts.iter() + .enumerate() + .take(4) + .for_each(|(i, x)| assert!(x.ends_with("init"), "{}", err_msg(i, x))); + + evts.iter().enumerate().skip(4).take(3).for_each(|(i, x)| { + assert!(x.starts_with("Kid"), "{}", err_msg(i, "Kid{}::exit")); + assert!(x.ends_with("exit"), "{}", err_msg(i, "Kid{}::exit")); + }); + + evts.iter().enumerate().skip(7).for_each(|(i, x)| { + assert!(x.starts_with("Kid"), "{}", err_msg(i, "Kid{}::init")); + assert!(x.ends_with("init"), "{}", err_msg(i, "Kid{}::init")); + }); + } }