Skip to content

Commit

Permalink
EDM Heartbeat (#33)
Browse files Browse the repository at this point in the history
* first draft

* use lookup barrier to make hub spoke barrier

* test crashng leave

* use deathrattle to fasttrack leaving

* reduce sleep

* remove need for leavers

* dont stack trace on interpreter shutdown from heartbeat logic

* make heartbeats env var
  • Loading branch information
Jackmin801 authored Oct 3, 2024
1 parent 0ba28db commit 31caeac
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 50 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,10 @@ HF_HUB_ETAG_TIMEOUT=500
| `GLOBAL_WORLD_SIZE` | The size of the global process group. | `1` |
| `GLOBAL_RANK` | Rank of the process in the global process group. | `0` |

### Global Store Configuration
### Elastic Device Mesh Configuration
| Environment Variable | Description | Default Value |
|-----------------------|--------------------------------------------------|---------------|
| `ZERO_BAND_GLOBAL_STORE_TIMEOUT_SECONDS` | Number of seconds before the global store operations timeout | `300` |
| `ZERO_BAND_GLOBAL_STORE_POLLING_INTERVAL_SECONDS` | Number of seconds between polls to the store when waiting for values | `0.1` |
| `ZERO_BAND_EDM_HEARTBEAT_INTERVAL_SECONDS` | Interval in seconds between heartbeats | `2` |
| `ZERO_BAND_EDM_HEARTBEAT_TIMEOUT_SECONDS` | Time in seconds after which a node is considered dead if no heartbeat is received | `10` |
140 changes: 95 additions & 45 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import sys
import os
import time
from torch.distributed.device_mesh import init_device_mesh
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger
import torch.distributed as dist
from datetime import timedelta
import time
from typing import List, Tuple, Optional
from torch.testing._internal.distributed.fake_pg import FakeProcessGroup

import multiprocessing as mp

TCPSTORE_TIMEOUT = timedelta(seconds=int(os.getenv("ZERO_BAND_GLOBAL_STORE_TIMEOUT_SECONDS", "300")))
TCPSTORE_POLLING_INTERVAL = float(os.getenv("ZERO_BAND_GLOBAL_STORE_POLLING_INTERVAL_SECONDS", "0.1"))
MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit
MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit
HEARTBEAT_INTERVAL = int(
os.getenv("ZERO_BAND_EDM_HEARTBEAT_INTERVAL_SECONDS", "2")
) # Interval in seconds between heartbeats
HEARTBEAT_TIMEOUT = int(
os.getenv("ZERO_BAND_EDM_HEARTBEAT_TIMEOUT_SECONDS", "10")
) # Time in seconds after which a node is considered dead if no heartbeat is received


class ElasticDeviceMesh:
Expand All @@ -30,7 +35,6 @@ class ElasticDeviceMesh:
- rank_{uuid}: The rank of the node with the given uuid
- rank_map_{rank}: The new rank of the node with the given rank. Used to remap ranks when nodes leave.
- joiner_{i}: The uuid of the ith joiner. Its a KV implmentation of a queue.
- leaver_{i}: The uuid of the ith leaver. Its a KV implmentation of a queue.
"""

local_pg: dist.ProcessGroup
Expand All @@ -54,18 +58,21 @@ def __init__(self, backend: str = "cpu:gloo,cuda:nccl"):
)
self.local_pg = self.mesh.get_group("intranode")

# Start heartbeat
self._start_heartbeat()

# Logging
self._logger.info(f"global_pg size : {self.global_pg.size()}, local_pg size: {self.local_pg.size()}")

def __del__(self):
self._stop_heartbeat()
dist.destroy_process_group()

def _init_global_store_and_status(self):
"""Initialize the global store with mesh_count, joiner_0, leaver_0, and status. Also sets the global status."""
"""Initialize the global store with mesh_count, joiner_0, and status. Also sets the global status."""
if self._global_leader:
self.global_store.set("mesh_count", "0")
self.global_store.set("joiner_0", "null")
self.global_store.set("leaver_0", "null")
self.global_store.set("status", "init")
self.global_status = "init"
else:
Expand All @@ -82,38 +89,17 @@ def _queue_join(self):
else:
raise RuntimeError("Too many joiners")

def _queue_leave(self):
"""Queue a node to leave the mesh."""
self.leaving = True
for i in range(MAX_LEAVERS):
leaver_id = self.global_store.get(f"leaver_{i}").decode("utf-8")
if leaver_id == "null":
self._logger.debug(f"Queueing leaver {self.world_info.global_unique_id} at index {i}")
self.global_store.set(f"leaver_{i}", self.world_info.global_unique_id)
self.global_store.set(f"leaver_{i + 1}", "null")
break
else:
raise RuntimeError("Too many leavers")

def _get_joiners_and_leavers(self) -> Tuple[List[str], List[str]]:
def _get_joiners(self) -> Tuple[List[str], List[str]]:
joiners = []
leavers = []
for i in range(MAX_JOINERS):
joiner_id = self.global_store.get(f"joiner_{i}").decode("utf-8")
if joiner_id == "null":
break
joiners.append(joiner_id)
for i in range(MAX_LEAVERS):
leaver_id = self.global_store.get(f"leaver_{i}").decode("utf-8")
if leaver_id == "null":
break
leavers.append(leaver_id)
self._logger.debug(f"Joiners: {joiners}, Leavers: {leavers}")
return joiners, leavers
return joiners

def _clear_joiners_and_leavers(self):
def _clear_joiners(self):
self.global_store.set("joiner_0", "null")
self.global_store.set("leaver_0", "null")

def _wait_for_status(self, status: Optional[str] = None) -> str:
"""Wait for status to be set in the store.
Expand All @@ -135,10 +121,6 @@ def _wait_for_status(self, status: Optional[str] = None) -> str:
raise e
time.sleep(0.1)

def _get_assigned_global_rank_and_size(self) -> Tuple[int, int]:
"""Get the assigned global rank from the leader."""
return

def _init_global_pg(self) -> None:
# Each rank gets its own global store with global rank 0 as the master
time_start = time.perf_counter()
Expand Down Expand Up @@ -188,6 +170,7 @@ def _init_global_pg(self) -> None:
# Update global store values
if self._global_leader:
self.global_store.set("status", "running")
self.global_store.set("resolved_time", str(time.time()))
self.global_status = "running"
self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank))

Expand All @@ -197,20 +180,78 @@ def _init_global_pg(self) -> None:
# We might be able to get away with only doing in joining path.
# Let's not risk it for now though.
dist.barrier(self.global_pg)
self._last_resolved_time = self.global_store.get("resolved_time").decode("utf-8")
self._logger.info(
f"Elastic Device mesh init done with {self.global_pg.size()} peers in {time.perf_counter() - time_start} seconds"
)

def _start_heartbeat(self):
"""Start sending heartbeats to the global store in a separate process."""
self._heartbeat_stop_event = mp.Event()
self._heartbeat_process = mp.Process(target=self._heartbeat_loop, args=(self._heartbeat_stop_event,))
self._heartbeat_process.start()

def _stop_heartbeat(self):
"""Stop the heartbeat process."""
self._send_deathrattle()
if hasattr(self, "_heartbeat_stop_event"):
self._heartbeat_stop_event.set()
self._heartbeat_process.join()

def _heartbeat_loop(self, stop_event):
"""Continuously send heartbeats until stopped."""
try:
while not stop_event.is_set():
self._send_heartbeat()
time.sleep(HEARTBEAT_INTERVAL)
finally:
self._send_deathrattle()

def _send_heartbeat(self):
"""Send a heartbeat to the global store."""
current_time = time.time()
try:
self.global_store.set(f"heartbeat_{self.world_info.global_rank}", str(current_time))
except Exception:
pass

def _send_deathrattle(self):
"""Send a deathrattle to the global store."""
if hasattr(self, "global_store"):
self.global_store.set(f"heartbeat_{self.world_info.global_rank}", "-100")
else:
import warnings

warnings.warn("global_store garbage collected. Skipping deathrattle.")

def _check_heartbeats(self) -> List[str]:
"""Check heartbeats and return a list of nodes that have missed their heartbeats."""
dead_nodes = []
current_time = time.time()
for i in range(self.world_info.global_world_size):
try:
last_heartbeat = float(self.global_store.get(f"heartbeat_{i}").decode("utf-8"))
if current_time - last_heartbeat > HEARTBEAT_TIMEOUT:
dead_nodes.append(i)
except dist.DistStoreError:
self._logger.warning(f"Node {i} has no heartbeat")
return dead_nodes

def _resolve_world(self):
"""Set the new world size and ranks for all nodes if there are joiners or leavers. Else, do nothing."""
# Find joiners and leavers
joiners, leavers = self._get_joiners_and_leavers()
# If no joiners or leavers, no resolution needed
if len(joiners) == 0 and len(leavers) == 0:
"""Set the new world size and ranks for all nodes if there are joiners or dead nodes. Else, do nothing."""
# Find joiners
joiners = self._get_joiners()

# Check for dead nodes
dead_nodes = self._check_heartbeats()
self._logger.debug(f"Joiners: {joiners}, Dead nodes: {dead_nodes}")

# If no joiners or dead nodes, no resolution needed
if len(joiners) == 0 and len(dead_nodes) == 0:
return

# Remap live ranks to smaller world_size caused by leavers
leaving_ranks = {int(self.global_store.get(f"rank_{leaver_id}").decode("utf-8")) for leaver_id in leavers}
# Remap live ranks to smaller world_size caused by dead nodes
leaving_ranks = set(dead_nodes)
live_ranks = [i for i in range(self.world_info.global_world_size) if i not in leaving_ranks]
for i, rank in enumerate(live_ranks):
self.global_store.set(f"rank_map_{rank}", str(i))
Expand All @@ -228,12 +269,21 @@ def _resolve_world(self):
self.global_store.set("status", "reinit")

def maybe_reinit_global_pg(self):
"""Reinitialize the global_pg if there are joiners or leavers."""
"""Reinitialize the global_pg if there are joiners or dead nodes."""
time_start = time.perf_counter()
self._logger.debug("Resolving world")
if self._global_leader:
self._resolve_world()
dist.barrier(self.global_pg)
self.global_store.set("resolved_time", str(time.time()))
else:
while (ans := self.global_store.get("resolved_time").decode("utf-8")) == self._last_resolved_time:
time.sleep(TCPSTORE_POLLING_INTERVAL)
self._last_resolved_time = ans

self._logger.debug("World resolved in %s seconds", time.perf_counter() - time_start)

status = self.global_store.get("status").decode("utf-8")
if status == "running": # No joiners or leavers
if status == "running": # No joiners or dead nodes
return

# Reinit Path
Expand Down Expand Up @@ -267,7 +317,7 @@ def maybe_reinit_global_pg(self):
)

if self._global_leader:
self._clear_joiners_and_leavers()
self._clear_joiners()
self.global_store.set("status", "running")

# Update rank if needed (otherwise, the next remap will do the lookup incorrectly)
Expand Down
6 changes: 2 additions & 4 deletions tests/test_dist/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,15 @@ def foo(**kwargs):
assert edm.mesh_count == 1
assert edm.global_pg.size() == global_world_size + 1

if test_value == 1:
edm._queue_leave()

a = torch.arange(3) * (test_value + 1)
sum_ints = global_world_size * (global_world_size + 1) // 2 + 100
dist.all_reduce(a, op=dist.ReduceOp.SUM, group=edm.global_pg)
assert torch.allclose(a, torch.tensor([0, sum_ints, 2 * sum_ints]))

edm.maybe_reinit_global_pg()
if test_value == 1:
return
time.sleep(2)
edm.maybe_reinit_global_pg()
assert edm.mesh_count == 2
assert edm.global_pg.size() == global_world_size

Expand Down

0 comments on commit 31caeac

Please sign in to comment.