diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py index 92828df5cf..9c32d1038b 100644 --- a/distributed/tensor_parallelism/fsdp_tp_example.py +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -1,3 +1,4 @@ +import sys import torch import torch.distributed as dist import torch.nn as nn @@ -13,7 +14,7 @@ from torch.distributed._tensor.device_mesh import init_device_mesh import os -from log_utils import rank_log, get_logger +from log_utils import rank_log, get_logger, verify_min_gpu_count """ This is the script to test 2D Parallel which combines Tensor/Sequence @@ -46,6 +47,12 @@ https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ """ +_min_gpu_count = 2 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit(0) + def find_multiple(n: int, k: int) -> int: """function to find resizing multiple for SwiGLU MLP""" diff --git a/distributed/tensor_parallelism/log_utils.py b/distributed/tensor_parallelism/log_utils.py index 611b25a412..f16d46526d 100644 --- a/distributed/tensor_parallelism/log_utils.py +++ b/distributed/tensor_parallelism/log_utils.py @@ -1,4 +1,5 @@ import logging +import torch logging.basicConfig( format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO @@ -12,3 +13,10 @@ def rank_log(_rank, logger, msg): """helper function to log only on global rank 0""" if _rank == 0: logger.info(f" {msg}") + + +def verify_min_gpu_count(min_gpus: int = 2) -> bool: + """ verification that we have at least 2 gpus to run dist examples """ + has_cuda = torch.cuda.is_available() + gpu_count = torch.cuda.device_count() + return has_cuda and gpu_count >= min_gpus diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index aa943a7304..6a9de413bb 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -1,4 +1,5 @@ import os +import sys import torch import torch.nn as nn @@ -11,7 +12,7 @@ RowwiseParallel, ) -from log_utils import rank_log, get_logger +from log_utils import rank_log, get_logger, verify_min_gpu_count """ @@ -29,6 +30,11 @@ now is different so that we need one all-gather for input and one reduce-scatter in the end of the second linear layer. """ +_min_gpu_count = 2 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit(0) class ToyModel(nn.Module): diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index f25252b998..bc8325d5d7 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -1,4 +1,5 @@ import os +import sys import torch import torch.nn as nn @@ -11,7 +12,9 @@ ) -from log_utils import rank_log, get_logger +from log_utils import rank_log, get_logger, verify_min_gpu_count + + """ @@ -45,6 +48,11 @@ Parallelism APIs in this example to show users how to use them. """ +_min_gpu_count = 2 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit(0) class ToyModel(nn.Module): """MLP based model""" diff --git a/run_python_examples.sh b/run_python_examples.sh index 1b45a281cf..c933cf8f65 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -63,8 +63,8 @@ function distributed() { start python tensor_parallelism/tensor_parallel_example.py || error "tensor parallel example failed" python tensor_parallelism/sequence_parallel_example.py || error "sequence parallel example failed" - python tensor_parallelism/two_d_parallel_example.py || error "2D parallel example failed" - python ddp/main.py || error "ddp example failed" + python tensor_parallelism/fsdp_tp_parallel_example.py || error "2D parallel example failed" + python ddp/main.py || error "ddp example failed" } function fast_neural_style() { @@ -96,7 +96,7 @@ function mnist() { python main.py --epochs 1 --dry-run || error "mnist example failed" } function mnist_forward_forward() { - start + start python main.py --epochs 1 --no_mps --no_cuda || error "mnist forward forward failed" } @@ -212,9 +212,8 @@ function clean() { function run_all() { # cpp dcgan - # distributed - fast_neural_style distributed + fast_neural_style imagenet mnist mnist_forward_forward