From 96a3d96e0f28a2b0842a9751e6a3a34294b3559c Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Fri, 4 Oct 2024 01:41:36 +0000 Subject: [PATCH] wip --- src/zeroband/comms.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index cf6378df..80250525 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -275,6 +275,9 @@ def maybe_reinit_global_pg(self): """Reinitialize the global_pg if there are joiners or dead nodes.""" time_start = time.perf_counter() self._logger.debug("Resolving world") + + self.live_recovery.stop_background_loop() # need to stop live recovery loop to avoid deadlocks + if self._global_leader: self._resolve_world() self.global_store.set("resolved_time", str(time.time())) @@ -329,6 +332,8 @@ def maybe_reinit_global_pg(self): # Without this barrier, a node might queue leave before the leaving queue is cleared dist.barrier(self.global_pg) + self.live_recovery.init_background_loop() + class LiveRecoveryModel(BaseModel): dest_rank: int @@ -385,18 +390,24 @@ def __init__(self, global_store: dist.Store): self.live_ckpt_store = LiveRecoveryStore(dist.PrefixStore("live_ckpt", self.global_store)) self.live_ckpt_store.set(self._live_recovery_key, None) - self._stop_event = mp.Event() self._dest_rank = mp.Value("i", -1) + self.init_background_loop() + + def init_background_loop(self) -> mp.Process: + self._stop_event = mp.Event() self._live_recovery_process = mp.Process( target=self._live_recovery_loop, args=(self._stop_event, self._dest_rank) ) self._live_recovery_process.start() - def __del__(self): + def stop_background_loop(self): self._stop_event.set() self._live_recovery_process.join() + def __del__(self): + self.stop_background_loop() + def _live_recovery_loop(self, stop_event: mp.Event, dest_rank: mp.Value) -> None: while not stop_event.is_set(): data = self.live_ckpt_store.get(self._live_recovery_key)