Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Oct 4, 2024
1 parent 67fa953 commit 96a3d96
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ def maybe_reinit_global_pg(self):
"""Reinitialize the global_pg if there are joiners or dead nodes."""
time_start = time.perf_counter()
self._logger.debug("Resolving world")

self.live_recovery.stop_background_loop() # need to stop live recovery loop to avoid deadlocks

if self._global_leader:
self._resolve_world()
self.global_store.set("resolved_time", str(time.time()))
Expand Down Expand Up @@ -329,6 +332,8 @@ def maybe_reinit_global_pg(self):
# Without this barrier, a node might queue leave before the leaving queue is cleared
dist.barrier(self.global_pg)

self.live_recovery.init_background_loop()


class LiveRecoveryModel(BaseModel):
dest_rank: int
Expand Down Expand Up @@ -385,18 +390,24 @@ def __init__(self, global_store: dist.Store):
self.live_ckpt_store = LiveRecoveryStore(dist.PrefixStore("live_ckpt", self.global_store))
self.live_ckpt_store.set(self._live_recovery_key, None)

self._stop_event = mp.Event()
self._dest_rank = mp.Value("i", -1)

self.init_background_loop()

def init_background_loop(self) -> mp.Process:
self._stop_event = mp.Event()
self._live_recovery_process = mp.Process(
target=self._live_recovery_loop, args=(self._stop_event, self._dest_rank)
)
self._live_recovery_process.start()

def __del__(self):
def stop_background_loop(self):
self._stop_event.set()
self._live_recovery_process.join()

def __del__(self):
self.stop_background_loop()

def _live_recovery_loop(self, stop_event: mp.Event, dest_rank: mp.Value) -> None:
while not stop_event.is_set():
data = self.live_ckpt_store.get(self._live_recovery_key)
Expand Down

0 comments on commit 96a3d96

Please sign in to comment.