diff --git a/pyhealth/trainer.py b/pyhealth/trainer.py index 25414204..de147bec 100644 --- a/pyhealth/trainer.py +++ b/pyhealth/trainer.py @@ -5,17 +5,40 @@ import numpy as np import torch +import torch.distributed as dist +import torch.multiprocessing as mp from torch import nn from torch.optim import Optimizer -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from tqdm.autonotebook import trange +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP from pyhealth.metrics import (binary_metrics_fn, multiclass_metrics_fn, multilabel_metrics_fn, regression_metrics_fn) from pyhealth.utils import create_directory logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def ddp_setup(rank: int, world_size: int): + """ + Args: + rank: Unique identifier of each process + world_size: Total number of processes + """ + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + torch.cuda.set_device(rank) + dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) + return + + +def ddp_cleanup(): + dist.destroy_process_group() + return def is_best(best_score: float, score: float, monitor_criterion: str) -> bool: @@ -27,8 +50,10 @@ def is_best(best_score: float, score: float, monitor_criterion: str) -> bool: raise ValueError(f"Monitor criterion {monitor_criterion} is not supported") -def set_logger(log_path: str) -> None: - create_directory(log_path) +def set_logger(log_path: str, world_size: int, gpu_id: int) -> None: + # create only on CPU/single-GPU or first process in multi-GPU + if world_size in [0, 1] or (world_size > 1 and gpu_id == 0): + create_directory(log_path) log_filename = os.path.join(log_path, "log.txt") handler = logging.FileHandler(log_filename) formatter = logging.Formatter("%(asctime)s %(message)s", "%Y-%m-%d %H:%M:%S") @@ -50,6 +75,56 @@ def get_metrics_fn(mode: str) -> Callable: raise ValueError(f"Mode {mode} is not supported") +class MNISTDataset(Dataset): + def __init__(self, train=True): + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + self.dataset = datasets.MNIST( + "../data", train=train, download=True, transform=transform + ) + + def __getitem__(self, index): + x, y = self.dataset[index] + return {"x": x, "y": y} + + def __len__(self): + return len(self.dataset) + + +class Model(nn.Module): + def __init__(self, device, gpu_id): + super(Model, self).__init__() + self.mode = "multiclass" + self.device = device + self.gpu_id = gpu_id + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + self.loss = nn.CrossEntropyLoss() + + def forward(self, x, y, **kwargs): + x = torch.stack(x, dim=0).to(self.gpu_id if self.gpu_id is not None else self.device) + y = torch.tensor(y).to(self.gpu_id if self.gpu_id is not None else self.device) + x = self.conv1(x) + x = torch.relu(x) + x = self.conv2(x) + x = torch.relu(x) + x = torch.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = torch.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + loss = self.loss(x, y) + y_prob = torch.softmax(x, dim=1) + return {"loss": loss, "y_prob": y_prob, "y_true": y} + + class Trainer: """Trainer for PyTorch models. @@ -61,6 +136,12 @@ class Trainer: means the default metrics in each metrics_fn will be used. device: Device to be used for training. Default is None, which means the device will be GPU if available, otherwise CPU. + rank: In case of multi-GPU, it's the unique identifier of each process. + Default is 0, which means single-GPU or CPU. + world_size: In case of multi-GPU, it's the total number of processes. + Default is 0, which means single-GPU or CPU. + gpu_id: In case of multi-GPU, it's the unique identifier of each GPU. + Default is None, which means single-GPU or CPU. enable_logging: Whether to enable logging. Default is True. output_path: Path to save the output. Default is "./output". exp_name: Name of the experiment. Default is current datetime. @@ -72,6 +153,9 @@ def __init__( checkpoint_path: Optional[str] = None, metrics: Optional[List[str]] = None, device: Optional[str] = None, + rank: int = 0, + world_size: int = 1, + gpu_id: int = None, enable_logging: bool = True, output_path: Optional[str] = None, exp_name: Optional[str] = None, @@ -82,6 +166,9 @@ def __init__( self.model = model self.metrics = metrics self.device = device + self.rank = rank + self.world_size = world_size + self.gpu_id = gpu_id # set logger if enable_logging: @@ -90,12 +177,17 @@ def __init__( if exp_name is None: exp_name = datetime.now().strftime("%Y%m%d-%H%M%S") self.exp_path = os.path.join(output_path, exp_name) - set_logger(self.exp_path) + set_logger(self.exp_path, self.world_size, self.gpu_id) else: self.exp_path = None # set device - self.model.to(self.device) + if self.world_size > 1: + # multi-gpu + ddp_setup(self.rank, self.world_size) + self.model = DDP(self.model, device_ids=self.gpu_id) + else: + self.model.to(self.device) # logging logger.info(self.model) @@ -174,6 +266,8 @@ def train( optimizer = optimizer_class(optimizer_grouped_parameters, **optimizer_params) # initialize + if self.world_size > 1: + train_dataloader.sampler.set_epoch(0) data_iterator = iter(train_dataloader) best_score = -1 * float("inf") if monitor_criterion == "max" else float("inf") if steps_per_epoch == None: @@ -195,6 +289,8 @@ def train( try: data = next(data_iterator) except StopIteration: + if self.world_size > 1: + train_dataloader.sampler.set_epoch(epoch+1) data_iterator = iter(train_dataloader) data = next(data_iterator) # forward @@ -215,7 +311,9 @@ def train( logger.info(f"--- Train epoch-{epoch}, step-{global_step} ---") logger.info(f"loss: {sum(training_loss) / len(training_loss):.4f}") if self.exp_path is not None: - self.save_ckpt(os.path.join(self.exp_path, "last.ckpt")) + # save only on CPU/single-GPU or first process in multi-GPU + if self.world_size in [0, 1] or (self.world_size > 1 and self.gpu_id == 0): + self.save_ckpt(os.path.join(self.exp_path, "last.ckpt")) # validation if val_dataloader is not None: @@ -233,7 +331,13 @@ def train( ) best_score = score if self.exp_path is not None: - self.save_ckpt(os.path.join(self.exp_path, "best.ckpt")) + # save only on CPU/single-GPU or first process in multi-GPU + if self.world_size in [0, 1] or (self.world_size > 1 and self.gpu_id == 0): + self.save_ckpt(os.path.join(self.exp_path, "best.ckpt")) + + # clean up DDP resources + if self.world_size > 1: + ddp_cleanup() # load best model if load_best_model_at_last and self.exp_path is not None and os.path.isfile( @@ -328,92 +432,83 @@ def evaluate(self, dataloader) -> Dict[str, float]: def save_ckpt(self, ckpt_path: str) -> None: """Saves the model checkpoint.""" - state_dict = self.model.state_dict() + if self.world_size > 1: + state_dict = self.model.module.state_dict() + else: + state_dict = self.model.state_dict() torch.save(state_dict, ckpt_path) return def load_ckpt(self, ckpt_path: str) -> None: - """Saves the model checkpoint.""" + """Loads the model checkpoint.""" state_dict = torch.load(ckpt_path, map_location=self.device) - self.model.load_state_dict(state_dict) + if self.world_size > 1: + self.model.module.load_state_dict(state_dict) + else: + self.model.load_state_dict(state_dict) return +def run_training( + rank: int, + world_size: int, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader] = None, + ): + """Triggers the training pipeline. + + Args: + rank: In case of multi-GPU, it's the unique identifier of each process. + 0 means single-GPU or CPU. + world_size: In case of multi-GPU, it's the total number of processes. + 0 means single-GPU or CPU. + train_dataloader: Dataloader for training. + val_dataloader: Dataloader for validation. Default is None + """ + device = "cuda" if torch.cuda.is_available() else "cpu" + gpu_id = None + if world_size > 1: + gpu_id = rank + + model = Model(device, gpu_id) + trainer = Trainer(model, device=device, rank=rank, world_size=world_size, gpu_id=gpu_id) + trainer.train( + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + monitor="accuracy", + epochs=5, + test_dataloader=val_dataloader, + ) + return + + if __name__ == "__main__": import torch import torch.nn as nn - from torch.utils.data import DataLoader, Dataset + from torch.utils.data import DataLoader from torchvision import datasets, transforms from pyhealth.datasets.utils import collate_fn_dict - - class MNISTDataset(Dataset): - def __init__(self, train=True): - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] - ) - self.dataset = datasets.MNIST( - "../data", train=train, download=True, transform=transform - ) - - def __getitem__(self, index): - x, y = self.dataset[index] - return {"x": x, "y": y} - - def __len__(self): - return len(self.dataset) - - - class Model(nn.Module): - def __init__(self): - super(Model, self).__init__() - self.mode = "multiclass" - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.conv1 = nn.Conv2d(1, 32, 3, 1) - self.conv2 = nn.Conv2d(32, 64, 3, 1) - self.dropout1 = nn.Dropout(0.25) - self.dropout2 = nn.Dropout(0.5) - self.fc1 = nn.Linear(9216, 128) - self.fc2 = nn.Linear(128, 10) - self.loss = nn.CrossEntropyLoss() - - def forward(self, x, y, **kwargs): - x = torch.stack(x, dim=0).to(self.device) - y = torch.tensor(y).to(self.device) - x = self.conv1(x) - x = torch.relu(x) - x = self.conv2(x) - x = torch.relu(x) - x = torch.max_pool2d(x, 2) - x = self.dropout1(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = torch.relu(x) - x = self.dropout2(x) - x = self.fc2(x) - loss = self.loss(x, y) - y_prob = torch.softmax(x, dim=1) - return {"loss": loss, "y_prob": y_prob, "y_true": y} - - train_dataset = MNISTDataset(train=True) val_dataset = MNISTDataset(train=False) - train_dataloader = DataLoader( - train_dataset, collate_fn=collate_fn_dict, batch_size=64, shuffle=True - ) + world_size = torch.cuda.device_count() + if world_size > 1: + train_dataloader = DataLoader( + train_dataset, collate_fn=collate_fn_dict, batch_size=64, shuffle=False, sampler=DistributedSampler(train_dataset) + ) + else: + train_dataloader = DataLoader( + train_dataset, collate_fn=collate_fn_dict, batch_size=64, shuffle=True + ) val_dataloader = DataLoader( val_dataset, collate_fn=collate_fn_dict, batch_size=64, shuffle=False ) - model = Model() - - trainer = Trainer(model, device="cuda" if torch.cuda.is_available() else "cpu") - trainer.train( - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - monitor="accuracy", - epochs=5, - test_dataloader=val_dataloader, - ) + if world_size > 1: + # multi-GPU + mp.spawn(run_training, args=(world_size, train_dataloader, val_dataloader), nprocs=world_size, join=True) + else: + # single-GPU or CPU + run_training(rank=0, world_size=world_size, train_dataloader=train_dataloader, val_dataloader=val_dataloader)