diff --git a/pytorch/ugatit/run.sh b/pytorch/ugatit/run.sh index 4b8dbc5..2350c3f 100644 --- a/pytorch/ugatit/run.sh +++ b/pytorch/ugatit/run.sh @@ -1,9 +1,32 @@ #!/bin/bash -# dataset can be downloaded from +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +# dataset can be downloaded from # https://drive.google.com/file/d/1xOWj1UVgp6NKMT3HbPhBbtq2A4EDkghF/view -export DISTRIBUTED_FRAMEWORK=byteps +SHOULD_DOWNLOAD_DATASET=${SHOULD_DOWNLOAD_DATASET:-1} +OMPI_COMM_WORLD_LOCAL_RANK=${OMPI_COMM_WORLD_LOCAL_RANK:-0} +if [[ "$SHOULD_DOWNLOAD_DATASET" == "1" ]] && [[ "$OMPI_COMM_WORLD_LOCAL_RANK" == "0" ]]; then + cd + [[ -f ./selfie2anime.zip ]] || wget -nv https://byteps.tos-cn-qingdao.volces.com/datasets/selfie2anime.zip + unzip -qn ./selfie2anime.zip -d selfie2anime + cd - +fi + +export DISTRIBUTED_FRAMEWORK=${DISTRIBUTED_FRAMEWORK:-byteps} -cd ./ugatit -bpslaunch python3 main.py $@ \ No newline at end of file +if [[ "$DISTRIBUTED_FRAMEWORK" == "byteps" ]]; then + bytepsrun python3 ${this_dir}/ugatit/main.py $@ +elif [[ "$DISTRIBUTED_FRAMEWORK" == "horovod" ]]; then + python3 ${this_dir}/ugatit/main.py $@ +elif [[ "$DISTRIBUTED_FRAMEWORK" == "torch_native" ]]; then + torchrun --nproc_per_node=${ML_PLATFORM_WORKER_GPU} \ + --nnodes=${ML_PLATFORM_WORKER_NUM} \ + --node_rank=${RANK} \ + --master_addr="${MASTER_ADDR}" \ + --master_port=${MASTER_PORT} ${this_dir}/ugatit/main.py $@ +else + echo "Unsupported distributed training framework: $DISTRIBUTED_FRAMEWORK" + echo "Please choose from: byteps, horovod and torch_native" +fi diff --git a/pytorch/ugatit/run_hvd.sh b/pytorch/ugatit/run_hvd.sh new file mode 100755 index 0000000..f494133 --- /dev/null +++ b/pytorch/ugatit/run_hvd.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +cd +wget https://byteps.tos-cn-qingdao.volces.com/byteps-examples.tar.gz +mkdir byteps-examples +tar xvzf byteps-examples.tar.gz --strip-components=1 -C byteps-examples + +# dataset can be downloaded from +# https://drive.google.com/file/d/1xOWj1UVgp6NKMT3HbPhBbtq2A4EDkghF/view + +SHOULD_DOWNLOAD_DATASET=${SHOULD_DOWNLOAD_DATASET:-1} +if [[ "$SHOULD_DOWNLOAD_DATASET" == "1" ]]; then + cd + wget https://byteps.tos-cn-qingdao.volces.com/datasets/selfie2anime.zip + unzip ./selfie2anime.zip -d selfie2anime + cd - +fi + +export DISTRIBUTED_FRAMEWORK=${DISTRIBUTED_FRAMEWORK:-byteps} + +if [[ "$DISTRIBUTED_FRAMEWORK" != "horovod" ]]; then + echo "This script can only be used to launch training uisng horovod." + exit 1 +fi +python3 ${this_dir}/ugatit/main.py $@ diff --git a/pytorch/ugatit/run_torch_ddp.sh b/pytorch/ugatit/run_torch_ddp.sh new file mode 100644 index 0000000..ecd43d2 --- /dev/null +++ b/pytorch/ugatit/run_torch_ddp.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# dataset can be downloaded from +# https://drive.google.com/file/d/1xOWj1UVgp6NKMT3HbPhBbtq2A4EDkghF/view + +export DISTRIBUTED_FRAMEWORK=torch_native + +cd ./ugatit +torchrun --nproc_per_node=${ML_PLATFORM_WORKER_GPU} \ + --nnodes=${ML_PLATFORM_WORKER_NUM} \ + --node_rank=${RANK} \ + --master_addr="${MASTER_ADDR}" \ + --master_port=${MASTER_PORT} main.py $@ + +# https://github.com/pytorch/pytorch/blob/v1.10.0/torch/distributed/run.py diff --git a/pytorch/ugatit/ugatit/UGATIT.py b/pytorch/ugatit/ugatit/UGATIT.py index fa5e877..f0c6dfb 100644 --- a/pytorch/ugatit/ugatit/UGATIT.py +++ b/pytorch/ugatit/ugatit/UGATIT.py @@ -7,10 +7,27 @@ from glob import glob from collections import OrderedDict -if os.getenv("DISTRIBUTED_FRAMEWORK") == "byteps": +dist_framework = os.getenv("DISTRIBUTED_FRAMEWORK", "").lower() +if dist_framework == "byteps": import byteps.torch as bps -else: +elif dist_framework == "horovod": import horovod.torch as bps +else: + import torch.distributed as bps + from torch.nn.parallel import DistributedDataParallel as DDP + + def local_rank(): + return int(os.getenv("LOCAL_RANK", "-1")) + def local_size(): + return int(os.getenv("LOCAL_WORLD_SIZE", "-1")) + def rank(): + return bps.get_rank() + def size(): + return bps.get_world_size() + bps.local_rank = local_rank + bps.local_size = local_size + bps.rank = rank + bps.size = size class UGATIT(object) : def __init__(self, args): @@ -145,31 +162,40 @@ def build_model(self): self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) + if dist_framework == "torch_native": + self.genA2B = DDP(self.genA2B, device_ids=[bps.local_rank()]) + self.genB2A = DDP(self.genB2A, device_ids=[bps.local_rank()]) + self.disGA = DDP(self.disGA, device_ids=[bps.local_rank()]) + self.disGB = DDP(self.disGB, device_ids=[bps.local_rank()]) + self.disLA = DDP(self.disLA, device_ids=[bps.local_rank()]) + self.disLB = DDP(self.disLB, device_ids=[bps.local_rank()]) + """ Define Loss """ self.L1_loss = nn.L1Loss().to(self.device) self.MSE_loss = nn.MSELoss().to(self.device) self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device) - gen_named_parameters = [] - dis_named_parameters = [] - for n, p in (list(self.genA2B.named_parameters(prefix='genA2B')) + - list(self.genB2A.named_parameters(prefix='genB2A'))): - gen_named_parameters.append((n, p)) - for n, p in (list(self.disGA.named_parameters(prefix='disGA')) + - list(self.disGB.named_parameters(prefix='disGB')) + - list(self.disLA.named_parameters(prefix='disLA')) + - list(self.disLB.named_parameters(prefix='disLB'))): - dis_named_parameters.append((n, p)) - - gen_state_dict = OrderedDict([("genA2B."+k, v) for k, v in self.genA2B.state_dict().items()] + - [("genB2A."+k, v) for k, v in self.genB2A.state_dict().items()]) - dis_state_dict = OrderedDict([("disGA."+k, v) for k, v in self.disGA.state_dict().items()] + - [("disGB."+k, v) for k, v in self.disGB.state_dict().items()] + - [("disLA."+k, v) for k, v in self.disLA.state_dict().items()] + - [("disLB."+k, v) for k, v in self.disLB.state_dict().items()]) + if dist_framework != "torch_native": + gen_named_parameters = [] + dis_named_parameters = [] + for n, p in (list(self.genA2B.named_parameters(prefix='genA2B')) + + list(self.genB2A.named_parameters(prefix='genB2A'))): + gen_named_parameters.append((n, p)) + for n, p in (list(self.disGA.named_parameters(prefix='disGA')) + + list(self.disGB.named_parameters(prefix='disGB')) + + list(self.disLA.named_parameters(prefix='disLA')) + + list(self.disLB.named_parameters(prefix='disLB'))): + dis_named_parameters.append((n, p)) + + gen_state_dict = OrderedDict([("genA2B."+k, v) for k, v in self.genA2B.state_dict().items()] + + [("genB2A."+k, v) for k, v in self.genB2A.state_dict().items()]) + dis_state_dict = OrderedDict([("disGA."+k, v) for k, v in self.disGA.state_dict().items()] + + [("disGB."+k, v) for k, v in self.disGB.state_dict().items()] + + [("disLA."+k, v) for k, v in self.disLA.state_dict().items()] + + [("disLB."+k, v) for k, v in self.disLB.state_dict().items()]) - bps.broadcast_parameters(gen_state_dict, root_rank=0) - bps.broadcast_parameters(dis_state_dict, root_rank=0) + bps.broadcast_parameters(gen_state_dict, root_rank=0) + bps.broadcast_parameters(dis_state_dict, root_rank=0) """ Trainer """ self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), @@ -187,16 +213,17 @@ def build_model(self): for n, p in list(self.genB2A.named_parameters()): named_parameters.append(("genB2A." + n, p)) - self.G_optim = bps.DistributedOptimizer(self.G_optim, - named_parameters=gen_named_parameters, - compression=bps.Compression.none) + if dist_framework != "torch_native": + self.G_optim = bps.DistributedOptimizer(self.G_optim, + named_parameters=gen_named_parameters, + compression=bps.Compression.none) - self.D_optim = bps.DistributedOptimizer(self.D_optim, - named_parameters=dis_named_parameters, - compression=bps.Compression.none) + self.D_optim = bps.DistributedOptimizer(self.D_optim, + named_parameters=dis_named_parameters, + compression=bps.Compression.none) - self.G_optim._handles.clear() - self.D_optim._handles.clear() + self.G_optim._handles.clear() + self.D_optim._handles.clear() """ Define Rho clipper to constraint the value of rho in AdaILN and ILN""" self.Rho_clipper = RhoClipper(0, 1) @@ -240,10 +267,11 @@ def train(self): real_A, real_B = real_A.to(self.device), real_B.to(self.device) # Update D - self.D_optim._handles.clear() + if dist_framework != "torch_native": + self.D_optim._handles.clear() + self.D_optim.set_backward_passes_per_step(1) + self.G_optim.set_backward_passes_per_step(10) self.D_optim.zero_grad() - self.D_optim.set_backward_passes_per_step(1) - self.G_optim.set_backward_passes_per_step(10) fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) @@ -253,10 +281,10 @@ def train(self): real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) - fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) - fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) - fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) - fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) + fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A.detach()) + fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A.detach()) + fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B.detach()) + fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B.detach()) D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(self.device)) + self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device)) D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(self.device)) + self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device)) @@ -275,10 +303,11 @@ def train(self): self.D_optim.step() # Update G - self.G_optim._handles.clear() + if dist_framework != "torch_native": + self.G_optim._handles.clear() + self.D_optim.set_backward_passes_per_step(10) + self.G_optim.set_backward_passes_per_step(1) self.G_optim.zero_grad() - self.D_optim.set_backward_passes_per_step(10) - self.G_optim.set_backward_passes_per_step(1) fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) @@ -288,11 +317,11 @@ def train(self): fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) - - fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) - fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) - fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) - fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) + with torch.no_grad(): + fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) + fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) + fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) + fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = self.MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device)) G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device)) diff --git a/pytorch/ugatit/ugatit/main.py b/pytorch/ugatit/ugatit/main.py index 8dd21a9..33562d6 100644 --- a/pytorch/ugatit/ugatit/main.py +++ b/pytorch/ugatit/ugatit/main.py @@ -2,10 +2,31 @@ import argparse from utils import * -if os.getenv("DISTRIBUTED_FRAMEWORK") == "byteps": +dist_framework = os.getenv("DISTRIBUTED_FRAMEWORK", "").lower() +if dist_framework == "byteps": import byteps.torch as bps -else: +elif dist_framework == "horovod": import horovod.torch as bps +else: + import torch.distributed as bps + def local_rank(): + return int(os.getenv("LOCAL_RANK", "-1")) + def local_size(): + return int(os.getenv("LOCAL_WORLD_SIZE", "-1")) + def rank(): + return bps.get_rank() + def size(): + return bps.get_world_size() + def init(): + bps.init_process_group(backend="nccl") + return None + + bps.local_rank = local_rank + bps.local_size = local_size + bps.rank = rank + bps.size = size + bps.init = init + """parsing and configuration""" @@ -15,7 +36,7 @@ def parse_args(): parser.add_argument('--phase', type=str, default='train', help='[train / test]') parser.add_argument('--light', type=str2bool, default=False, help='[U-GAT-IT full version / U-GAT-IT light version]') parser.add_argument('--dataset_dir', type=str, default='dataset', help='dataset dir path') - parser.add_argument('--dataset', type=str, default='YOUR_D:ATASET_NAME', help='dataset_name') + parser.add_argument('--dataset', type=str, default='YOUR_DATASET_NAME', help='dataset_name') parser.add_argument('--iteration', type=int, default=1000000, help='The number of training iterations') parser.add_argument('--batch_size', type=int, default=1, help='The size of batch size') @@ -72,6 +93,7 @@ def main(): bps.init() torch.manual_seed(1) torch.cuda.manual_seed(1) + print(f'xxxx bps.local_rank() {bps.local_rank()}', flush=True) torch.cuda.set_device(bps.local_rank()) # parse arguments args = parse_args()