diff --git a/CHANGELOG.md b/CHANGELOG.md index a82f5f75..f359159e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`. - Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint. - Added a callback for sending Slack notifications. +- Added `SkipStepAdamW` optimizer. - The trainer can load model-only checkpoints now. ### Changed @@ -25,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Added missing `weights_only=False` argument to fix loading train checkpoints with newer versions of PyTorch. +- Fixed bug where GCS upload does not retry on transient failures. ## [v1.7.0](https://github.com/allenai/OLMo-core/releases/tag/v1.7.0) - 2024-11-27 diff --git a/src/olmo_core/distributed/checkpoint/__init__.py b/src/olmo_core/distributed/checkpoint/__init__.py index 70d0e542..068fc3e8 100644 --- a/src/olmo_core/distributed/checkpoint/__init__.py +++ b/src/olmo_core/distributed/checkpoint/__init__.py @@ -63,6 +63,7 @@ def save_state_dict( state_dict: Dict[str, Any], process_group: Optional[dist.ProcessGroup] = None, save_overwrite: bool = False, + thread_count: Optional[int] = None, ): """ Save an arbitrary state dictionary to a distributed format that can loaded again with @@ -80,7 +81,7 @@ def save_state_dict( dir = _prepare_env_for_save(dir, process_group=process_group, save_overwrite=save_overwrite) dist_cp.state_dict_saver.save( state_dict, - storage_writer=RemoteFileSystemWriter(dir), + storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count), process_group=process_group, ) @@ -93,6 +94,7 @@ def save_model_and_optim_state( *, process_group: Optional[dist.ProcessGroup] = None, save_overwrite: bool = False, + thread_count: Optional[int] = None, ) -> None: """ Save model and optimizer state dictionaries. The model state can be a sharded model, in which @@ -123,7 +125,7 @@ def save_model_and_optim_state( planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True) dist_cp.state_dict_saver.save( state_dict, - storage_writer=RemoteFileSystemWriter(dir), + storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count), process_group=process_group, planner=planner, ) @@ -137,6 +139,7 @@ def async_save_model_and_optim_state( *, process_group: Optional[dist.ProcessGroup] = None, save_overwrite: bool = False, + thread_count: Optional[int] = None, ) -> Future[None]: """ An async version of :func:`save_model_and_optim_state()`. @@ -148,7 +151,7 @@ def async_save_model_and_optim_state( planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True) return dist_cp.state_dict_saver.async_save( state_dict, - storage_writer=RemoteFileSystemWriter(dir), + storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count), process_group=process_group, planner=planner, ) @@ -164,6 +167,7 @@ def load_model_and_optim_state( key_mapping: Optional[Dict[str, str]] = None, pre_download: bool = False, work_dir: Optional[PathOrStr] = None, + thread_count: Optional[int] = None, ): """ Load model and optimizer state in-place from a checkpoint saved via :func:`save_model_and_optim_state()`. @@ -201,10 +205,13 @@ def load_model_and_optim_state( This dictionary should map current keys to keys in the checkpoint to be loaded. :param pre_download: Download and cache relevant remote checkpoint files before trying to read from them. :param work_dir: A working directory for caching files/directories. + :param thread_count: Set the number of threads used for certain operations. """ dir = normalize_path(dir) state_dict = _prepare_state_dict(model, optim, process_group=process_group) - reader = RemoteFileSystemReader(dir, pre_download=pre_download, work_dir=work_dir) + reader = RemoteFileSystemReader( + dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir + ) if key_mapping is not None: metadata = reader.read_metadata() diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index 93b03fb8..24a0a784 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -92,6 +92,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut "enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0", ) set_env_var("NCCL_SOCKET_IFNAME", "enp0s12") + set_env_var("NCCL_DEBUG_SUBSYS", "INIT,NET") if backend_supports_cuda(backend): # Set CUDA device. diff --git a/src/olmo_core/internal/common.py b/src/olmo_core/internal/common.py index 1c2d426a..d660f0a0 100644 --- a/src/olmo_core/internal/common.py +++ b/src/olmo_core/internal/common.py @@ -102,6 +102,7 @@ def build_launch_config( # Setup python environment. "conda shell.bash activate base", "pip install -e '.[all]'", + "pip install --upgrade beaker-py", # Quickly try a new version of PyTorch like this # "pip install --upgrade --pre torch==2.6.0.dev20241112+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121", "pip freeze", diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index 4d1e9ee7..1015a60d 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -130,6 +130,7 @@ def build_common_components( root_dir=root_dir, cmd=[script, cmd_to_launch, run_name, cluster, *overrides], cluster=cluster, + nccl_debug=False, ) beaker_user = get_beaker_username() diff --git a/src/olmo_core/io.py b/src/olmo_core/io.py index 5fda2741..42ff17f3 100644 --- a/src/olmo_core/io.py +++ b/src/olmo_core/io.py @@ -532,16 +532,25 @@ def _get_gcs_client(): def _gcs_is_retriable(exc: Exception) -> bool: + from google.api_core.exceptions import BadRequest from google.api_core.retry import if_transient_error - return if_transient_error(exc) or isinstance(exc, requests.exceptions.Timeout) + return ( + if_transient_error(exc) + or isinstance(exc, requests.exceptions.Timeout) + or isinstance(exc, BadRequest) # Weird choice, but Google throws this transiently + ) def _get_gcs_retry(): from google.api_core.retry import Retry return Retry( - predicate=_gcs_is_retriable, initial=1.0, maximum=10.0, multiplier=2.0, timeout=600.0 + predicate=_gcs_is_retriable, # NOTE: it appears google might ignore this + initial=1.0, + maximum=10.0, + multiplier=2.0, + timeout=600.0, ) @@ -554,7 +563,7 @@ def _get_gcs_conditional_retry(): return ConditionalRetryPolicy(_get_gcs_retry(), is_generation_specified, ["query_params"]) -@retriable() +@retriable(retry_condition=_gcs_is_retriable) def _gcs_file_size(bucket_name: str, key: str) -> int: from google.api_core.exceptions import NotFound @@ -569,7 +578,7 @@ def _gcs_file_size(bucket_name: str, key: str) -> int: return blob.size -@retriable() +@retriable(retry_condition=_gcs_is_retriable) def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes: from google.api_core.exceptions import NotFound @@ -577,27 +586,43 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) try: - blob.reload() + blob.reload(retry=_get_gcs_retry()) except NotFound: raise FileNotFoundError(f"gs://{bucket_name}/{key}") return blob.download_as_bytes( - start=bytes_start, end=bytes_start + num_bytes - 1, retry=_get_gcs_retry() + start=bytes_start, + end=bytes_start + num_bytes - 1, + retry=_get_gcs_retry(), + checksum=None, # type: ignore ) -@retriable() +@retriable(retry_condition=_gcs_is_retriable) def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False): storage_client = _get_gcs_client() bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) - if not save_overwrite and blob.exists(): - raise FileExistsError( - f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it." - ) - blob.upload_from_filename(source, retry=_get_gcs_conditional_retry()) + generation: int = 0 + if blob.exists(retry=_get_gcs_retry()): + if not save_overwrite: + raise FileExistsError( + f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it." + ) -@retriable() + blob.reload(retry=_get_gcs_retry()) + assert blob.generation is not None + generation = blob.generation + + blob.upload_from_filename( + source, + if_generation_match=generation, + retry=_get_gcs_conditional_retry(), + checksum=None, + ) + + +@retriable(retry_condition=_gcs_is_retriable) def _gcs_clear_directory(bucket_name: str, prefix: str): from google.api_core.exceptions import NotFound diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index 435142f3..7d234d31 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -317,12 +317,23 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec: "#!/usr/bin/env bash", "set -exuo pipefail", "[[ -d /var/lib/tcpxo/lib64 ]] && export LD_LIBRARY_PATH=/var/lib/tcpxo/lib64:$LD_LIBRARY_PATH", + # Setup the kernel cache directory used by pytorch + "mkdir -p /root/.cache/torch/kernels && export PYTORCH_KERNEL_CACHE_PATH=/root/.cache/torch/kernels", "mkdir -p /olmo-core-runtime", "cd /olmo-core-runtime", *self.setup_steps, ] if torchrun: + if any(["augusta" in cluster for cluster in self.clusters]): + entrypoint_script.append( + "export BEAKER_REPLICA_RANK=$(" + "python -m olmo_core.launch.reorder_ranks_in_gcp " + "${BEAKER_REPLICA_RANK} " + "${BEAKER_REPLICA_COUNT} " + "${BEAKER_LEADER_REPLICA_HOSTNAME}" + ")" + ) entrypoint_script.append(" ".join(self._get_torchrun_cmd()) + ' "$@"') else: entrypoint_script.append('python "$@"') @@ -341,7 +352,7 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec: leader_selection=self.num_nodes > 1, host_networking=self.num_nodes > 1 or any(["augusta" in cluster for cluster in self.clusters]), - propagate_failure=True if self.num_nodes > 1 else None, + propagate_failure=False if self.num_nodes > 1 else None, propagate_preemption=True if self.num_nodes > 1 else None, synchronized_start_timeout="90m" if self.num_nodes > 1 else None, resources=TaskResources(gpu_count=self.num_gpus, shared_memory="10GiB"), diff --git a/src/olmo_core/launch/reorder_ranks_in_gcp.py b/src/olmo_core/launch/reorder_ranks_in_gcp.py new file mode 100644 index 00000000..d1381ea2 --- /dev/null +++ b/src/olmo_core/launch/reorder_ranks_in_gcp.py @@ -0,0 +1,70 @@ +import argparse +import sys + +import requests +import torch.distributed as dist +from urllib3.exceptions import MaxRetryError, NameResolutionError + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("rank", type=int, help="Worker number") + parser.add_argument("world_size", type=int, help="Total number of workers") + parser.add_argument("master_addr", help="Hostname of worker 0") + parser.add_argument("--master_port", type=int, default=29501, help="Port for TCPStore") + parser.add_argument("--debug", action="store_true", help="Enable debug mode (outside of GCP)") + args = parser.parse_args() + + # Create or connect to the store + store = dist.TCPStore( + host_name=args.master_addr, + port=args.master_port, + world_size=args.world_size, + is_master=(args.rank == 0), + ) + + # Get our own host id + if args.debug: + import socket + + host_id = f"{socket.gethostname()}_{args.rank}" + else: + try: + response = requests.get( + "http://metadata.google.internal/computeMetadata/v1/instance/attributes/physical_host", + headers={"Metadata-Flavor": "Google"}, + ) + assert response.status_code == 200 + host_id = response.text.strip() + except requests.exceptions.ConnectionError as e: + # Unwrap the exception + e = e.args[0] + if not isinstance(e, MaxRetryError): + raise + e = e.reason + if not isinstance(e, NameResolutionError): + raise + # Seems we called this outside of GCP, so we do nothing and just print our original rank. + print(args.rank) + sys.exit(0) + + # Find the index of our host id + store.set(f"node_{args.rank}_hostid", host_id) + store.wait([f"node_{i}_hostid" for i in range(args.world_size)]) + all_host_ids = [store.get(f"node_{i}_hostid").decode("UTF-8") for i in range(args.world_size)] + assert len(set(all_host_ids)) == len(all_host_ids) + assert host_id in all_host_ids + rank0_host_id = all_host_ids[0] + all_host_ids.sort() + # Rank 0 needs to remain rank 0, so we reshuffle around it + rank0_index = all_host_ids.index(rank0_host_id) + all_host_ids = all_host_ids[rank0_index:] + all_host_ids[:rank0_index] + print(all_host_ids.index(host_id)) + + # Make sure we're all done before exiting + store.set(f"node_{args.rank}_done", host_id) + store.wait([f"node_{i}_done" for i in range(args.world_size)]) + + +if __name__ == "__main__": + main() diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index f77ec4e2..44d34252 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -460,19 +460,22 @@ def olmo2_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": ) @classmethod - def olmo2_26B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": + def olmo2_32B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": """ - A 26B OLMo model config. + A 32B OLMo model config. """ + d_model = 5120 return cls.llama_like( vocab_size=vocab_size, - d_model=7168, - n_layers=kwargs.pop("n_layers", 40), - n_heads=kwargs.pop("n_heads", 56), + d_model=d_model, + n_layers=kwargs.pop("n_layers", 64), + n_heads=kwargs.pop("n_heads", 40), + n_kv_heads=kwargs.pop("n_kv_heads", 8), block_name=kwargs.pop("block_name", TransformerBlockType.reordered_norm), qk_norm=kwargs.pop("qk_norm", True), rope_theta=kwargs.pop("rope_theta", 500_000), - hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 1024), + hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 512), + hidden_size_multiplier=kwargs.pop("hidden_size_multiplier", 27648 / (8 * d_model / 3)), layer_norm_eps=1e-6, **kwargs, ) diff --git a/src/olmo_core/optim/__init__.py b/src/olmo_core/optim/__init__.py index 0e1cf986..c050d5e9 100644 --- a/src/olmo_core/optim/__init__.py +++ b/src/olmo_core/optim/__init__.py @@ -1,5 +1,5 @@ from .adam import AdamConfig -from .adamw import AdamWConfig +from .adamw import AdamWConfig, SkipStepAdamW, SkipStepAdamWConfig from .config import OptimConfig, OptimGroupOverride from .lion import Lion, LionConfig, SkipStepLion, SkipStepLionConfig from .scheduler import ( @@ -18,6 +18,8 @@ "OptimGroupOverride", "SkipStepOptimizer", "AdamWConfig", + "SkipStepAdamWConfig", + "SkipStepAdamW", "AdamConfig", "LionConfig", "Lion", diff --git a/src/olmo_core/optim/adamw.py b/src/olmo_core/optim/adamw.py index bc5f1e46..e4a24c90 100644 --- a/src/olmo_core/optim/adamw.py +++ b/src/olmo_core/optim/adamw.py @@ -1,4 +1,3 @@ -import math from dataclasses import dataclass from typing import Optional, Tuple, Type @@ -6,9 +5,9 @@ import torch.nn as nn from .config import OptimConfig +from .skip_step_optimizer import SkipStepOptimizer -# TODO: use this when we implement a "skip step" version of AdamW. def adamw_step( p: nn.Parameter, *, @@ -18,7 +17,7 @@ def adamw_step( weight_decay: float, exp_avg: torch.Tensor, exp_avg_sq: torch.Tensor, - step: int, + step: torch.Tensor, step_factor: torch.Tensor, ): if p.grad is None: @@ -34,19 +33,87 @@ def adamw_step( exp_avg_sq.mul_(1 - step_factor * (1 - beta2)) exp_avg_sq.add_(step_factor * p.grad * p.grad, alpha=1 - beta2) - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step + bias_correction1 = 1 - beta1 ** (step + 1) + bias_correction2 = 1 - beta2 ** (step + 1) step_size = lr / bias_correction1 - bias_correction2_sqrt = math.sqrt(bias_correction2) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + denom = (exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps) update = -step_size * torch.div(exp_avg, denom) update.mul_(step_factor) p.add_(update) +class SkipStepAdamW(SkipStepOptimizer): + """ + A "skip step" version of :class:`AdamW`. + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + foreach: Optional[bool] = None, + fused: Optional[bool] = None, + rolling_interval_length: int = 128, + sigma_factor: int = 6, + ) -> None: + assert lr > 0.0 + assert all([0.0 <= beta <= 1.0 for beta in betas]) + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, foreach=foreach, fused=fused + ) + super().__init__( + params, + defaults, + rolling_interval_length=rolling_interval_length, + sigma_factor=sigma_factor, + ) + self._step_skipped: Optional[torch.Tensor] = None + + @property + def step_skipped(self) -> torch.Tensor: + if self._step_skipped is not None: + return self._step_skipped + else: + return torch.tensor(0.0) + + @torch.no_grad() + def step(self, closure=None) -> None: + if closure is not None: + with torch.enable_grad(): + closure() + + step_factor = self.get_step_factor() + self._step_skipped = 1 - step_factor + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + state = self.state[p] + if len(state) == 0: + state["step"] = torch.tensor(0.0, dtype=torch.float32, device=p.device) + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + + adamw_step( + p, + lr=group["lr"], + betas=group["betas"], + eps=group["eps"], + weight_decay=group["weight_decay"], + exp_avg=state["exp_avg"], + exp_avg_sq=state["exp_avg_sq"], + step=state["step"], + step_factor=step_factor, + ) + + @dataclass class AdamWConfig(OptimConfig): # NOTE: omagaconf doesn't like "OptimConfig[torch.optim.AdamW]" """ @@ -63,3 +130,21 @@ class AdamWConfig(OptimConfig): # NOTE: omagaconf doesn't like "OptimConfig[tor @classmethod def optimizer(cls) -> Type[torch.optim.AdamW]: return torch.optim.AdamW + + +@dataclass +class SkipStepAdamWConfig(OptimConfig): + """ + Configuration class for building a :class:`SkipStepAdamW` optimizer. + """ + + lr: float = 1e-3 + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 1e-2 + rolling_interval_length: int = 128 + sigma_factor: int = 6 + + @classmethod + def optimizer(cls) -> Type[SkipStepAdamW]: + return SkipStepAdamW diff --git a/src/olmo_core/optim/skip_step_optimizer.py b/src/olmo_core/optim/skip_step_optimizer.py index 98ada1bd..40b0b034 100644 --- a/src/olmo_core/optim/skip_step_optimizer.py +++ b/src/olmo_core/optim/skip_step_optimizer.py @@ -91,17 +91,20 @@ def get_step_factor(self) -> torch.Tensor: The tensor can be used within the optimizer's step computation to essentially skip a step without a host-device sync. """ - if len(self._losses) < max(20, self.rolling_interval_length // 2): + if len(self._losses) < max(2, self.rolling_interval_length // 2): return torch.tensor(1.0).to(device=self.device, non_blocking=True) loss_std, loss_mean = torch.std_mean(torch.stack(self._losses[:-1])) if self._grad_norms: grad_norm_std, grad_norm_mean = torch.std_mean(torch.stack(self._grad_norms[:-1])) - return ((self.latest_loss - loss_mean) <= self.sigma_factor * loss_std) and ( - (self.latest_grad_norm - grad_norm_mean) <= self.sigma_factor * grad_norm_std + step_factor = torch.logical_and( + (self.latest_loss - loss_mean) <= self.sigma_factor * loss_std, + (self.latest_grad_norm - grad_norm_mean) <= self.sigma_factor * grad_norm_std, ) else: - return (self.latest_loss - loss_mean) <= self.sigma_factor * loss_std + step_factor = (self.latest_loss - loss_mean) <= self.sigma_factor * loss_std + + return step_factor.float() @property def step_skipped(self) -> torch.Tensor: diff --git a/src/olmo_core/train/__init__.py b/src/olmo_core/train/__init__.py index ba59008b..e14f3dc7 100644 --- a/src/olmo_core/train/__init__.py +++ b/src/olmo_core/train/__init__.py @@ -75,7 +75,7 @@ def prepare_training_environment( *, seed: Optional[int] = None, backend: Optional[str] = "cpu:gloo,cuda:nccl", - timeout: timedelta = timedelta(minutes=10), + timeout: timedelta = timedelta(minutes=30), log_filter_type: Optional[LogFilterType] = None, ): """ diff --git a/src/olmo_core/train/callbacks/evaluator_callback.py b/src/olmo_core/train/callbacks/evaluator_callback.py index ea2bfa58..556492b7 100644 --- a/src/olmo_core/train/callbacks/evaluator_callback.py +++ b/src/olmo_core/train/callbacks/evaluator_callback.py @@ -129,7 +129,7 @@ def build(self, trainer: "Trainer") -> Optional[Callback]: eval_batch_size = ( self.eval_batch_size if self.eval_batch_size is not None - else trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group) + else 2 * trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group) ) dataset = self.eval_dataset.build() if not isinstance(dataset, NumpyPaddedFSLDataset): diff --git a/src/olmo_core/train/callbacks/grad_clipper.py b/src/olmo_core/train/callbacks/grad_clipper.py index 0a0ebbcb..97ad3b8d 100644 --- a/src/olmo_core/train/callbacks/grad_clipper.py +++ b/src/olmo_core/train/callbacks/grad_clipper.py @@ -4,6 +4,7 @@ import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from olmo_core.distributed.utils import get_local_tensor from olmo_core.optim import SkipStepOptimizer from .callback import Callback @@ -26,6 +27,8 @@ def pre_optim_step(self): self.trainer.model.parameters(), self.max_grad_norm ) + grad_norm = get_local_tensor(grad_norm.detach()) + # NOTE: grad norm is already reduced over ranks, so we set `reduce_type` to `None`. self.trainer.record_metric("optim/total grad norm", grad_norm, reduce_type=None) if isinstance(self.trainer.optim, SkipStepOptimizer): diff --git a/src/olmo_core/train/checkpoint.py b/src/olmo_core/train/checkpoint.py index c018c9ff..29f4f513 100644 --- a/src/olmo_core/train/checkpoint.py +++ b/src/olmo_core/train/checkpoint.py @@ -55,6 +55,8 @@ class CheckpointerConfig(Config): work_dir: Optional[str] = None save_overwrite: Optional[bool] = None pre_download: bool = False + save_thread_count: Optional[int] = None + load_thread_count: Optional[int] = None def build(self, process_group: Optional[dist.ProcessGroup] = None, **kwargs) -> "Checkpointer": kwargs = {**self.as_dict(exclude_none=True, recurse=False), **kwargs} @@ -82,6 +84,8 @@ class Checkpointer: save_overwrite: bool = False pre_download: bool = False process_group: Optional[dist.ProcessGroup] = None + save_thread_count: Optional[int] = None + load_thread_count: Optional[int] = None def __post_init__(self): self.work_dir = Path(self.work_dir) @@ -107,6 +111,7 @@ def save(self, dir: PathOrStr, model: nn.Module, optim: Optimizer, train_state: optim, process_group=self.process_group, save_overwrite=self.save_overwrite, + thread_count=self.save_thread_count, ) self._save_metadata(dir, CheckpointMetadata()) @@ -136,6 +141,7 @@ def save_async( optim, process_group=self.process_group, save_overwrite=self.save_overwrite, + thread_count=self.save_thread_count, ) def done_callback(fut: Future): @@ -210,6 +216,7 @@ def load( key_mapping=key_mapping, pre_download=is_url(dir) and self.pre_download, work_dir=self.work_dir, + thread_count=self.load_thread_count, ) return trainer_state @@ -332,7 +339,7 @@ def _save_train_state(self, dir: PathOrStr, wd: Path, train_state: Dict[str, Any # NOTE: if 'dir' is a URL, the 'wd' will be a different temp dir for each rank. if is_url(dir) or get_fs_local_rank() == 0: train_dir.mkdir(exist_ok=True, parents=True) - wait_for(train_dir.exists, description=f"Waiting on '{train_dir}' to be created...") + wait_for(train_dir.exists, description=f"Waiting for '{train_dir}' to be created...") torch.save(train_state, train_dir / f"rank{get_rank()}.pt") def _save_metadata(self, dir: PathOrStr, metadata: CheckpointMetadata): diff --git a/src/scripts/train/OLMo2-26B.py b/src/scripts/train/OLMo2-26B.py deleted file mode 100644 index 6453407c..00000000 --- a/src/scripts/train/OLMo2-26B.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -Train a 26B OLMo model. Run this script without any arguments to see usage info. -""" - -import logging - -from olmo_core.config import DType -from olmo_core.distributed.parallel import DataParallelType -from olmo_core.float8 import Float8Config -from olmo_core.internal.experiment import CommonComponents, main -from olmo_core.nn.transformer import ( - TransformerActivationCheckpointingConfig, - TransformerActivationCheckpointingMode, - TransformerConfig, - TransformerDataParallelConfig, -) -from olmo_core.optim import AdamWConfig, OptimGroupOverride -from olmo_core.train import TrainerConfig -from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback - -log = logging.getLogger(__name__) - - -def build_model_config(common: CommonComponents) -> TransformerConfig: - compile = True - return TransformerConfig.olmo2_26B( - vocab_size=common.tokenizer.padded_vocab_size(), - compile=compile, - fused_ops=False, - use_flash=not compile, - dp_config=TransformerDataParallelConfig( - name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 - ), - ac_config=TransformerActivationCheckpointingConfig( - mode=TransformerActivationCheckpointingMode.full - ), - float8_config=Float8Config(compile=compile, enabled=False), - ) - - -def build_optim_config(common: CommonComponents) -> AdamWConfig: - del common - return AdamWConfig( - lr=6e-4, - weight_decay=0.1, - betas=(0.9, 0.95), - group_overrides=[ - OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) - ], - fused=True, - ) - - -def build_trainer_config(common: CommonComponents) -> TrainerConfig: - return ( - TrainerConfig( - save_folder=common.save_folder, - rank_microbatch_size=4 * 4096, - save_overwrite=True, - metrics_collect_interval=10, - cancel_check_interval=1, - z_loss_multiplier=1e-5, - compile_loss=True, - ) - .with_callback( - "checkpointer", - CheckpointerCallback( - save_interval=10_000, - ephemeral_save_interval=250, - save_async=True, - ), - ) - .with_callback( - "comet", - CometCallback( - name=common.run_name, - workspace="ai2", - project="OLMo-core-26B", - enabled=True, - cancel_check_interval=10, - ), - ) - .with_callback( - "wandb", - WandBCallback( - name=common.run_name, - entity="ai2-llm", - project="OLMo-core-26B", - enabled=False, - cancel_check_interval=10, - ), - ) - ) - - -if __name__ == "__main__": - main( - global_batch_size=2048 * 4096, - model_config_builder=build_model_config, - optim_config_builder=build_optim_config, - trainer_config_builder=build_trainer_config, - ) diff --git a/src/test/optim/adamw_test.py b/src/test/optim/adamw_test.py index 5756f9a6..a792ace9 100644 --- a/src/test/optim/adamw_test.py +++ b/src/test/optim/adamw_test.py @@ -1,7 +1,10 @@ +from test.utils import DEVICES + +import pytest import torch import torch.nn as nn -from olmo_core.optim import AdamWConfig, OptimGroupOverride +from olmo_core.optim import AdamWConfig, OptimGroupOverride, SkipStepAdamWConfig class MyModel(nn.Module): @@ -43,3 +46,33 @@ def test_adamw_config_to_optim_with_group_overrides(): for group in optim.param_groups: assert "initial_lr" in group + + +@pytest.mark.parametrize("device", DEVICES) +def test_adamw(device: torch.device): + config = AdamWConfig() + model = MyModel().train().to(device) + optim = config.build(model) + + for group in optim.param_groups: + assert "initial_lr" in group + + # Take a step. + optim.zero_grad(set_to_none=True) + model(torch.randint(0, 1024, (2, 8), device=device).int()).sum().backward() + optim.step() + + +@pytest.mark.parametrize("device", DEVICES) +def test_skip_step_adamw(device: torch.device): + config = SkipStepAdamWConfig() + model = MyModel().train().to(device) + optim = config.build(model) + + for group in optim.param_groups: + assert "initial_lr" in group + + # Take a step. + optim.zero_grad(set_to_none=True) + model(torch.randint(0, 1024, (2, 8), device=device).int()).sum().backward() + optim.step()