-
Notifications
You must be signed in to change notification settings - Fork 222
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* MXNet distributed training * change apiVersion * Addressed some review comments Newline related comments * Revert "change apiVersion" This reverts commit 163aed7.
- Loading branch information
1 parent
59cdbae
commit 6f627a8
Showing
3 changed files
with
248 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
FROM horovod/horovod:0.16.2-tf1.12.0-torch1.1.0-mxnet1.4.1-py3.5 AS build | ||
|
||
# Create a wrapper for OpenMPI to allow running as root by default | ||
RUN mv /usr/local/bin/mpirun /usr/local/bin/mpirun.real && \ | ||
echo '#!/bin/bash' > /usr/local/bin/mpirun && \ | ||
echo 'mpirun.real --allow-run-as-root "$@"' >> /usr/local/bin/mpirun && \ | ||
chmod a+x /usr/local/bin/mpirun | ||
|
||
# Configure OpenMPI to run good defaults: | ||
RUN echo "hwloc_base_binding_policy = none" >> /usr/local/etc/openmpi-mca-params.conf && \ | ||
echo "rmaps_base_mapping_policy = slot" >> /usr/local/etc/openmpi-mca-params.conf && \ | ||
echo "btl_tcp_if_exclude = lo,docker0" >> /usr/local/etc/openmpi-mca-params.conf | ||
|
||
# Set default NCCL parameters | ||
RUN echo NCCL_DEBUG=INFO >> /etc/nccl.conf && \ | ||
echo NCCL_SOCKET_IFNAME=^docker0 >> /etc/nccl.conf | ||
|
||
# -------------------------------------------------------------------- | ||
|
||
# Other packages needed for running examples | ||
RUN pip install gluoncv | ||
|
||
# add the example script to examples folder | ||
ADD mxnet_mnist.py /examples/mxnet_mnist.py | ||
|
||
WORKDIR "/" | ||
CMD ["bin/bash"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
apiVersion: kubeflow.org/v1alpha1 | ||
kind: MPIJob | ||
metadata: | ||
labels: | ||
ksonnet.io/component: mxnet-mnist-horovod-job | ||
name: mxnet-mnist-horovod-job | ||
namespace: default | ||
spec: | ||
replicas: 2 | ||
template: | ||
spec: | ||
containers: | ||
- command: | ||
- mpirun | ||
- -mca | ||
- btl_tcp_if_exclude | ||
- lo | ||
- -mca | ||
- pml | ||
- ob1 | ||
- -mca | ||
- btl | ||
- ^openib | ||
- --bind-to | ||
- none | ||
- -map-by | ||
- slot | ||
- -x | ||
- LD_LIBRARY_PATH | ||
- -x | ||
- PATH | ||
- -x | ||
- NCCL_SOCKET_IFNAME=eth0 | ||
- -x | ||
- NCCL_DEBUG=INFO | ||
- -x | ||
- MXNET_CUDNN_AUTOTUNE_DEFAULT=0 | ||
- python | ||
- /examples/mxnet_mnist.py | ||
- --save-frequency | ||
- "1" | ||
- --batch-size | ||
- "64" | ||
- --epochs | ||
- "5" | ||
image: mpioperator/mxnet-horovod:latest | ||
name: mxnet-mnist-horovod-job | ||
resources: | ||
limits: | ||
nvidia.com/gpu: 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import argparse | ||
import logging | ||
import os | ||
import zipfile | ||
import time | ||
|
||
import mxnet as mx | ||
import horovod.mxnet as hvd | ||
from mxnet import autograd, gluon, nd | ||
from mxnet.test_utils import download | ||
|
||
# Training settings | ||
parser = argparse.ArgumentParser(description='Apache MXNet MNIST Example') | ||
|
||
parser.add_argument('--batch-size', type=int, default=64, | ||
help='training batch size (default: 64)') | ||
parser.add_argument('--dtype', type=str, default='float32', | ||
help='training data type (default: float32)') | ||
parser.add_argument('--epochs', type=int, default=5, | ||
help='number of training epochs (default: 5)') | ||
parser.add_argument('--lr', type=float, default=0.01, | ||
help='learning rate (default: 0.01)') | ||
parser.add_argument('--momentum', type=float, default=0.9, | ||
help='SGD momentum (default: 0.9)') | ||
parser.add_argument('--no-cuda', action='store_true', default=False, | ||
help='disable training on GPU (default: False)') | ||
args = parser.parse_args() | ||
|
||
if not args.no_cuda: | ||
# Disable CUDA if there are no GPUs. | ||
if mx.context.num_gpus() == 0: | ||
args.no_cuda = True | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logging.info(args) | ||
|
||
|
||
# Function to get mnist iterator given a rank | ||
def get_mnist_iterator(rank): | ||
data_dir = "data-%d" % rank | ||
if not os.path.isdir(data_dir): | ||
os.makedirs(data_dir) | ||
zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip', | ||
dirname=data_dir) | ||
with zipfile.ZipFile(zip_file_path) as zf: | ||
zf.extractall(data_dir) | ||
|
||
input_shape = (1, 28, 28) | ||
batch_size = args.batch_size | ||
|
||
train_iter = mx.io.MNISTIter( | ||
image="%s/train-images-idx3-ubyte" % data_dir, | ||
label="%s/train-labels-idx1-ubyte" % data_dir, | ||
input_shape=input_shape, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
flat=False, | ||
num_parts=hvd.size(), | ||
part_index=hvd.rank() | ||
) | ||
|
||
val_iter = mx.io.MNISTIter( | ||
image="%s/t10k-images-idx3-ubyte" % data_dir, | ||
label="%s/t10k-labels-idx1-ubyte" % data_dir, | ||
input_shape=input_shape, | ||
batch_size=batch_size, | ||
flat=False, | ||
) | ||
|
||
return train_iter, val_iter | ||
|
||
|
||
# Function to define neural network | ||
def conv_nets(): | ||
net = gluon.nn.HybridSequential() | ||
with net.name_scope(): | ||
net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu')) | ||
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2)) | ||
net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu')) | ||
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2)) | ||
net.add(gluon.nn.Flatten()) | ||
net.add(gluon.nn.Dense(512, activation="relu")) | ||
net.add(gluon.nn.Dense(10)) | ||
return net | ||
|
||
|
||
# Function to evaluate accuracy for a model | ||
def evaluate(model, data_iter, context): | ||
data_iter.reset() | ||
metric = mx.metric.Accuracy() | ||
for _, batch in enumerate(data_iter): | ||
data = batch.data[0].as_in_context(context) | ||
label = batch.label[0].as_in_context(context) | ||
output = model(data.astype(args.dtype, copy=False)) | ||
metric.update([label], [output]) | ||
return metric.get() | ||
|
||
|
||
# Initialize Horovod | ||
hvd.init() | ||
|
||
# Horovod: pin context to local rank | ||
context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank()) | ||
num_workers = hvd.size() | ||
|
||
# Load training and validation data | ||
train_data, val_data = get_mnist_iterator(hvd.rank()) | ||
|
||
# Build model | ||
model = conv_nets() | ||
model.cast(args.dtype) | ||
model.hybridize() | ||
|
||
# Create optimizer | ||
optimizer_params = {'momentum': args.momentum, | ||
'learning_rate': args.lr * hvd.size()} | ||
opt = mx.optimizer.create('sgd', **optimizer_params) | ||
|
||
# Initialize parameters | ||
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", | ||
magnitude=2) | ||
model.initialize(initializer, ctx=context) | ||
|
||
# Horovod: fetch and broadcast parameters | ||
params = model.collect_params() | ||
if params is not None: | ||
hvd.broadcast_parameters(params, root_rank=0) | ||
|
||
# Horovod: create DistributedTrainer, a subclass of gluon.Trainer | ||
trainer = hvd.DistributedTrainer(params, opt) | ||
|
||
# Create loss function and train metric | ||
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() | ||
metric = mx.metric.Accuracy() | ||
|
||
# Train model | ||
for epoch in range(args.epochs): | ||
tic = time.time() | ||
train_data.reset() | ||
metric.reset() | ||
for nbatch, batch in enumerate(train_data, start=1): | ||
data = batch.data[0].as_in_context(context) | ||
label = batch.label[0].as_in_context(context) | ||
with autograd.record(): | ||
output = model(data.astype(args.dtype, copy=False)) | ||
loss = loss_fn(output, label) | ||
loss.backward() | ||
trainer.step(args.batch_size) | ||
metric.update([label], [output]) | ||
|
||
if nbatch % 100 == 0: | ||
name, acc = metric.get() | ||
logging.info('[Epoch %d Batch %d] Training: %s=%f' % | ||
(epoch, nbatch, name, acc)) | ||
|
||
if hvd.rank() == 0: | ||
elapsed = time.time() - tic | ||
speed = nbatch * args.batch_size * hvd.size() / elapsed | ||
logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f', | ||
epoch, speed, elapsed) | ||
|
||
# Evaluate model accuracy | ||
_, train_acc = metric.get() | ||
name, val_acc = evaluate(model, val_data, context) | ||
if hvd.rank() == 0: | ||
logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name, | ||
train_acc, name, val_acc) | ||
|
||
if hvd.rank() == 0 and epoch == args.epochs - 1: | ||
assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected\ | ||
(0.96)" % val_acc |