Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pull fixes from 32B branch #139

Merged
merged 5 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`.
- Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint.
- Added a callback for sending Slack notifications.
- Added `SkipStepAdamW` optimizer.
- The trainer can load model-only checkpoints now.

### Changed
Expand All @@ -25,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Added missing `weights_only=False` argument to fix loading train checkpoints with newer versions of PyTorch.
- Fixed bug where GCS upload does not retry on transient failures.

## [v1.7.0](https://github.com/allenai/OLMo-core/releases/tag/v1.7.0) - 2024-11-27

Expand Down
15 changes: 11 additions & 4 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def save_state_dict(
state_dict: Dict[str, Any],
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
):
"""
Save an arbitrary state dictionary to a distributed format that can loaded again with
Expand All @@ -80,7 +81,7 @@ def save_state_dict(
dir = _prepare_env_for_save(dir, process_group=process_group, save_overwrite=save_overwrite)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count),
process_group=process_group,
)

Expand All @@ -93,6 +94,7 @@ def save_model_and_optim_state(
*,
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
) -> None:
"""
Save model and optimizer state dictionaries. The model state can be a sharded model, in which
Expand Down Expand Up @@ -123,7 +125,7 @@ def save_model_and_optim_state(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count),
process_group=process_group,
planner=planner,
)
Expand All @@ -137,6 +139,7 @@ def async_save_model_and_optim_state(
*,
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
) -> Future[None]:
"""
An async version of :func:`save_model_and_optim_state()`.
Expand All @@ -148,7 +151,7 @@ def async_save_model_and_optim_state(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
return dist_cp.state_dict_saver.async_save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(dir, thread_count=thread_count),
process_group=process_group,
planner=planner,
)
Expand All @@ -164,6 +167,7 @@ def load_model_and_optim_state(
key_mapping: Optional[Dict[str, str]] = None,
pre_download: bool = False,
work_dir: Optional[PathOrStr] = None,
thread_count: Optional[int] = None,
):
"""
Load model and optimizer state in-place from a checkpoint saved via :func:`save_model_and_optim_state()`.
Expand Down Expand Up @@ -201,10 +205,13 @@ def load_model_and_optim_state(
This dictionary should map current keys to keys in the checkpoint to be loaded.
:param pre_download: Download and cache relevant remote checkpoint files before trying to read from them.
:param work_dir: A working directory for caching files/directories.
:param thread_count: Set the number of threads used for certain operations.
"""
dir = normalize_path(dir)
state_dict = _prepare_state_dict(model, optim, process_group=process_group)
reader = RemoteFileSystemReader(dir, pre_download=pre_download, work_dir=work_dir)
reader = RemoteFileSystemReader(
dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir
)

if key_mapping is not None:
metadata = reader.read_metadata()
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut
"enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0",
)
set_env_var("NCCL_SOCKET_IFNAME", "enp0s12")
set_env_var("NCCL_DEBUG_SUBSYS", "INIT,NET")

if backend_supports_cuda(backend):
# Set CUDA device.
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/internal/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def build_launch_config(
# Setup python environment.
"conda shell.bash activate base",
"pip install -e '.[all]'",
"pip install --upgrade beaker-py",
# Quickly try a new version of PyTorch like this
# "pip install --upgrade --pre torch==2.6.0.dev20241112+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121",
"pip freeze",
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def build_common_components(
root_dir=root_dir,
cmd=[script, cmd_to_launch, run_name, cluster, *overrides],
cluster=cluster,
nccl_debug=False,
)

beaker_user = get_beaker_username()
Expand Down
51 changes: 38 additions & 13 deletions src/olmo_core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,16 +532,25 @@ def _get_gcs_client():


def _gcs_is_retriable(exc: Exception) -> bool:
from google.api_core.exceptions import BadRequest
from google.api_core.retry import if_transient_error

return if_transient_error(exc) or isinstance(exc, requests.exceptions.Timeout)
return (
if_transient_error(exc)
or isinstance(exc, requests.exceptions.Timeout)
or isinstance(exc, BadRequest) # Weird choice, but Google throws this transiently
)


def _get_gcs_retry():
from google.api_core.retry import Retry

return Retry(
predicate=_gcs_is_retriable, initial=1.0, maximum=10.0, multiplier=2.0, timeout=600.0
predicate=_gcs_is_retriable, # NOTE: it appears google might ignore this
initial=1.0,
maximum=10.0,
multiplier=2.0,
timeout=600.0,
)


Expand All @@ -554,7 +563,7 @@ def _get_gcs_conditional_retry():
return ConditionalRetryPolicy(_get_gcs_retry(), is_generation_specified, ["query_params"])


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_file_size(bucket_name: str, key: str) -> int:
from google.api_core.exceptions import NotFound

Expand All @@ -569,35 +578,51 @@ def _gcs_file_size(bucket_name: str, key: str) -> int:
return blob.size


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
from google.api_core.exceptions import NotFound

storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
try:
blob.reload()
blob.reload(retry=_get_gcs_retry())
except NotFound:
raise FileNotFoundError(f"gs://{bucket_name}/{key}")
return blob.download_as_bytes(
start=bytes_start, end=bytes_start + num_bytes - 1, retry=_get_gcs_retry()
start=bytes_start,
end=bytes_start + num_bytes - 1,
retry=_get_gcs_retry(),
checksum=None, # type: ignore
)


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
if not save_overwrite and blob.exists():
raise FileExistsError(
f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)
blob.upload_from_filename(source, retry=_get_gcs_conditional_retry())

generation: int = 0
if blob.exists(retry=_get_gcs_retry()):
if not save_overwrite:
raise FileExistsError(
f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)

@retriable()
blob.reload(retry=_get_gcs_retry())
assert blob.generation is not None
generation = blob.generation

blob.upload_from_filename(
source,
if_generation_match=generation,
retry=_get_gcs_conditional_retry(),
checksum=None,
)


@retriable(retry_condition=_gcs_is_retriable)
def _gcs_clear_directory(bucket_name: str, prefix: str):
from google.api_core.exceptions import NotFound

Expand Down
13 changes: 12 additions & 1 deletion src/olmo_core/launch/beaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,23 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec:
"#!/usr/bin/env bash",
"set -exuo pipefail",
"[[ -d /var/lib/tcpxo/lib64 ]] && export LD_LIBRARY_PATH=/var/lib/tcpxo/lib64:$LD_LIBRARY_PATH",
# Setup the kernel cache directory used by pytorch
"mkdir -p /root/.cache/torch/kernels && export PYTORCH_KERNEL_CACHE_PATH=/root/.cache/torch/kernels",
"mkdir -p /olmo-core-runtime",
"cd /olmo-core-runtime",
*self.setup_steps,
]

if torchrun:
if any(["augusta" in cluster for cluster in self.clusters]):
entrypoint_script.append(
"export BEAKER_REPLICA_RANK=$("
"python -m olmo_core.launch.reorder_ranks_in_gcp "
"${BEAKER_REPLICA_RANK} "
"${BEAKER_REPLICA_COUNT} "
"${BEAKER_LEADER_REPLICA_HOSTNAME}"
")"
)
entrypoint_script.append(" ".join(self._get_torchrun_cmd()) + ' "$@"')
else:
entrypoint_script.append('python "$@"')
Expand All @@ -341,7 +352,7 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec:
leader_selection=self.num_nodes > 1,
host_networking=self.num_nodes > 1
or any(["augusta" in cluster for cluster in self.clusters]),
propagate_failure=True if self.num_nodes > 1 else None,
propagate_failure=False if self.num_nodes > 1 else None,
propagate_preemption=True if self.num_nodes > 1 else None,
synchronized_start_timeout="90m" if self.num_nodes > 1 else None,
resources=TaskResources(gpu_count=self.num_gpus, shared_memory="10GiB"),
Expand Down
70 changes: 70 additions & 0 deletions src/olmo_core/launch/reorder_ranks_in_gcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import argparse
import sys

import requests
import torch.distributed as dist
from urllib3.exceptions import MaxRetryError, NameResolutionError


def main():
parser = argparse.ArgumentParser()
parser.add_argument("rank", type=int, help="Worker number")
parser.add_argument("world_size", type=int, help="Total number of workers")
parser.add_argument("master_addr", help="Hostname of worker 0")
parser.add_argument("--master_port", type=int, default=29501, help="Port for TCPStore")
parser.add_argument("--debug", action="store_true", help="Enable debug mode (outside of GCP)")
args = parser.parse_args()

# Create or connect to the store
store = dist.TCPStore(
host_name=args.master_addr,
port=args.master_port,
world_size=args.world_size,
is_master=(args.rank == 0),
)

# Get our own host id
if args.debug:
import socket

host_id = f"{socket.gethostname()}_{args.rank}"
else:
try:
response = requests.get(
"http://metadata.google.internal/computeMetadata/v1/instance/attributes/physical_host",
headers={"Metadata-Flavor": "Google"},
)
assert response.status_code == 200
host_id = response.text.strip()
except requests.exceptions.ConnectionError as e:
# Unwrap the exception
e = e.args[0]
if not isinstance(e, MaxRetryError):
raise
e = e.reason
if not isinstance(e, NameResolutionError):
raise
# Seems we called this outside of GCP, so we do nothing and just print our original rank.
print(args.rank)
sys.exit(0)

# Find the index of our host id
store.set(f"node_{args.rank}_hostid", host_id)
store.wait([f"node_{i}_hostid" for i in range(args.world_size)])
all_host_ids = [store.get(f"node_{i}_hostid").decode("UTF-8") for i in range(args.world_size)]
assert len(set(all_host_ids)) == len(all_host_ids)
assert host_id in all_host_ids
rank0_host_id = all_host_ids[0]
all_host_ids.sort()
# Rank 0 needs to remain rank 0, so we reshuffle around it
rank0_index = all_host_ids.index(rank0_host_id)
all_host_ids = all_host_ids[rank0_index:] + all_host_ids[:rank0_index]
print(all_host_ids.index(host_id))

# Make sure we're all done before exiting
store.set(f"node_{args.rank}_done", host_id)
store.wait([f"node_{i}_done" for i in range(args.world_size)])


if __name__ == "__main__":
main()
15 changes: 9 additions & 6 deletions src/olmo_core/nn/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,19 +460,22 @@ def olmo2_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
)

@classmethod
def olmo2_26B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
def olmo2_32B(cls, vocab_size: int, **kwargs) -> "TransformerConfig":
"""
A 26B OLMo model config.
A 32B OLMo model config.
"""
d_model = 5120
return cls.llama_like(
vocab_size=vocab_size,
d_model=7168,
n_layers=kwargs.pop("n_layers", 40),
n_heads=kwargs.pop("n_heads", 56),
d_model=d_model,
n_layers=kwargs.pop("n_layers", 64),
n_heads=kwargs.pop("n_heads", 40),
n_kv_heads=kwargs.pop("n_kv_heads", 8),
block_name=kwargs.pop("block_name", TransformerBlockType.reordered_norm),
qk_norm=kwargs.pop("qk_norm", True),
rope_theta=kwargs.pop("rope_theta", 500_000),
hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 1024),
hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 512),
hidden_size_multiplier=kwargs.pop("hidden_size_multiplier", 27648 / (8 * d_model / 3)),
layer_norm_eps=1e-6,
**kwargs,
)
Expand Down
4 changes: 3 additions & 1 deletion src/olmo_core/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .adam import AdamConfig
from .adamw import AdamWConfig
from .adamw import AdamWConfig, SkipStepAdamW, SkipStepAdamWConfig
from .config import OptimConfig, OptimGroupOverride
from .lion import Lion, LionConfig, SkipStepLion, SkipStepLionConfig
from .scheduler import (
Expand All @@ -18,6 +18,8 @@
"OptimGroupOverride",
"SkipStepOptimizer",
"AdamWConfig",
"SkipStepAdamWConfig",
"SkipStepAdamW",
"AdamConfig",
"LionConfig",
"Lion",
Expand Down
Loading
Loading