Skip to content

Commit

Permalink
Feat ckpt (#18)
Browse files Browse the repository at this point in the history
* fix vocab size for debugmodel and real data

* add torchdata

* add ckpt v0

* fix real dataset loading error

* ckpt save in the right step folder

* change total tokens diloco

* add tests

* refactor easier step

* use fsspec rsync for bbacking up the ckpt to remote

* add async saving to remote

* remove unused file

* fix rebase

* fix ckpt

* add diloco ckpt

* save into dioco scpeific folder

* firemove process group

* firemove process group

* add diloco rank

* fix ckpt issue

* remove ckpt tests
  • Loading branch information
samsja authored Sep 30, 2024
1 parent 46a0d04 commit 6475ea4
Show file tree
Hide file tree
Showing 6 changed files with 554 additions and 16 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ dependencies = [
"transformers>=4.44.2",
"datasets>=3.0.0",
"pydantic_config @ git+https://github.com/samsja/pydantic_config.git@e529c9c",
"einops"
"einops",
"torchdata>=0.8.0",
"fsspec[gcs]>=2024.3.1",
]

[project.optional-dependencies]
Expand Down
225 changes: 225 additions & 0 deletions src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
from dataclasses import dataclass
import multiprocessing
import os
import time
from typing import Any
from fsspec.generic import rsync as rsync_fsspec
import torch
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from torchdata.stateful_dataloader import StatefulDataLoader
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
set_optimizer_state_dict,
set_model_state_dict,
get_model_state_dict,
get_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.checkpoint.stateful import Stateful
from zeroband.utils.logging import get_logger
import warnings
import logging

from zeroband.utils.world_info import get_world_info

## code inspired by torchtitan https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py


@dataclass
class TrainingProgress(Stateful):
total_tokens: int
outer_step: int
step: int

def state_dict(self) -> dict[str, Any]:
return {"total_tokens": self.total_tokens, "outer_step": self.outer_step, "step": self.step}

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self.total_tokens = state_dict["total_tokens"]
self.outer_step = state_dict["outer_step"]
self.step = state_dict["step"]


class ModelWrapper(Stateful):
def __init__(self, model: nn.Module) -> None:
self.model = model

def state_dict(self) -> dict[str, Any]:
return get_model_state_dict(self.model)

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
set_model_state_dict(model=self.model, model_state_dict=state_dict, options=StateDictOptions(strict=False))


class OptimizerWrapper(Stateful):
def __init__(
self,
model: nn.Module,
optim: torch.optim.Optimizer,
) -> None:
self.model = model
self.optim = optim

def state_dict(self) -> dict[str, Any]:
return get_optimizer_state_dict(
model=self.model, optimizers=self.optim, options=StateDictOptions(flatten_optimizer_state_dict=True)
)

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
set_optimizer_state_dict(
model=self.model, optimizers=self.optim, optim_state_dict=state_dict, options=StateDictOptions(strict=False)
)


class CkptManager:
"""Its name CkptManager because I (sami) always misstyped chekcpoint.
Checkpoint are saved in a folder with the following structure:
ckpt_path/
step_0/
_0_0.pt
_1_0.pt
...
step_1/
...
"""

def __init__(
self,
model: nn.Module,
optimizer: Optimizer,
scheduler: LambdaLR,
dataloader: StatefulDataLoader,
training_progress: TrainingProgress,
diloco_offloaded_param_list: list[nn.Parameter] | None,
diloco_offloaded_optimizer: Optimizer | None,
):
self.model = ModelWrapper(model)
self.optimizer = OptimizerWrapper(model, optimizer)
self.scheduler = scheduler
self.dataloader = dataloader
self.training_progress = training_progress

# states can only be stateful object, hence we need to wrap Model and Optimizer
self.states: dict[str, Stateful] = {
"model": self.model,
"optimizer": self.optimizer,
"scheduler": self.scheduler,
# "dataloader": self.dataloader, # ignoring dataloader for now as each rank has its own dataloader
"training_progress": self.training_progress,
}

assert (diloco_offloaded_param_list is None) == (
diloco_offloaded_optimizer is None
), "diloco_offloaded_model and diloco_offloaded_optimizer must be both None or both have values"

self.diloco_offloaded_optimizer = diloco_offloaded_optimizer # he we don't use Wrapper because it failed
# which might make the ckpt less generic in term of loading from different number of device. FSDP ckpt seems to be a mess tho
self.diloco_offloaded_param_list = diloco_offloaded_param_list

if diloco_offloaded_optimizer is not None:
# even if the diloco_offloaded target the cpu list model, we still use the gpu model to load and save state.
# main reason is that we actually don't a cpu model but just a list of cpu parameters.
self.states["diloco_optimizer"] = self.diloco_offloaded_optimizer

self._logger = get_logger()

self.async_save_process: list[multiprocessing.Process] = []

def save(self, ckpt_path: str, remote_ckpt_path: str | None) -> None:
"""
Each rank will save the right shard of the model and optimizer.
Saving is done inplace
"""

time_start = time.perf_counter()
world_info = get_world_info()

ckpt_path = os.path.join(ckpt_path, f"step_{self.training_progress.step}")
if self.diloco_offloaded_optimizer:
# here we save model and offloaded optimizer on each diloco rank even tho they are the same
# this is done for two reasons:
# * if the nodes don't share a filesystem nor a remote path, they still save all of the data
# * its easier to implement and avoid race condition on the shared data.
ckpt_path = os.path.join(ckpt_path, f"diloco_{world_info.diloco_rank}")

catch_warning = self._logger.getEffectiveLevel() <= logging.INFO

with warnings.catch_warnings():
# pytorch has an annoying warning when saving the optimizer state https://github.com/pytorch/pytorch/issues/136907
# we can ignore it if we are not logging in DEBUG mode
if catch_warning:
warnings.simplefilter("ignore")

dcp.save(self.states, checkpoint_id=ckpt_path)

## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
with open(os.path.join(ckpt_path, f"__{world_info.local_rank}_0.pt"), "wb") as f:
torch.save({"data_loader": self.dataloader.state_dict()}, f)

self._logger.info(f"Saved checkpoint to {ckpt_path} in {time.perf_counter() - time_start} seconds")

if remote_ckpt_path is not None:
self._async_save_remote(ckpt_path, remote_ckpt_path)

def _async_save_remote(self, ckpt_path: str, remote_ckpt_path: str):
"""asyncronously rsync a ckpt folder to a remote location. Using fsspec to handle remote cloud storage without to install
specific libraries (e.g. s3fs)
"""

def rsync():
time_start = time.perf_counter()
self._logger.info(f"start pushing {ckpt_path} to {remote_ckpt_path} asynchronously")
rsync_fsspec(ckpt_path, destination=remote_ckpt_path)
self._logger.info(
f"finish pushing {ckpt_path} to {remote_ckpt_path} in {time.perf_counter() - time_start} seconds"
)

processes = multiprocessing.Process(target=rsync, daemon=True)
processes.start()

self.async_save_process.append(processes)

def wait_async_save_process(self):
"""
wait for all async save process to finish
"""
for process in self.async_save_process:
process.join()

def _del__(self):
self.wait_async_save_process()

def load(self, resume_ckpt_path: str) -> None:
"""
loading should be done after fsdp wrap and optimizer init.
Each rank will load the right shard of the model and optimizer.
All rank will load the global states (scheduler, step, total_tokens, dataloader).
`resume_ckpt_path` should point to a specific step and not to the base ckpt folder. Example: `ckpt_path/step_100`
Loading is done inplace
"""
time_start = time.perf_counter()

world_info = get_world_info()
if self.diloco_offloaded_param_list is not None:
resume_ckpt_path = os.path.join(resume_ckpt_path, f"diloco_{world_info.diloco_rank}")

self.states = dcp.load(self.states, checkpoint_id=resume_ckpt_path)

# since we don't load the param list from the state dict as its the same as the model one we just copy
if self.diloco_offloaded_param_list is not None:
for param_offloaded, param_model in zip(self.diloco_offloaded_param_list, self.model.model.parameters()):
param_offloaded.data.copy_(param_model.data)

## the next part is a fix so that each rank save a different dataloader rank. It not efficient because it reads the state two times from disk
with open(os.path.join(resume_ckpt_path, f"__{world_info.local_rank}_0.pt"), "rb") as f:
rank_state_dict = torch.load(f)

self.dataloader.load_state_dict(rank_state_dict["data_loader"])

self._logger.info(f"Loaded checkpoint from {resume_ckpt_path} in {time.perf_counter() - time_start} seconds")
3 changes: 2 additions & 1 deletion src/zeroband/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
Expand Down Expand Up @@ -79,7 +80,7 @@ def tokenize_function(data):

data_collator = collate_causal_mask(max_seq_length=seq_length, pad_id=tokenizer.pad_token_id, ignore_index=-100)

return DataLoader(
return StatefulDataLoader(
train_dataset,
collate_fn=data_collator,
batch_size=batch_size,
Expand Down
Loading

0 comments on commit 6475ea4

Please sign in to comment.