Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Jun 27, 2024
1 parent 8ba8eb9 commit 142dba6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
5 changes: 4 additions & 1 deletion src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ def setup_model(args, tokenizer, train_loader, grad_accum):
)

if args.is_granite:
with ensure_loadable_granite_checkpoint(args.model_name_or_path) as path:
with ensure_loadable_granite_checkpoint(
args.model_name_or_path,
args.output_dir
) as path:
model = GPTDolomiteForCausalLM.from_pretrained(
path,
attn_implementation="flash_attention_2",
Expand Down
18 changes: 11 additions & 7 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,12 +482,10 @@ class UniversalCheckpointArgs:


@contextmanager
def ensure_loadable_granite_checkpoint(model_name_or_path: str):

def temp_filename():
# previously we used mktemp, but it caused problems in multi node settings
# so now we just fix it and hope it does not clash
return "/tmp/tmpin9f06ge"
def ensure_loadable_granite_checkpoint(
model_name_or_path: str,
tmpdir: str,
):

# this has to be done per node, so we use local rank
local_rank = int(os.environ["LOCAL_RANK"])
Expand All @@ -502,7 +500,13 @@ def temp_filename():
# if the load failed then it must not be a granite
# for now just assume its a llama
# make a temp directory name, but do not create it
tmpdir = temp_filename()
# previously we used mktemp, but it caused problems in multi node settings
# so now we use a provided
tmpdir = Path(tmpdir) / 'tmp'
if os.path.exists(tmpdir):
# need to delete if it exists because import doesnt like it to
shutil.rmtree(tmpdir, ignore_errors=True)

if not dist.is_initialized() or local_rank == 0:
import_from_huggingface(model_name_or_path, tmpdir)
if dist.is_initialized():
Expand Down

0 comments on commit 142dba6

Please sign in to comment.