Skip to content

Commit

Permalink
use global rank
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jun 27, 2024
1 parent 142dba6 commit db7a751
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,6 @@ def ensure_loadable_granite_checkpoint(
tmpdir: str,
):

# this has to be done per node, so we use local rank
local_rank = int(os.environ["LOCAL_RANK"])
try:
GPTDolomiteConfig.from_pretrained(model_name_or_path)
yield model_name_or_path
Expand All @@ -501,18 +499,31 @@ def ensure_loadable_granite_checkpoint(
# for now just assume its a llama
# make a temp directory name, but do not create it
# previously we used mktemp, but it caused problems in multi node settings
# so now we use a provided
# so now we use a provided tmpdir
# Assumption: tmpdir should be accessible by all ranks, even those
# in different nodes
tmpdir = Path(tmpdir) / 'tmp'
if os.path.exists(tmpdir):
# need to delete if it exists because import doesnt like it to
shutil.rmtree(tmpdir, ignore_errors=True)

if not dist.is_initialized() or local_rank == 0:
if not dist.is_initialized() or dist.get_rank() == 0:
import_from_huggingface(model_name_or_path, tmpdir)

if dist.is_initialized():
# the first barrier is to wait for rank 0 to finish converting the model

Check warning on line 514 in src/instructlab/training/utils.py

View workflow job for this annotation

GitHub Actions / lint

C0303: Trailing whitespace (trailing-whitespace)
# and place into tmpdir
dist.barrier()

# return tmpdir out for loading
yield tmpdir
if not dist.is_initialized() or local_rank == 0:

if dist.is_initialized():
# the second barrier is to wait for all the models to finish loading
dist.barrier()

if not dist.is_initialized() or dist.get_rank() == 0:
# at this point, we can be confident that the tmpdir is no longer needed
shutil.rmtree(tmpdir, ignore_errors=True)


Expand Down

0 comments on commit db7a751

Please sign in to comment.