diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 6f76ccb7..61b83bc4 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -487,8 +487,6 @@ def ensure_loadable_granite_checkpoint( tmpdir: str, ): - # this has to be done per node, so we use local rank - local_rank = int(os.environ["LOCAL_RANK"]) try: GPTDolomiteConfig.from_pretrained(model_name_or_path) yield model_name_or_path @@ -501,18 +499,31 @@ def ensure_loadable_granite_checkpoint( # for now just assume its a llama # make a temp directory name, but do not create it # previously we used mktemp, but it caused problems in multi node settings - # so now we use a provided + # so now we use a provided tmpdir + # Assumption: tmpdir should be accessible by all ranks, even those + # in different nodes tmpdir = Path(tmpdir) / 'tmp' if os.path.exists(tmpdir): # need to delete if it exists because import doesnt like it to shutil.rmtree(tmpdir, ignore_errors=True) - if not dist.is_initialized() or local_rank == 0: + if not dist.is_initialized() or dist.get_rank() == 0: import_from_huggingface(model_name_or_path, tmpdir) + if dist.is_initialized(): + # the first barrier is to wait for rank 0 to finish converting the model + # and place into tmpdir dist.barrier() + + # return tmpdir out for loading yield tmpdir - if not dist.is_initialized() or local_rank == 0: + + if dist.is_initialized(): + # the second barrier is to wait for all the models to finish loading + dist.barrier() + + if not dist.is_initialized() or dist.get_rank() == 0: + # at this point, we can be confident that the tmpdir is no longer needed shutil.rmtree(tmpdir, ignore_errors=True)