Skip to content

Commit

Permalink
fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 17, 2025
2 parents 35aca24 + 48abe8c commit 12cffc9
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 38 deletions.
10 changes: 2 additions & 8 deletions .github/RELEASE_PROCESS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,12 @@
3. Run the release script:

```bash
./src/scripts/release.sh
./src/scripts/release/release.sh
```

This will commit the changes to the CHANGELOG and `version.py` files and then create a new tag in git
which will trigger a workflow on GitHub Actions that handles the rest.

## Fixing a failed release

If for some reason the GitHub Actions release workflow failed with an error that needs to be fixed, you'll have to delete both the tag and corresponding release from GitHub. After you've pushed a fix, delete the tag from your local clone with

```bash
git tag -l | xargs git tag -d && git fetch -t
```

Then repeat the steps above.
If for some reason the GitHub Actions release workflow failed with an error that needs to be fixed, you'll have to delete the tag on GitHub. Once you've pushed a fix you can simply repeat the steps above.
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ jobs:
- name: Generate release notes
run: |
. .venv/bin/activate
python src/scripts/release_notes.py > ${{ github.workspace }}-RELEASE_NOTES.md
python src/scripts/release/release_notes.py > ${{ github.workspace }}-RELEASE_NOTES.md
- name: Publish package to PyPI
run: |
Expand All @@ -262,4 +262,4 @@ jobs:
env:
GH_TOKEN: ${{ github.token }}
run: |
./scripts/add_pr_comments_on_release.sh
./src/scripts/release/add_pr_comments_on_release.sh
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ This major release introduces a few breaking changes. As such, we've provided an
- Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`.
- Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint.
- Added a callback for sending Slack notifications.
- The trainer can load model-only checkpoints now.

### Changed

Expand Down
54 changes: 40 additions & 14 deletions src/olmo_core/train/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
import torch.distributed as dist
from cached_path import cached_path
from torch.distributed.checkpoint.metadata import Metadata

from ..aliases import PathOrStr
from ..config import Config
Expand All @@ -21,7 +22,13 @@
load_state_dict,
save_state_dict,
)
from ..distributed.utils import barrier, get_fs_local_rank, get_rank, is_distributed
from ..distributed.utils import (
barrier,
get_fs_local_rank,
get_rank,
is_distributed,
scatter_object,
)
from ..exceptions import OLMoConfigurationError
from ..io import (
clear_directory,
Expand Down Expand Up @@ -141,7 +148,7 @@ def load(
dir: PathOrStr,
train_module: TrainModule,
*,
load_trainer_state: bool = True,
load_trainer_state: Optional[bool] = None,
) -> Optional[Dict[str, Any]]:
"""
Load model, optim, and other training state from a local or remote checkpoint directory
Expand All @@ -151,21 +158,38 @@ def load(

# Maybe load trainer state.
trainer_state: Optional[Dict[str, Any]] = None
if load_trainer_state:
try:
trainer_state = torch.load(
cached_path(f"{dir}/train/rank{get_rank()}.pt", quiet=True), weights_only=False
)
except FileNotFoundError:
# Fall back to rank 0 train state.
# This can happen when we're restoring a checkpoint with a different world size.
trainer_state = torch.load(
cached_path(f"{dir}/train/rank0.pt", quiet=True), weights_only=False
)
if load_trainer_state is not False:
# Try loading the given rank's state first, then fall back to rank 0 train state if it
# doesn't exist, which can happen when we're restoring a checkpoint with a different world size.
for path in (f"{dir}/train/rank{get_rank()}.pt", f"{dir}/train/rank0.pt"):
try:
trainer_state = torch.load(cached_path(path, quiet=True), weights_only=False)
break
except FileNotFoundError:
pass

if load_trainer_state is True and trainer_state is None:
raise FileNotFoundError(f"Missing trainer state in checkpoint dir '{dir}'")

# Load train module state.
train_module_dir = f"{dir}/model_and_optim"
metadata = get_checkpoint_metadata(train_module_dir)
metadata: Optional[Metadata] = None
if get_rank(self.process_group) == 0:
try:
metadata = get_checkpoint_metadata(train_module_dir)
except FileNotFoundError:
# Try base directory, which could be the case if user is trying to load model weights
# (possibly with optimizer state), and not an actual train checkpoint.
if trainer_state is None:
metadata = get_checkpoint_metadata(dir)
train_module_dir = dir
else:
raise

train_module_dir = scatter_object(train_module_dir)
if metadata is None:
metadata = get_checkpoint_metadata(train_module_dir)

state_dict = train_module.state_dict_to_load(metadata)
load_state_dict(
train_module_dir,
Expand Down Expand Up @@ -228,6 +252,8 @@ def dir_is_checkpoint(cls, dir: PathOrStr) -> bool:
Check if a directory is a checkpoint directory.
"""
dir = normalize_path(dir)
if file_exists(f"{dir}/.metadata"): # just model (and maybe optim state), no trainer state
return True
paths_to_check = [
f"{dir}/train/rank0.pt",
f"{dir}/model_and_optim/.metadata",
Expand Down
25 changes: 18 additions & 7 deletions src/olmo_core/train/train_module/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,16 @@ def state_dict_to_load(self, metadata: Metadata) -> Dict[str, Any]:
if self.load_key_mapping is not None:
_swap_param_keys(state_dict, self.load_key_mapping, metadata=metadata)

has_optim_state: bool = False
for key in metadata.state_dict_metadata.keys():
if key.startswith("optim."):
has_optim_state = True
break

if not has_optim_state:
del state_dict["optim"]
log.warning("No optimizer state found in checkpoint")

return state_dict

def state_dict_to_save(self) -> Dict[str, Any]:
Expand All @@ -700,13 +710,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
options=self.state_dict_load_opts,
)
gc_cuda()
dist_cp_sd.set_optimizer_state_dict(
model,
optim,
state_dict["optim"],
options=self.state_dict_load_opts,
)
gc_cuda()
if "optim" in state_dict:
dist_cp_sd.set_optimizer_state_dict(
model,
optim,
state_dict["optim"],
options=self.state_dict_load_opts,
)
gc_cuda()

def train_batch(self, batch: Dict[str, Any], dry_run: bool = False):
# Set model to train mode if it isn't already.
Expand Down
9 changes: 5 additions & 4 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def load_state_dict(self, state_dict: TrainerStateDict):
"were saved with a different world size."
)

def load_checkpoint(self, dir: PathOrStr, *, load_trainer_state: bool = True):
def load_checkpoint(self, dir: PathOrStr, *, load_trainer_state: Optional[bool] = None):
"""
Load a checkpoint.
Expand All @@ -630,8 +630,7 @@ def load_checkpoint(self, dir: PathOrStr, *, load_trainer_state: bool = True):
self.train_module,
load_trainer_state=load_trainer_state,
)
if load_trainer_state:
assert trainer_state is not None
if trainer_state is not None:
self.load_state_dict(cast(TrainerStateDict, trainer_state))

for callback in self.callbacks.values():
Expand All @@ -640,7 +639,9 @@ def load_checkpoint(self, dir: PathOrStr, *, load_trainer_state: bool = True):
self._checkpoint_loaded = True
log.info("Checkpoint successfully loaded")

def maybe_load_checkpoint(self, dir: PathOrStr, *, load_trainer_state: bool = True) -> bool:
def maybe_load_checkpoint(
self, dir: PathOrStr, *, load_trainer_state: Optional[bool] = None
) -> bool:
"""
Like :meth:`load_checkpoint()` but is a no-op if there is no checkpoint in the ``dir`` provided.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
set -e

repo_url=https://github.com/allenai/OLMo-core

tags=$(git tag -l --sort=-version:refname 'v*' | head -n 2)
current_tag=$(echo "$tags" | head -n 1)
last_tag=$(echo "$tags" | tail -n 1)
Expand Down
File renamed without changes.
9 changes: 6 additions & 3 deletions src/scripts/release.sh → src/scripts/release/release.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

set -e

TAG=$(python -c 'from olmo_core.version import VERSION; print("v" + VERSION)')
# Make sure clone is up-to-date with remote.
git pull > /dev/null
git tag -l | xargs git tag -d > /dev/null
git fetch -t > /dev/null

git pull
TAG=$(python -c 'from olmo_core.version import VERSION; print("v" + VERSION)')

# Make sure tag/release doesn't already exist.
STATUS_CODE=$(curl -s -o /dev/null -w "%{http_code}" "https://github.com/allenai/OLMo-core/releases/tag/${TAG}")
Expand All @@ -13,7 +16,7 @@ if [[ $STATUS_CODE == "200" ]]; then
exit 1
fi

python src/scripts/prepare_changelog.py
python src/scripts/release/prepare_changelog.py

read -rp "Creating new release for $TAG. Do you want to continue? [Y/n] " prompt

Expand Down
File renamed without changes.

0 comments on commit 12cffc9

Please sign in to comment.