diff --git a/README.md b/README.md index 68f780cf..6dc90b0e 100644 --- a/README.md +++ b/README.md @@ -96,8 +96,10 @@ HF_HUB_ETAG_TIMEOUT=500 | `GLOBAL_WORLD_SIZE` | The size of the global process group. | `1` | | `GLOBAL_RANK` | Rank of the process in the global process group. | `0` | -### Global Store Configuration +### Elastic Device Mesh Configuration | Environment Variable | Description | Default Value | |-----------------------|--------------------------------------------------|---------------| | `ZERO_BAND_GLOBAL_STORE_TIMEOUT_SECONDS` | Number of seconds before the global store operations timeout | `300` | | `ZERO_BAND_GLOBAL_STORE_POLLING_INTERVAL_SECONDS` | Number of seconds between polls to the store when waiting for values | `0.1` | +| `ZERO_BAND_EDM_HEARTBEAT_INTERVAL_SECONDS` | Interval in seconds between heartbeats | `2` | +| `ZERO_BAND_EDM_HEARTBEAT_TIMEOUT_SECONDS` | Time in seconds after which a node is considered dead if no heartbeat is received | `10` | diff --git a/src/zeroband/comms.py b/src/zeroband/comms.py index f6d2d215..38753172 100644 --- a/src/zeroband/comms.py +++ b/src/zeroband/comms.py @@ -1,19 +1,24 @@ import sys import os +import time from torch.distributed.device_mesh import init_device_mesh from zeroband.utils.world_info import get_world_info from zeroband.utils.logging import get_logger import torch.distributed as dist from datetime import timedelta -import time from typing import List, Tuple, Optional from torch.testing._internal.distributed.fake_pg import FakeProcessGroup - +import multiprocessing as mp TCPSTORE_TIMEOUT = timedelta(seconds=int(os.getenv("ZERO_BAND_GLOBAL_STORE_TIMEOUT_SECONDS", "300"))) TCPSTORE_POLLING_INTERVAL = float(os.getenv("ZERO_BAND_GLOBAL_STORE_POLLING_INTERVAL_SECONDS", "0.1")) MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit -MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit +HEARTBEAT_INTERVAL = int( + os.getenv("ZERO_BAND_EDM_HEARTBEAT_INTERVAL_SECONDS", "2") +) # Interval in seconds between heartbeats +HEARTBEAT_TIMEOUT = int( + os.getenv("ZERO_BAND_EDM_HEARTBEAT_TIMEOUT_SECONDS", "10") +) # Time in seconds after which a node is considered dead if no heartbeat is received class ElasticDeviceMesh: @@ -30,7 +35,6 @@ class ElasticDeviceMesh: - rank_{uuid}: The rank of the node with the given uuid - rank_map_{rank}: The new rank of the node with the given rank. Used to remap ranks when nodes leave. - joiner_{i}: The uuid of the ith joiner. Its a KV implmentation of a queue. - - leaver_{i}: The uuid of the ith leaver. Its a KV implmentation of a queue. """ local_pg: dist.ProcessGroup @@ -54,18 +58,21 @@ def __init__(self, backend: str = "cpu:gloo,cuda:nccl"): ) self.local_pg = self.mesh.get_group("intranode") + # Start heartbeat + self._start_heartbeat() + # Logging self._logger.info(f"global_pg size : {self.global_pg.size()}, local_pg size: {self.local_pg.size()}") def __del__(self): + self._stop_heartbeat() dist.destroy_process_group() def _init_global_store_and_status(self): - """Initialize the global store with mesh_count, joiner_0, leaver_0, and status. Also sets the global status.""" + """Initialize the global store with mesh_count, joiner_0, and status. Also sets the global status.""" if self._global_leader: self.global_store.set("mesh_count", "0") self.global_store.set("joiner_0", "null") - self.global_store.set("leaver_0", "null") self.global_store.set("status", "init") self.global_status = "init" else: @@ -82,38 +89,17 @@ def _queue_join(self): else: raise RuntimeError("Too many joiners") - def _queue_leave(self): - """Queue a node to leave the mesh.""" - self.leaving = True - for i in range(MAX_LEAVERS): - leaver_id = self.global_store.get(f"leaver_{i}").decode("utf-8") - if leaver_id == "null": - self._logger.debug(f"Queueing leaver {self.world_info.global_unique_id} at index {i}") - self.global_store.set(f"leaver_{i}", self.world_info.global_unique_id) - self.global_store.set(f"leaver_{i + 1}", "null") - break - else: - raise RuntimeError("Too many leavers") - - def _get_joiners_and_leavers(self) -> Tuple[List[str], List[str]]: + def _get_joiners(self) -> Tuple[List[str], List[str]]: joiners = [] - leavers = [] for i in range(MAX_JOINERS): joiner_id = self.global_store.get(f"joiner_{i}").decode("utf-8") if joiner_id == "null": break joiners.append(joiner_id) - for i in range(MAX_LEAVERS): - leaver_id = self.global_store.get(f"leaver_{i}").decode("utf-8") - if leaver_id == "null": - break - leavers.append(leaver_id) - self._logger.debug(f"Joiners: {joiners}, Leavers: {leavers}") - return joiners, leavers + return joiners - def _clear_joiners_and_leavers(self): + def _clear_joiners(self): self.global_store.set("joiner_0", "null") - self.global_store.set("leaver_0", "null") def _wait_for_status(self, status: Optional[str] = None) -> str: """Wait for status to be set in the store. @@ -135,10 +121,6 @@ def _wait_for_status(self, status: Optional[str] = None) -> str: raise e time.sleep(0.1) - def _get_assigned_global_rank_and_size(self) -> Tuple[int, int]: - """Get the assigned global rank from the leader.""" - return - def _init_global_pg(self) -> None: # Each rank gets its own global store with global rank 0 as the master time_start = time.perf_counter() @@ -188,6 +170,7 @@ def _init_global_pg(self) -> None: # Update global store values if self._global_leader: self.global_store.set("status", "running") + self.global_store.set("resolved_time", str(time.time())) self.global_status = "running" self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank)) @@ -197,20 +180,78 @@ def _init_global_pg(self) -> None: # We might be able to get away with only doing in joining path. # Let's not risk it for now though. dist.barrier(self.global_pg) + self._last_resolved_time = self.global_store.get("resolved_time").decode("utf-8") self._logger.info( f"Elastic Device mesh init done with {self.global_pg.size()} peers in {time.perf_counter() - time_start} seconds" ) + def _start_heartbeat(self): + """Start sending heartbeats to the global store in a separate process.""" + self._heartbeat_stop_event = mp.Event() + self._heartbeat_process = mp.Process(target=self._heartbeat_loop, args=(self._heartbeat_stop_event,)) + self._heartbeat_process.start() + + def _stop_heartbeat(self): + """Stop the heartbeat process.""" + self._send_deathrattle() + if hasattr(self, "_heartbeat_stop_event"): + self._heartbeat_stop_event.set() + self._heartbeat_process.join() + + def _heartbeat_loop(self, stop_event): + """Continuously send heartbeats until stopped.""" + try: + while not stop_event.is_set(): + self._send_heartbeat() + time.sleep(HEARTBEAT_INTERVAL) + finally: + self._send_deathrattle() + + def _send_heartbeat(self): + """Send a heartbeat to the global store.""" + current_time = time.time() + try: + self.global_store.set(f"heartbeat_{self.world_info.global_rank}", str(current_time)) + except Exception: + pass + + def _send_deathrattle(self): + """Send a deathrattle to the global store.""" + if hasattr(self, "global_store"): + self.global_store.set(f"heartbeat_{self.world_info.global_rank}", "-100") + else: + import warnings + + warnings.warn("global_store garbage collected. Skipping deathrattle.") + + def _check_heartbeats(self) -> List[str]: + """Check heartbeats and return a list of nodes that have missed their heartbeats.""" + dead_nodes = [] + current_time = time.time() + for i in range(self.world_info.global_world_size): + try: + last_heartbeat = float(self.global_store.get(f"heartbeat_{i}").decode("utf-8")) + if current_time - last_heartbeat > HEARTBEAT_TIMEOUT: + dead_nodes.append(i) + except dist.DistStoreError: + self._logger.warning(f"Node {i} has no heartbeat") + return dead_nodes + def _resolve_world(self): - """Set the new world size and ranks for all nodes if there are joiners or leavers. Else, do nothing.""" - # Find joiners and leavers - joiners, leavers = self._get_joiners_and_leavers() - # If no joiners or leavers, no resolution needed - if len(joiners) == 0 and len(leavers) == 0: + """Set the new world size and ranks for all nodes if there are joiners or dead nodes. Else, do nothing.""" + # Find joiners + joiners = self._get_joiners() + + # Check for dead nodes + dead_nodes = self._check_heartbeats() + self._logger.debug(f"Joiners: {joiners}, Dead nodes: {dead_nodes}") + + # If no joiners or dead nodes, no resolution needed + if len(joiners) == 0 and len(dead_nodes) == 0: return - # Remap live ranks to smaller world_size caused by leavers - leaving_ranks = {int(self.global_store.get(f"rank_{leaver_id}").decode("utf-8")) for leaver_id in leavers} + # Remap live ranks to smaller world_size caused by dead nodes + leaving_ranks = set(dead_nodes) live_ranks = [i for i in range(self.world_info.global_world_size) if i not in leaving_ranks] for i, rank in enumerate(live_ranks): self.global_store.set(f"rank_map_{rank}", str(i)) @@ -228,12 +269,21 @@ def _resolve_world(self): self.global_store.set("status", "reinit") def maybe_reinit_global_pg(self): - """Reinitialize the global_pg if there are joiners or leavers.""" + """Reinitialize the global_pg if there are joiners or dead nodes.""" + time_start = time.perf_counter() + self._logger.debug("Resolving world") if self._global_leader: self._resolve_world() - dist.barrier(self.global_pg) + self.global_store.set("resolved_time", str(time.time())) + else: + while (ans := self.global_store.get("resolved_time").decode("utf-8")) == self._last_resolved_time: + time.sleep(TCPSTORE_POLLING_INTERVAL) + self._last_resolved_time = ans + + self._logger.debug("World resolved in %s seconds", time.perf_counter() - time_start) + status = self.global_store.get("status").decode("utf-8") - if status == "running": # No joiners or leavers + if status == "running": # No joiners or dead nodes return # Reinit Path @@ -267,7 +317,7 @@ def maybe_reinit_global_pg(self): ) if self._global_leader: - self._clear_joiners_and_leavers() + self._clear_joiners() self.global_store.set("status", "running") # Update rank if needed (otherwise, the next remap will do the lookup incorrectly) diff --git a/tests/test_dist/test_comms.py b/tests/test_dist/test_comms.py index 40c22059..15e090ec 100644 --- a/tests/test_dist/test_comms.py +++ b/tests/test_dist/test_comms.py @@ -117,17 +117,15 @@ def foo(**kwargs): assert edm.mesh_count == 1 assert edm.global_pg.size() == global_world_size + 1 - if test_value == 1: - edm._queue_leave() - a = torch.arange(3) * (test_value + 1) sum_ints = global_world_size * (global_world_size + 1) // 2 + 100 dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg) assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints])) - edm.maybe_reinit_global_pg() if test_value == 1: return + time.sleep(2) + edm.maybe_reinit_global_pg() assert edm.mesh_count == 2 assert edm.global_pg.size() == global_world_size