Skip to content

Commit

Permalink
add gpu verification, update run_python_examples.sh
Browse files Browse the repository at this point in the history
  • Loading branch information
lessw2020 committed Nov 22, 2023
1 parent 836f798 commit 2de0144
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 8 deletions.
9 changes: 8 additions & 1 deletion distributed/tensor_parallelism/fsdp_tp_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import torch
import torch.distributed as dist
import torch.nn as nn
Expand All @@ -13,7 +14,7 @@

from torch.distributed._tensor.device_mesh import init_device_mesh
import os
from log_utils import rank_log, get_logger
from log_utils import rank_log, get_logger, verify_min_gpu_count

"""
This is the script to test 2D Parallel which combines Tensor/Sequence
Expand Down Expand Up @@ -46,6 +47,12 @@
https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/
"""

_min_gpu_count = 2

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit(0)


def find_multiple(n: int, k: int) -> int:
"""function to find resizing multiple for SwiGLU MLP"""
Expand Down
8 changes: 8 additions & 0 deletions distributed/tensor_parallelism/log_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import torch

logging.basicConfig(
format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO
Expand All @@ -12,3 +13,10 @@ def rank_log(_rank, logger, msg):
"""helper function to log only on global rank 0"""
if _rank == 0:
logger.info(f" {msg}")


def verify_min_gpu_count(min_gpus: int = 2) -> bool:
""" verification that we have at least 2 gpus to run dist examples """
has_cuda = torch.cuda.is_available()
gpu_count = torch.cuda.device_count()
return has_cuda and gpu_count >= min_gpus
8 changes: 7 additions & 1 deletion distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import torch
import torch.nn as nn

Expand All @@ -11,7 +12,7 @@
RowwiseParallel,
)

from log_utils import rank_log, get_logger
from log_utils import rank_log, get_logger, verify_min_gpu_count


"""
Expand All @@ -29,6 +30,11 @@
now is different so that we need one all-gather for input and one reduce-scatter
in the end of the second linear layer.
"""
_min_gpu_count = 2

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit(0)


class ToyModel(nn.Module):
Expand Down
10 changes: 9 additions & 1 deletion distributed/tensor_parallelism/tensor_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
import torch
import torch.nn as nn

Expand All @@ -11,7 +12,9 @@
)


from log_utils import rank_log, get_logger
from log_utils import rank_log, get_logger, verify_min_gpu_count




"""
Expand Down Expand Up @@ -45,6 +48,11 @@
Parallelism APIs in this example to show users how to use them.
"""

_min_gpu_count = 2

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit(0)

class ToyModel(nn.Module):
"""MLP based model"""
Expand Down
9 changes: 4 additions & 5 deletions run_python_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ function distributed() {
start
python tensor_parallelism/tensor_parallel_example.py || error "tensor parallel example failed"
python tensor_parallelism/sequence_parallel_example.py || error "sequence parallel example failed"
python tensor_parallelism/two_d_parallel_example.py || error "2D parallel example failed"
python ddp/main.py || error "ddp example failed"
python tensor_parallelism/fsdp_tp_parallel_example.py || error "2D parallel example failed"
python ddp/main.py || error "ddp example failed"
}

function fast_neural_style() {
Expand Down Expand Up @@ -96,7 +96,7 @@ function mnist() {
python main.py --epochs 1 --dry-run || error "mnist example failed"
}
function mnist_forward_forward() {
start
start
python main.py --epochs 1 --no_mps --no_cuda || error "mnist forward forward failed"

}
Expand Down Expand Up @@ -212,9 +212,8 @@ function clean() {
function run_all() {
# cpp
dcgan
# distributed
fast_neural_style
distributed
fast_neural_style
imagenet
mnist
mnist_forward_forward
Expand Down

0 comments on commit 2de0144

Please sign in to comment.