Skip to content

Commit

Permalink
Missing update
Browse files Browse the repository at this point in the history
Signed-off-by: Mustafa Eyceoz <[email protected]>
  • Loading branch information
Maxusmusti committed Oct 14, 2024
1 parent 4a90e6d commit 7ead6af
Showing 1 changed file with 33 additions and 14 deletions.
47 changes: 33 additions & 14 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from copy import deepcopy
from pathlib import Path
import argparse
import json
import math
import os
import re
Expand Down Expand Up @@ -650,23 +651,35 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)

if train_args.process_data:
dp.main(
DataProcessArgs(
# XXX(osilkin): make a decision here, either:
# 1. the CLI is fully responsible for managing where the data is written
# 2. we never cache it and simply write it to a tmp file every time.
#
# An important reason for why #1 would be preferable is in the case of OpenShift/SELinux
# where the user has a defined place for new temporary data to be written.
data_output_path=train_args.data_output_dir,
model_path=train_args.model_path,
data_path=train_args.data_path,
max_seq_len=train_args.max_seq_len,
chat_tmpl_path=train_args.chat_tmpl_path,
if os.path.exists(train_args.model_path):
if not os.path.isdir(train_args.model_path):
raise RuntimeError(
"Model path does not appear to be a dir, please validate or update the path"
)
else:
raise RuntimeError(
"Model Path cannot be found, please verify existense and permissions"
)

# process the training data
if not os.path.exists(train_args.data_output_dir):
os.makedirs(train_args.data_output_dir, exist_ok=True)
dp.main(
DataProcessArgs(
# XXX(osilkin): make a decision here, either:
# 1. the CLI is fully responsible for managing where the data is written
# 2. we never cache it and simply write it to a tmp file every time.
#
# An important reason for why #1 would be preferable is in the case of OpenShift/SELinux
# where the user has a defined place for new temporary data to be written.
data_output_path=train_args.data_output_dir,
model_path=train_args.model_path,
data_path=train_args.data_path,
max_seq_len=train_args.max_seq_len,
chat_tmpl_path=train_args.chat_tmpl_path,
)
)

if not os.path.exists(train_args.ckpt_output_dir):
os.makedirs(train_args.ckpt_output_dir, exist_ok=True)
command = [
Expand Down Expand Up @@ -703,6 +716,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
command.append(f"--mock_len={train_args.mock_len}")

if train_args.use_dolomite:
with open(Path(train_args.model_path) / "config.json") as conf_json:
model_conf = json.load(conf_json)
if model_conf["model_type"] == "granite":
raise RuntimeError(
"Converting Granite models to Dolomite format is currently unsupported."
)
command.append("--use_dolomite")

if train_args.disable_flash_attn:
Expand Down

0 comments on commit 7ead6af

Please sign in to comment.