Skip to content

Commit

Permalink
fix: when using one for all, stop all procs before restarting
Browse files Browse the repository at this point in the history
  • Loading branch information
vmenge committed Jan 17, 2024
1 parent e17f326 commit dbe071c
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "speare"
version = "0.1.6"
version = "0.1.7"
edition = "2021"
license = "MIT"
description = "actor-like thin abstraction over tokio::task and flume channels"
Expand Down
111 changes: 92 additions & 19 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ where
match self {
Self::Handle => write!(f, "manual exit through Handle::stop"),
Self::Parent => write!(f, "exit request from parent supervision strategy"),
Self::Err(e) => write!(f, "error: {e}"),
Self::Err(e) => write!(f, "{e}"),
}
}
}
Expand Down Expand Up @@ -671,7 +671,12 @@ where
handle
}

async fn handle_err(&mut self, e: Box<dyn Any + Send>, proc_id: u64) -> Option<()> {
async fn handle_err(
&mut self,
e: Box<dyn Any + Send>,
proc_id: u64,
err_ack: Sender<()>,
) -> Option<()> {
let directive = self
.supervision
.deciders
Expand All @@ -688,7 +693,9 @@ where
.get_backoff_duration(self.supervision.max_restarts, self.supervision.backoff)
{
Some(delay) => {
let _ = child.send(ProcMsg::FromParent(ProcAction::RestartIn(delay)));
let _ = child.send(ProcMsg::FromParent(ProcAction::Restart(
Restart::without_ack(delay),
)));
}

None => {
Expand All @@ -697,22 +704,43 @@ where
self.children_proc_msg_tx.remove(&proc_id);
}
}

let _ = err_ack.send(());
}

(Strategy::OneForOne { counter }, Directive::Stop) => {
let child = self.children_proc_msg_tx.get(&proc_id)?;
let _ = child.send(ProcMsg::FromParent(ProcAction::Stop));
counter.remove(&proc_id);
self.children_proc_msg_tx.remove(&proc_id);

let _ = err_ack.send(());
}

(Strategy::OneForAll { counter }, Directive::Restart) => {
match counter
.get_backoff_duration(self.supervision.max_restarts, self.supervision.backoff)
{
Some(delay) => {
let _ = err_ack.send(());

let mut exit_ack_rxs = vec![];
let mut can_restart_txs = vec![];

for child in self.children_proc_msg_tx.values() {
let _ = child.send(ProcMsg::FromParent(ProcAction::RestartIn(delay)));
let (restart, exit_ack_rx, can_restart_tx) = Restart::with_ack(delay);
exit_ack_rxs.push(exit_ack_rx);
can_restart_txs.push(can_restart_tx);

let _ = child.send(ProcMsg::FromParent(ProcAction::Restart(restart)));
}

for rx in exit_ack_rxs {
let _ = rx.recv_async().await;
}

for tx in can_restart_txs {
let _ = tx.send(());
}
}

Expand All @@ -727,6 +755,8 @@ where
}

(Strategy::OneForAll { .. }, Directive::Stop) => {
let _ = err_ack.send(());

for child in self.children_proc_msg_tx.values() {
let _ = child.send(ProcMsg::FromParent(ProcAction::Stop));
}
Expand All @@ -741,9 +771,13 @@ where
err: e,
ack: tx,
});

let _ = err_ack.send(());
}

(_, Directive::Resume) => (),
(_, Directive::Resume) => {
let _ = err_ack.send(());
}
};

task::yield_now().await;
Expand All @@ -765,9 +799,48 @@ enum ProcMsg {
FromHandle(ProcAction),
}

#[derive(Debug)]
struct Restart {
delay: Duration,
exit_ack_tx: Option<Sender<()>>,
can_restart_rx: Option<Receiver<()>>,
}

impl Restart {
fn with_ack(delay: Duration) -> (Self, Receiver<()>, Sender<()>) {
let (exit_ack_tx, exit_ack_rx) = flume::unbounded();
let (can_restart_tx, can_restart_rx) = flume::unbounded();

let restart = Restart {
delay,
exit_ack_tx: Some(exit_ack_tx),
can_restart_rx: Some(can_restart_rx),
};

(restart, exit_ack_rx, can_restart_tx)
}

fn without_ack(delay: Duration) -> Restart {
Restart {
delay,
exit_ack_tx: None,
can_restart_rx: None,
}
}

/// Waits for signal from Parent to restart
async fn sync(&self) {
if let (Some(exit_ack_tx), Some(can_restart_rx)) = (&self.exit_ack_tx, &self.can_restart_rx)
{
let _ = exit_ack_tx.send(());
let _ = can_restart_rx.recv_async().await;
}
}
}

#[derive(Debug)]
enum ProcAction {
RestartIn(Duration),
Restart(Restart),
Stop,
}

Expand All @@ -781,7 +854,7 @@ where
time::sleep(d).await;
}

let mut restart_in = None;
let mut restart = None;

match Child::init(&mut ctx).await {
Err(e) => {
Expand All @@ -796,8 +869,8 @@ where
loop {
if let Ok(ProcMsg::FromParent(proc_action)) = ctx.proc_msg_rx.recv_async().await
{
if let ProcAction::RestartIn(dur) = proc_action {
restart_in = Some(dur);
if let ProcAction::Restart(r) = proc_action {
restart = Some(r);
}

break;
Expand All @@ -810,8 +883,9 @@ where

task::yield_now().await;

if restart_in.is_some() {
spawn::<Parent, Child>(ctx, restart_in)
if let Some(r) = restart {
r.sync().await;
spawn::<Parent, Child>(ctx, Some(r.delay))
}
}

Expand All @@ -836,15 +910,14 @@ where
break
},

Ok(ProcMsg::FromParent(ProcAction::RestartIn(dur))) => {
Ok(ProcMsg::FromParent(ProcAction::Restart(r))) => {
exit_reason = exit_reason.or(Some(ExitReason::Parent));
restart_in = Some(dur);
restart = Some(r);
break;
}

Ok(ProcMsg::FromChild { child_id, err, ack }) => {
ctx.handle_err(err, child_id).await;
let _ = ack.send(());
ctx.handle_err(err, child_id, ack).await;
}

Ok(_) => ()
Expand Down Expand Up @@ -882,8 +955,9 @@ where
let exit_reason = exit_reason.unwrap_or(ExitReason::Handle);
process.exit(exit_reason, &mut ctx).await;

if restart_in.is_some() {
spawn::<Parent, Child>(ctx, restart_in)
if let Some(r) = restart {
r.sync().await;
spawn::<Parent, Child>(ctx, Some(r.delay))
}
}
}
Expand Down Expand Up @@ -957,8 +1031,7 @@ impl Node {
},

Ok(ProcMsg::FromChild { child_id, err, ack }) => {
ctx.handle_err(err, child_id).await;
let _ = ack.send(());
ctx.handle_err(err, child_id, ack).await;
}

_ => {}
Expand Down
107 changes: 107 additions & 0 deletions core/tests/supervision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ mod one_for_one {
}

mod one_for_all {
use crate::sync_vec::SyncVec;

use super::*;

#[derive(Clone)]
Expand Down Expand Up @@ -658,4 +660,109 @@ mod one_for_all {
// Assert
assert_eq!(errors, vec!["EscalateChildErr".to_string()])
}

struct Dad {
kid0: Handle<()>,
}

#[async_trait]
impl Process for Dad {
type Props = SyncVec<String>;
type Msg = ();
type Err = ();

async fn init(ctx: &mut Ctx<Self>) -> Result<Self, Self::Err> {
ctx.props().push("Dad::init".to_string()).await;
let kid0 = ctx.spawn::<Kid>((0, ctx.props().clone()));
ctx.spawn::<Kid>((1, ctx.props().clone()));
ctx.spawn::<Kid>((2, ctx.props().clone()));

Ok(Self { kid0 })
}

async fn exit(&mut self, _: ExitReason<Self>, ctx: &mut Ctx<Self>) {
ctx.props().push("Dad::exit".to_string()).await;
}

fn supervision(_: &Self::Props) -> Supervision {
Supervision::one_for_all().directive(Directive::Restart)
}

async fn handle(&mut self, _: Self::Msg, _: &mut Ctx<Self>) -> Result<(), Self::Err> {
self.kid0.send(());
Ok(())
}
}

struct Kid;

#[async_trait]
impl Process for Kid {
type Props = (Id, SyncVec<String>);
type Msg = ();
type Err = ();

async fn init(ctx: &mut Ctx<Self>) -> Result<Self, Self::Err> {
let (id, evts) = ctx.props();
evts.push(format!("Kid{id}::init")).await;
Ok(Self)
}

async fn exit(&mut self, _: ExitReason<Self>, ctx: &mut Ctx<Self>) {
let (id, evts) = ctx.props();
evts.push(format!("Kid{id}::exit")).await;
}

async fn handle(&mut self, _: Self::Msg, _: &mut Ctx<Self>) -> Result<(), Self::Err> {
Err(())
}
}

#[tokio::test]
async fn stops_all_children_before_starting_them_again() {
// Arrange
let evts = SyncVec::default();
let mut node = Node::default();
let dad = node.spawn::<Dad>(evts.clone());

// Act
dad.send(());
time::sleep(Duration::from_millis(1)).await;

// Assert
let evts = evts.clone_vec().await;
println!("{:?}", evts);
assert_eq!(evts.len(), 10, "{:?}", evts);

let evts_clone = evts.clone();
let err_msg = move |idx: usize, actual: &str| {
evts_clone
.iter()
.enumerate()
.map(|(i, x)| {
if i == idx {
format!("{x} -- ACTUAL VALUE. EXPECTED: {actual}")
} else {
x.to_string()
}
})
.collect::<Vec<_>>()
.join(",\n")
};

evts.iter()
.enumerate()
.take(4)
.for_each(|(i, x)| assert!(x.ends_with("init"), "{}", err_msg(i, x)));

evts.iter().enumerate().skip(4).take(3).for_each(|(i, x)| {
assert!(x.starts_with("Kid"), "{}", err_msg(i, "Kid{}::exit"));
assert!(x.ends_with("exit"), "{}", err_msg(i, "Kid{}::exit"));
});

evts.iter().enumerate().skip(7).for_each(|(i, x)| {
assert!(x.starts_with("Kid"), "{}", err_msg(i, "Kid{}::init"));
assert!(x.ends_with("init"), "{}", err_msg(i, "Kid{}::init"));
});
}
}

0 comments on commit dbe071c

Please sign in to comment.