From 3d679a5606fe87cb5e79a4e26a352ed8bac07ddf Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Mon, 14 Oct 2024 16:54:58 -0700 Subject: [PATCH 1/2] Set cuda device during initialization of distributed backend. The commit is needed to avoid GPU 0 being set as the compute stream via torch.cuda.current_stream() during initialization across all GPUs. Signed-off-by: Jagadish Krishnamoorthy --- training/cifar/cifar10_deepspeed.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/training/cifar/cifar10_deepspeed.py b/training/cifar/cifar10_deepspeed.py index 521a75cdf..8460feaf4 100755 --- a/training/cifar/cifar10_deepspeed.py +++ b/training/cifar/cifar10_deepspeed.py @@ -1,4 +1,5 @@ import argparse +import os import deepspeed import torch @@ -279,6 +280,9 @@ def test(model_engine, testset, local_device, target_dtype, test_batch_size=4): def main(args): # Initialize DeepSpeed distributed backend. deepspeed.init_distributed() + _local_rank = int(os.environ.get("LOCAL_RANK")) + torch_device = torch.device(f"cuda:{_local_rank}") + torch.cuda.set_device(torch_device) ######################################################################## # Step1. Data Preparation. From 7f9198893cc7edccb17dda11bdf3baf23e60cff0 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Tue, 15 Oct 2024 15:37:55 -0700 Subject: [PATCH 2/2] Use device-agnostic accelerator API. Signed-off-by: Jagadish Krishnamoorthy --- training/cifar/cifar10_deepspeed.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/training/cifar/cifar10_deepspeed.py b/training/cifar/cifar10_deepspeed.py index 8460feaf4..9888544d5 100755 --- a/training/cifar/cifar10_deepspeed.py +++ b/training/cifar/cifar10_deepspeed.py @@ -281,8 +281,7 @@ def main(args): # Initialize DeepSpeed distributed backend. deepspeed.init_distributed() _local_rank = int(os.environ.get("LOCAL_RANK")) - torch_device = torch.device(f"cuda:{_local_rank}") - torch.cuda.set_device(torch_device) + get_accelerator().set_device(_local_rank) ######################################################################## # Step1. Data Preparation.