Skip to content

Commit

Permalink
Clean up early validation checks and move to utils
Browse files Browse the repository at this point in the history
Signed-off-by: Mustafa Eyceoz <[email protected]>
  • Loading branch information
Maxusmusti committed Oct 14, 2024
1 parent 7ead6af commit dba7b29
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 26 deletions.
28 changes: 2 additions & 26 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from copy import deepcopy
from pathlib import Path
import argparse
import json
import math
import os
import re
Expand Down Expand Up @@ -43,6 +42,7 @@
add_noisy_embeddings,
apply_gradient_checkpointing,
check_flash_attn_enabled,
check_valid_train_args,
convert_loss_to_reduce_sum,
ensure_loadable_dolomite_checkpoint,
get_projection_layer_names,
Expand Down Expand Up @@ -645,21 +645,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
"""
Wrapper around the main training job that calls torchrun.
"""
# early validation logic here
if train_args.max_batch_len < train_args.max_seq_len:
raise ValueError(
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)

if os.path.exists(train_args.model_path):
if not os.path.isdir(train_args.model_path):
raise RuntimeError(
"Model path does not appear to be a dir, please validate or update the path"
)
else:
raise RuntimeError(
"Model Path cannot be found, please verify existense and permissions"
)
check_valid_train_args(train_args)

# process the training data
if not os.path.exists(train_args.data_output_dir):
Expand Down Expand Up @@ -716,19 +702,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
command.append(f"--mock_len={train_args.mock_len}")

if train_args.use_dolomite:
with open(Path(train_args.model_path) / "config.json") as conf_json:
model_conf = json.load(conf_json)
if model_conf["model_type"] == "granite":
raise RuntimeError(
"Converting Granite models to Dolomite format is currently unsupported."
)
command.append("--use_dolomite")

if train_args.disable_flash_attn:
if train_args.use_dolomite:
raise RuntimeError(
"ERROR: Trying to use padding-free transformer without flash attention is not supported"
)
command.append("--disable_flash_attn")

if train_args.lora:
Expand Down
34 changes: 34 additions & 0 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, List, Optional
import importlib
import inspect
import json
import logging
import os
import random
Expand Down Expand Up @@ -40,6 +41,39 @@
import torch
import torch.nn.functional as F

# First Party
from instructlab.training.config import TrainingArgs


def check_valid_train_args(train_args: TrainingArgs):
# early validation logic here
if train_args.max_batch_len < train_args.max_seq_len:
raise ValueError(
f"the `max_batch_len` cannot be less than `max_seq_len`: {train_args.max_batch_len=} < {train_args.max_seq_len=}"
)

if os.path.exists(train_args.model_path):
if not os.path.isdir(train_args.model_path):
raise FileNotFoundError(
"Model path does not appear to be a dir, please validate or update the path"
)
else:
raise FileNotFoundError(
"Model Path cannot be found, please verify existense and permissions"
)

if train_args.use_dolomite:
with open(Path(train_args.model_path) / "config.json") as conf_json:
model_conf = json.load(conf_json)
if model_conf["model_type"] == "granite":
raise RuntimeError(
"Converting Granite models to Dolomite format is currently unsupported."
)
if train_args.disable_flash_attn:
raise RuntimeError(
"ERROR: Trying to use dolomite padding-free transformer without flash attention is not supported"
)


def retrieve_chat_template(chat_tmpl_path):
try:
Expand Down

0 comments on commit dba7b29

Please sign in to comment.