Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ddp #2

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions pytorch/ugatit/run.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,32 @@
#!/bin/bash

# dataset can be downloaded from
this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"

# dataset can be downloaded from
# https://drive.google.com/file/d/1xOWj1UVgp6NKMT3HbPhBbtq2A4EDkghF/view

export DISTRIBUTED_FRAMEWORK=byteps
SHOULD_DOWNLOAD_DATASET=${SHOULD_DOWNLOAD_DATASET:-1}
OMPI_COMM_WORLD_LOCAL_RANK=${OMPI_COMM_WORLD_LOCAL_RANK:-0}
if [[ "$SHOULD_DOWNLOAD_DATASET" == "1" ]] && [[ "$OMPI_COMM_WORLD_LOCAL_RANK" == "0" ]]; then
cd
[[ -f ./selfie2anime.zip ]] || wget -nv https://byteps.tos-cn-qingdao.volces.com/datasets/selfie2anime.zip
unzip -qn ./selfie2anime.zip -d selfie2anime
cd -
fi

export DISTRIBUTED_FRAMEWORK=${DISTRIBUTED_FRAMEWORK:-byteps}

cd ./ugatit
bpslaunch python3 main.py $@
if [[ "$DISTRIBUTED_FRAMEWORK" == "byteps" ]]; then
bytepsrun python3 ${this_dir}/ugatit/main.py $@
elif [[ "$DISTRIBUTED_FRAMEWORK" == "horovod" ]]; then
python3 ${this_dir}/ugatit/main.py $@
elif [[ "$DISTRIBUTED_FRAMEWORK" == "torch_native" ]]; then
torchrun --nproc_per_node=${ML_PLATFORM_WORKER_GPU} \
--nnodes=${ML_PLATFORM_WORKER_NUM} \
--node_rank=${RANK} \
--master_addr="${MASTER_ADDR}" \
--master_port=${MASTER_PORT} ${this_dir}/ugatit/main.py $@
else
echo "Unsupported distributed training framework: $DISTRIBUTED_FRAMEWORK"
echo "Please choose from: byteps, horovod and torch_native"
fi
27 changes: 27 additions & 0 deletions pytorch/ugatit/run_hvd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash

this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"

cd
wget https://byteps.tos-cn-qingdao.volces.com/byteps-examples.tar.gz
mkdir byteps-examples
tar xvzf byteps-examples.tar.gz --strip-components=1 -C byteps-examples

# dataset can be downloaded from
# https://drive.google.com/file/d/1xOWj1UVgp6NKMT3HbPhBbtq2A4EDkghF/view

SHOULD_DOWNLOAD_DATASET=${SHOULD_DOWNLOAD_DATASET:-1}
if [[ "$SHOULD_DOWNLOAD_DATASET" == "1" ]]; then
cd
wget https://byteps.tos-cn-qingdao.volces.com/datasets/selfie2anime.zip
unzip ./selfie2anime.zip -d selfie2anime
cd -
fi

export DISTRIBUTED_FRAMEWORK=${DISTRIBUTED_FRAMEWORK:-byteps}

if [[ "$DISTRIBUTED_FRAMEWORK" != "horovod" ]]; then
echo "This script can only be used to launch training uisng horovod."
exit 1
fi
python3 ${this_dir}/ugatit/main.py $@
15 changes: 15 additions & 0 deletions pytorch/ugatit/run_torch_ddp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

# dataset can be downloaded from
# https://drive.google.com/file/d/1xOWj1UVgp6NKMT3HbPhBbtq2A4EDkghF/view

export DISTRIBUTED_FRAMEWORK=torch_native

cd ./ugatit
torchrun --nproc_per_node=${ML_PLATFORM_WORKER_GPU} \
--nnodes=${ML_PLATFORM_WORKER_NUM} \
--node_rank=${RANK} \
--master_addr="${MASTER_ADDR}" \
--master_port=${MASTER_PORT} main.py $@

# https://github.com/pytorch/pytorch/blob/v1.10.0/torch/distributed/run.py
117 changes: 73 additions & 44 deletions pytorch/ugatit/ugatit/UGATIT.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,27 @@
from glob import glob
from collections import OrderedDict

if os.getenv("DISTRIBUTED_FRAMEWORK") == "byteps":
dist_framework = os.getenv("DISTRIBUTED_FRAMEWORK", "").lower()
if dist_framework == "byteps":
import byteps.torch as bps
else:
elif dist_framework == "horovod":
import horovod.torch as bps
else:
import torch.distributed as bps
from torch.nn.parallel import DistributedDataParallel as DDP

def local_rank():
return int(os.getenv("LOCAL_RANK", "-1"))
def local_size():
return int(os.getenv("LOCAL_WORLD_SIZE", "-1"))
def rank():
return bps.get_rank()
def size():
return bps.get_world_size()
bps.local_rank = local_rank
bps.local_size = local_size
bps.rank = rank
bps.size = size

class UGATIT(object) :
def __init__(self, args):
Expand Down Expand Up @@ -145,31 +162,40 @@ def build_model(self):
self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)
self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device)

if dist_framework == "torch_native":
self.genA2B = DDP(self.genA2B, device_ids=[bps.local_rank()])
self.genB2A = DDP(self.genB2A, device_ids=[bps.local_rank()])
self.disGA = DDP(self.disGA, device_ids=[bps.local_rank()])
self.disGB = DDP(self.disGB, device_ids=[bps.local_rank()])
self.disLA = DDP(self.disLA, device_ids=[bps.local_rank()])
self.disLB = DDP(self.disLB, device_ids=[bps.local_rank()])

""" Define Loss """
self.L1_loss = nn.L1Loss().to(self.device)
self.MSE_loss = nn.MSELoss().to(self.device)
self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device)

gen_named_parameters = []
dis_named_parameters = []
for n, p in (list(self.genA2B.named_parameters(prefix='genA2B')) +
list(self.genB2A.named_parameters(prefix='genB2A'))):
gen_named_parameters.append((n, p))
for n, p in (list(self.disGA.named_parameters(prefix='disGA')) +
list(self.disGB.named_parameters(prefix='disGB')) +
list(self.disLA.named_parameters(prefix='disLA')) +
list(self.disLB.named_parameters(prefix='disLB'))):
dis_named_parameters.append((n, p))

gen_state_dict = OrderedDict([("genA2B."+k, v) for k, v in self.genA2B.state_dict().items()] +
[("genB2A."+k, v) for k, v in self.genB2A.state_dict().items()])
dis_state_dict = OrderedDict([("disGA."+k, v) for k, v in self.disGA.state_dict().items()] +
[("disGB."+k, v) for k, v in self.disGB.state_dict().items()] +
[("disLA."+k, v) for k, v in self.disLA.state_dict().items()] +
[("disLB."+k, v) for k, v in self.disLB.state_dict().items()])
if dist_framework != "torch_native":
gen_named_parameters = []
dis_named_parameters = []
for n, p in (list(self.genA2B.named_parameters(prefix='genA2B')) +
list(self.genB2A.named_parameters(prefix='genB2A'))):
gen_named_parameters.append((n, p))
for n, p in (list(self.disGA.named_parameters(prefix='disGA')) +
list(self.disGB.named_parameters(prefix='disGB')) +
list(self.disLA.named_parameters(prefix='disLA')) +
list(self.disLB.named_parameters(prefix='disLB'))):
dis_named_parameters.append((n, p))

gen_state_dict = OrderedDict([("genA2B."+k, v) for k, v in self.genA2B.state_dict().items()] +
[("genB2A."+k, v) for k, v in self.genB2A.state_dict().items()])
dis_state_dict = OrderedDict([("disGA."+k, v) for k, v in self.disGA.state_dict().items()] +
[("disGB."+k, v) for k, v in self.disGB.state_dict().items()] +
[("disLA."+k, v) for k, v in self.disLA.state_dict().items()] +
[("disLB."+k, v) for k, v in self.disLB.state_dict().items()])

bps.broadcast_parameters(gen_state_dict, root_rank=0)
bps.broadcast_parameters(dis_state_dict, root_rank=0)
bps.broadcast_parameters(gen_state_dict, root_rank=0)
bps.broadcast_parameters(dis_state_dict, root_rank=0)

""" Trainer """
self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()),
Expand All @@ -187,16 +213,17 @@ def build_model(self):
for n, p in list(self.genB2A.named_parameters()):
named_parameters.append(("genB2A." + n, p))

self.G_optim = bps.DistributedOptimizer(self.G_optim,
named_parameters=gen_named_parameters,
compression=bps.Compression.none)
if dist_framework != "torch_native":
self.G_optim = bps.DistributedOptimizer(self.G_optim,
named_parameters=gen_named_parameters,
compression=bps.Compression.none)

self.D_optim = bps.DistributedOptimizer(self.D_optim,
named_parameters=dis_named_parameters,
compression=bps.Compression.none)
self.D_optim = bps.DistributedOptimizer(self.D_optim,
named_parameters=dis_named_parameters,
compression=bps.Compression.none)

self.G_optim._handles.clear()
self.D_optim._handles.clear()
self.G_optim._handles.clear()
self.D_optim._handles.clear()

""" Define Rho clipper to constraint the value of rho in AdaILN and ILN"""
self.Rho_clipper = RhoClipper(0, 1)
Expand Down Expand Up @@ -240,10 +267,11 @@ def train(self):
real_A, real_B = real_A.to(self.device), real_B.to(self.device)

# Update D
self.D_optim._handles.clear()
if dist_framework != "torch_native":
self.D_optim._handles.clear()
self.D_optim.set_backward_passes_per_step(1)
self.G_optim.set_backward_passes_per_step(10)
self.D_optim.zero_grad()
self.D_optim.set_backward_passes_per_step(1)
self.G_optim.set_backward_passes_per_step(10)

fake_A2B, _, _ = self.genA2B(real_A)
fake_B2A, _, _ = self.genB2A(real_B)
Expand All @@ -253,10 +281,10 @@ def train(self):
real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B)
real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B)

fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)
fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A.detach())
fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A.detach())
fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B.detach())
fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B.detach())

D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(self.device)) + self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device))
D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(self.device)) + self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device))
Expand All @@ -275,10 +303,11 @@ def train(self):
self.D_optim.step()

# Update G
self.G_optim._handles.clear()
if dist_framework != "torch_native":
self.G_optim._handles.clear()
self.D_optim.set_backward_passes_per_step(10)
self.G_optim.set_backward_passes_per_step(1)
self.G_optim.zero_grad()
self.D_optim.set_backward_passes_per_step(10)
self.G_optim.set_backward_passes_per_step(1)

fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A)
fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B)
Expand All @@ -288,11 +317,11 @@ def train(self):

fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A)
fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B)

fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)
with torch.no_grad():
fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A)
fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A)
fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B)
fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B)

G_ad_loss_GA = self.MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device))
G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device))
Expand Down
28 changes: 25 additions & 3 deletions pytorch/ugatit/ugatit/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,31 @@
import argparse
from utils import *

if os.getenv("DISTRIBUTED_FRAMEWORK") == "byteps":
dist_framework = os.getenv("DISTRIBUTED_FRAMEWORK", "").lower()
if dist_framework == "byteps":
import byteps.torch as bps
else:
elif dist_framework == "horovod":
import horovod.torch as bps
else:
import torch.distributed as bps
def local_rank():
return int(os.getenv("LOCAL_RANK", "-1"))
def local_size():
return int(os.getenv("LOCAL_WORLD_SIZE", "-1"))
def rank():
return bps.get_rank()
def size():
return bps.get_world_size()
def init():
bps.init_process_group(backend="nccl")
return None

bps.local_rank = local_rank
bps.local_size = local_size
bps.rank = rank
bps.size = size
bps.init = init


"""parsing and configuration"""

Expand All @@ -15,7 +36,7 @@ def parse_args():
parser.add_argument('--phase', type=str, default='train', help='[train / test]')
parser.add_argument('--light', type=str2bool, default=False, help='[U-GAT-IT full version / U-GAT-IT light version]')
parser.add_argument('--dataset_dir', type=str, default='dataset', help='dataset dir path')
parser.add_argument('--dataset', type=str, default='YOUR_D:ATASET_NAME', help='dataset_name')
parser.add_argument('--dataset', type=str, default='YOUR_DATASET_NAME', help='dataset_name')

parser.add_argument('--iteration', type=int, default=1000000, help='The number of training iterations')
parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size')
Expand Down Expand Up @@ -72,6 +93,7 @@ def main():
bps.init()
torch.manual_seed(1)
torch.cuda.manual_seed(1)
print(f'xxxx bps.local_rank() {bps.local_rank()}', flush=True)
torch.cuda.set_device(bps.local_rank())
# parse arguments
args = parse_args()
Expand Down