From 952a00eda797c42bba577497ea3741b8dba2a756 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Wed, 24 Apr 2024 01:44:37 +0000 Subject: [PATCH] torchelastic: change monitor_interval default to 0.1 (#124692) This reduces the default monitor_interval for torchelastic to 0.1s as testing shows negligble load for common use cases. Even at the extremes, 100k processes is only 45.4% cpu util of a single core. Torchelastic monitor_interval only monitors the processes on a single worker so under typical loads even for huge jobs we expect ~8 subprocesses per machine with one per GPU. As an external datapoint, Python's wait polls every 50usec-50ms (https://github.com/python/cpython/blob/main/Lib/subprocess.py#L2035). ## Motivation This setting is used to control how frequently we poll for failed processes in elastic. * For some jobs of note we run elastic 3 times per try so with the default timeout of 5 seconds we should save ~15 seconds per retry. * @kiukchung's use case: Apparently this is annoying in notebooks etc since it adds delay to shutdown when testing things ## Results This is measured in cores (100% is a single core under full load). | monitor_interval (s) | nproc-per-node | CPU util (highest observed) | | -------------------- | -------------- | --------------------------- | | 1.0 | 10 | 0.2% | | 0.1 | 1 | 0.4% | | 0.1 | 10 | 0.4% | | 0.01 | 10 | 0.9% | | 0.001 | 10 | 4.0% | | 0.1 | 100 | 0.5% | | 0.1 | 1000 | 2.2% | | 0.1 | 10000 | 15.7% | | 0.1 | 100000 | 45.4% | ## Methodology ```sh # run command $ LOGLEVEL=INFO torchrun --nnodes 1 --nproc-per-node 10 --monitor-interval 0.1 ~/wait.py # wait a few seconds for all processes to start and reach steady state and then run, wait ~30s or 3 prints and take the highest $ top -b -d 10 -c | rg 'torchrun.*wait ``` wait.py ```py import time time.sleep(10*60) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/124692 Approved by: https://github.com/kiukchung, https://github.com/kurman --- test/distributed/elastic/agent/server/test/api_test.py | 4 ++-- test/distributed/launcher/api_test.py | 2 +- torch/distributed/elastic/agent/server/api.py | 2 +- torch/distributed/launcher/api.py | 2 +- torch/distributed/run.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/distributed/elastic/agent/server/test/api_test.py b/test/distributed/elastic/agent/server/test/api_test.py index b0d64f98538b46..e57b7b9fcb44c1 100644 --- a/test/distributed/elastic/agent/server/test/api_test.py +++ b/test/distributed/elastic/agent/server/test/api_test.py @@ -54,7 +54,7 @@ def test_worker_group_constructor(self): args=(), rdzv_handler=None, max_restarts=50, - monitor_interval=1, + monitor_interval=0.1, ) worker_group = WorkerGroup(spec) @@ -157,7 +157,7 @@ class SimpleElasticAgentTest(unittest.TestCase): def _get_worker_spec( self, max_restarts=1, - monitor_interval=1.0, + monitor_interval=0.1, role="test_trainer", local_world_size=8, local_addr=None, diff --git a/test/distributed/launcher/api_test.py b/test/distributed/launcher/api_test.py index 81e9320d1f0629..38e3bc305fa55e 100644 --- a/test/distributed/launcher/api_test.py +++ b/test/distributed/launcher/api_test.py @@ -70,7 +70,7 @@ def get_test_launch_config( nproc_per_node=nproc_per_node, run_id=run_id, rdzv_endpoint=rdzv_endpoint, - monitor_interval=1, + monitor_interval=0.1, rdzv_backend=rdzv_backend, start_method="spawn", max_restarts=0, diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index 4ebfc5952303bb..dd20703cedb46b 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -85,7 +85,7 @@ class WorkerSpec: entrypoint: Union[Callable, str, None] = None args: Tuple = () max_restarts: int = 3 - monitor_interval: float = 30.0 + monitor_interval: float = 0.1 master_port: Optional[int] = None master_addr: Optional[str] = None local_addr: Optional[str] = None diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index f2b4aca644f843..20de0a032713af 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -75,7 +75,7 @@ class LaunchConfig: rdzv_configs: Dict[str, Any] = field(default_factory=dict) rdzv_timeout: int = -1 max_restarts: int = 3 - monitor_interval: float = 30 + monitor_interval: float = 0.1 start_method: str = "spawn" log_line_prefix_template: Optional[str] = None metrics_cfg: Dict[str, str] = field(default_factory=dict) diff --git a/torch/distributed/run.py b/torch/distributed/run.py index 3352111068d88f..98917a667e0b58 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -499,7 +499,7 @@ def get_args_parser() -> ArgumentParser: "--monitor_interval", action=env, type=float, - default=5, + default=0.1, help="Interval, in seconds, to monitor the state of workers.", ) parser.add_argument(