Skip to content

Commit

Permalink
Add checkpoint config num_to_keep for pretrainers (intel#39)
Browse files Browse the repository at this point in the history
* add num_to_keep for pretrainers

Signed-off-by: Zhi Lin <[email protected]>

* add num_to_keep to config

Signed-off-by: Zhi Lin <[email protected]>

---------

Signed-off-by: Zhi Lin <[email protected]>
  • Loading branch information
kira-lin authored Sep 22, 2023
1 parent 71cc3ce commit 005b27a
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
1 change: 1 addition & 0 deletions pretrain/config/bloom1b7_8gpus_pretrain.conf
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"checkpoint": {
# The root path of checkpoint. Only absolute path is supported
"root_path": "/tmp/llm-ray/checkpoint",
"num_to_keep": 10
}
},
# Ray related configuration, Only used when mode is set to ray
Expand Down
1 change: 1 addition & 0 deletions pretrain/megatron_pretrain_template.conf
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"checkpoint": {
# The root path of checkpoint. Only absolute path is supported
"root_path": "/tmp/llm-ray/checkpoint",
"num_to_keep": 10
}
},
# Ray related configuration, Only used when mode is set to ray
Expand Down
14 changes: 14 additions & 0 deletions pretrain/plugin/megatron_pretrainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import math
import time
import shutil

import torch
import transformers
Expand Down Expand Up @@ -251,13 +252,24 @@ def _save_done(self, root_path, step):
donefile = self._get_local_donefile_path(root_path, step)
Checkpoint.to_directory(placeholder, donefile)

def _remove_stale_checkpoint(self, root_path, num_to_keep):
checkpoints = self._get_all_checkpoint_step(root_path)
if len(checkpoints) > num_to_keep:
stale = checkpoints[-1]
logger.warning("Removing stale checkpoint")
shutil.rmtree(f"{root_path}/{stale}")

def save(self, config, step):
if config is None or config is {}:
logger.warning(f"checkpoint is empty, skip")
return
root_path = config.get("root_path")
if root_path is None:
logger.warning(f"checkpoint root_path is empty, skip")
num_to_keep = config.get("num_to_keep")
if num_to_keep <= 0:
logger.warning(f"checkpoint num_to_keep cannot be zero, ignored")
num_to_keep = None
local_checkpoint_path = self._get_local_path(root_path, step)
if self.mode == "ddp":
if int(self.rank) == 0:
Expand All @@ -267,3 +279,5 @@ def save(self, config, step):
else:
pass
self._save_done(root_path, step)
if num_to_keep > 0 and self.mode == "ddp" and int(self.rank) == 0:
self._remove_stale_checkpoint(root_path, num_to_keep)
13 changes: 13 additions & 0 deletions pretrain/plugin/pretrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,24 @@ def _save_done(self, root_path, episode):
donefile = self._get_local_donefile_path(root_path, episode)
Checkpoint.to_directory(placeholder, donefile)

def _remove_stale_checkpoint(self, root_path, num_to_keep):
checkpoints = self._get_all_checkpoint_episode(root_path)
if len(checkpoints) > num_to_keep:
stale = checkpoints[-1]
logger.warning("Removing stale checkpoint")
shutil.rmtree(f"{root_path}/{stale}")

def save(self, config, episode):
if config is None or config is {}:
logger.warning(f"checkpoint is empty, skip")
return
root_path = config.get("root_path")
if root_path is None:
logger.warning(f"checkpoint root_path is empty, skip")
num_to_keep = config.get("num_to_keep")
if num_to_keep <= 0:
logger.warning(f"checkpoint num_to_keep cannot be zero, ignored")
num_to_keep = None
local_checkpoint_path = self._get_local_path(root_path, episode)
if self.mode == "ddp":
if int(self.rank) == 0:
Expand All @@ -305,3 +316,5 @@ def save(self, config, episode):
else:
pass
self._save_done(root_path, episode)
if num_to_keep > 0 and self.mode == "ddp"and int(self.rank) == 0:
self._remove_stale_checkpoint(root_path, num_to_keep)

0 comments on commit 005b27a

Please sign in to comment.