Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-GPU training #308

Open
wants to merge 2 commits into
base: 202410-sunlab-hackthon
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 169 additions & 74 deletions pyhealth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,40 @@

import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from tqdm.autonotebook import trange
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from pyhealth.metrics import (binary_metrics_fn, multiclass_metrics_fn,
multilabel_metrics_fn, regression_metrics_fn)
from pyhealth.utils import create_directory

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def ddp_setup(rank: int, world_size: int):
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
torch.cuda.set_device(rank)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
return


def ddp_cleanup():
dist.destroy_process_group()
return


def is_best(best_score: float, score: float, monitor_criterion: str) -> bool:
Expand All @@ -27,8 +50,10 @@ def is_best(best_score: float, score: float, monitor_criterion: str) -> bool:
raise ValueError(f"Monitor criterion {monitor_criterion} is not supported")


def set_logger(log_path: str) -> None:
create_directory(log_path)
def set_logger(log_path: str, world_size: int, gpu_id: int) -> None:
# create only on CPU/single-GPU or first process in multi-GPU
if world_size in [0, 1] or (world_size > 1 and gpu_id == 0):
create_directory(log_path)
log_filename = os.path.join(log_path, "log.txt")
handler = logging.FileHandler(log_filename)
formatter = logging.Formatter("%(asctime)s %(message)s", "%Y-%m-%d %H:%M:%S")
Expand All @@ -50,6 +75,56 @@ def get_metrics_fn(mode: str) -> Callable:
raise ValueError(f"Mode {mode} is not supported")


class MNISTDataset(Dataset):
def __init__(self, train=True):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
self.dataset = datasets.MNIST(
"../data", train=train, download=True, transform=transform
)

def __getitem__(self, index):
x, y = self.dataset[index]
return {"x": x, "y": y}

def __len__(self):
return len(self.dataset)


class Model(nn.Module):
def __init__(self, device, gpu_id):
super(Model, self).__init__()
self.mode = "multiclass"
self.device = device
self.gpu_id = gpu_id
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.loss = nn.CrossEntropyLoss()

def forward(self, x, y, **kwargs):
x = torch.stack(x, dim=0).to(self.gpu_id if self.gpu_id is not None else self.device)
y = torch.tensor(y).to(self.gpu_id if self.gpu_id is not None else self.device)
x = self.conv1(x)
x = torch.relu(x)
x = self.conv2(x)
x = torch.relu(x)
x = torch.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = torch.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
loss = self.loss(x, y)
y_prob = torch.softmax(x, dim=1)
return {"loss": loss, "y_prob": y_prob, "y_true": y}


class Trainer:
"""Trainer for PyTorch models.

Expand All @@ -61,6 +136,12 @@ class Trainer:
means the default metrics in each metrics_fn will be used.
device: Device to be used for training. Default is None, which means
the device will be GPU if available, otherwise CPU.
rank: In case of multi-GPU, it's the unique identifier of each process.
Default is 0, which means single-GPU or CPU.
world_size: In case of multi-GPU, it's the total number of processes.
Default is 0, which means single-GPU or CPU.
gpu_id: In case of multi-GPU, it's the unique identifier of each GPU.
Default is None, which means single-GPU or CPU.
enable_logging: Whether to enable logging. Default is True.
output_path: Path to save the output. Default is "./output".
exp_name: Name of the experiment. Default is current datetime.
Expand All @@ -72,6 +153,9 @@ def __init__(
checkpoint_path: Optional[str] = None,
metrics: Optional[List[str]] = None,
device: Optional[str] = None,
rank: int = 0,
world_size: int = 1,
gpu_id: int = None,
enable_logging: bool = True,
output_path: Optional[str] = None,
exp_name: Optional[str] = None,
Expand All @@ -82,6 +166,9 @@ def __init__(
self.model = model
self.metrics = metrics
self.device = device
self.rank = rank
self.world_size = world_size
self.gpu_id = gpu_id

# set logger
if enable_logging:
Expand All @@ -90,12 +177,17 @@ def __init__(
if exp_name is None:
exp_name = datetime.now().strftime("%Y%m%d-%H%M%S")
self.exp_path = os.path.join(output_path, exp_name)
set_logger(self.exp_path)
set_logger(self.exp_path, self.world_size, self.gpu_id)
else:
self.exp_path = None

# set device
self.model.to(self.device)
if self.world_size > 1:
# multi-gpu
ddp_setup(self.rank, self.world_size)
self.model = DDP(self.model, device_ids=self.gpu_id)
else:
self.model.to(self.device)

# logging
logger.info(self.model)
Expand Down Expand Up @@ -174,6 +266,8 @@ def train(
optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params)

# initialize
if self.world_size > 1:
train_dataloader.sampler.set_epoch(0)
data_iterator = iter(train_dataloader)
best_score = -1 * float("inf") if monitor_criterion == "max" else float("inf")
if steps_per_epoch == None:
Expand All @@ -195,6 +289,8 @@ def train(
try:
data = next(data_iterator)
except StopIteration:
if self.world_size > 1:
train_dataloader.sampler.set_epoch(epoch+1)
data_iterator = iter(train_dataloader)
data = next(data_iterator)
# forward
Expand All @@ -215,7 +311,9 @@ def train(
logger.info(f"--- Train epoch-{epoch}, step-{global_step} ---")
logger.info(f"loss: {sum(training_loss) / len(training_loss):.4f}")
if self.exp_path is not None:
self.save_ckpt(os.path.join(self.exp_path, "last.ckpt"))
# save only on CPU/single-GPU or first process in multi-GPU
if self.world_size in [0, 1] or (self.world_size > 1 and self.gpu_id == 0):
self.save_ckpt(os.path.join(self.exp_path, "last.ckpt"))

# validation
if val_dataloader is not None:
Expand All @@ -233,7 +331,13 @@ def train(
)
best_score = score
if self.exp_path is not None:
self.save_ckpt(os.path.join(self.exp_path, "best.ckpt"))
# save only on CPU/single-GPU or first process in multi-GPU
if self.world_size in [0, 1] or (self.world_size > 1 and self.gpu_id == 0):
self.save_ckpt(os.path.join(self.exp_path, "best.ckpt"))

# clean up DDP resources
if self.world_size > 1:
ddp_cleanup()

# load best model
if load_best_model_at_last and self.exp_path is not None and os.path.isfile(
Expand Down Expand Up @@ -328,92 +432,83 @@ def evaluate(self, dataloader) -> Dict[str, float]:

def save_ckpt(self, ckpt_path: str) -> None:
"""Saves the model checkpoint."""
state_dict = self.model.state_dict()
if self.world_size > 1:
state_dict = self.model.module.state_dict()
else:
state_dict = self.model.state_dict()
torch.save(state_dict, ckpt_path)
return

def load_ckpt(self, ckpt_path: str) -> None:
"""Saves the model checkpoint."""
"""Loads the model checkpoint."""
state_dict = torch.load(ckpt_path, map_location=self.device)
self.model.load_state_dict(state_dict)
if self.world_size > 1:
self.model.module.load_state_dict(state_dict)
else:
self.model.load_state_dict(state_dict)
return


def run_training(
rank: int,
world_size: int,
train_dataloader: DataLoader,
val_dataloader: Optional[DataLoader] = None,
):
"""Triggers the training pipeline.

Args:
rank: In case of multi-GPU, it's the unique identifier of each process.
0 means single-GPU or CPU.
world_size: In case of multi-GPU, it's the total number of processes.
0 means single-GPU or CPU.
train_dataloader: Dataloader for training.
val_dataloader: Dataloader for validation. Default is None
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
gpu_id = None
if world_size > 1:
gpu_id = rank

model = Model(device, gpu_id)
trainer = Trainer(model, device=device, rank=rank, world_size=world_size, gpu_id=gpu_id)
trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
monitor="accuracy",
epochs=5,
test_dataloader=val_dataloader,
)
return


if __name__ == "__main__":
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from pyhealth.datasets.utils import collate_fn_dict


class MNISTDataset(Dataset):
def __init__(self, train=True):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
self.dataset = datasets.MNIST(
"../data", train=train, download=True, transform=transform
)

def __getitem__(self, index):
x, y = self.dataset[index]
return {"x": x, "y": y}

def __len__(self):
return len(self.dataset)


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.mode = "multiclass"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
self.loss = nn.CrossEntropyLoss()

def forward(self, x, y, **kwargs):
x = torch.stack(x, dim=0).to(self.device)
y = torch.tensor(y).to(self.device)
x = self.conv1(x)
x = torch.relu(x)
x = self.conv2(x)
x = torch.relu(x)
x = torch.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = torch.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
loss = self.loss(x, y)
y_prob = torch.softmax(x, dim=1)
return {"loss": loss, "y_prob": y_prob, "y_true": y}


train_dataset = MNISTDataset(train=True)
val_dataset = MNISTDataset(train=False)

train_dataloader = DataLoader(
train_dataset, collate_fn=collate_fn_dict, batch_size=64, shuffle=True
)
world_size = torch.cuda.device_count()
if world_size > 1:
train_dataloader = DataLoader(
train_dataset, collate_fn=collate_fn_dict, batch_size=64, shuffle=False, sampler=DistributedSampler(train_dataset)
)
else:
train_dataloader = DataLoader(
train_dataset, collate_fn=collate_fn_dict, batch_size=64, shuffle=True
)
val_dataloader = DataLoader(
val_dataset, collate_fn=collate_fn_dict, batch_size=64, shuffle=False
)

model = Model()

trainer = Trainer(model, device="cuda" if torch.cuda.is_available() else "cpu")
trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
monitor="accuracy",
epochs=5,
test_dataloader=val_dataloader,
)
if world_size > 1:
# multi-GPU
mp.spawn(run_training, args=(world_size, train_dataloader, val_dataloader), nprocs=world_size, join=True)
else:
# single-GPU or CPU
run_training(rank=0, world_size=world_size, train_dataloader=train_dataloader, val_dataloader=val_dataloader)