Skip to content

Commit

Permalink
chore: add exit code & tox fix
Browse files Browse the repository at this point in the history
Currently, the training library does not exit when an error is encountered
within the training loop (invoked through torchrun). This commit updates
that functionality so we correctly return an exit code of 1 on child failure.

Additionally, this commit also adds the `make fix` command which
automatically fixes all trivial issues picked up on by ruff

Signed-off-by: Oleg S <[email protected]>
  • Loading branch information
RobotSail committed Oct 16, 2024
1 parent e680bd8 commit 9c899dc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
15 changes: 10 additions & 5 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
print(f"\033[92mRunning training command as subprocess: {' '.join(command)}\033[0m")
process = None
interrupt: KeyboardInterrupt | Exception | None = None
failure = False
try:
process = StreamablePopen(
f"{train_args.ckpt_output_dir}/full_logs_global{torch_args.node_rank}.log",
Expand All @@ -771,19 +772,20 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
print("Training subprocess interrupted by user.")
interrupt = e
except Exception as e:
print(f"An error occurred: {str(e)}")
print("Unexpected exception received during distributed training")
interrupt = e
finally:
if "process" not in locals() or process is None:
return
if process.poll() == 0:
print("\033[92mTraining subprocess exited successfully! 🎉\033[0m")

failure = process.poll() != 0
if not failure:
print("\033[92mOperation completed successfully! 🎉\033[0m")
else:
print(
"\033[91mTraining subprocess has not exited yet. Sending SIGTERM.\033[0m"
)

print("Sending interrupt signal to Training subprocess.")
process.terminate()
try:
print("Waiting for process to exit, 60s...")
Expand All @@ -795,8 +797,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
process.kill()

if interrupt:
print(f"Error caught from training subprocess.: {interrupt}")
raise interrupt
if failure:
raise RuntimeError(
"Suffered a failure during distributed training. Please see the training logs for more context."
)


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ commands =
sh -c 'git diff --exit-code || (echo "pyproject.toml formatting is incorrect. Please run \"make toml-fmt\" and commit the changes." && exit 1)'
allowlist_externals = make, sh


[testenv:spellcheck]
description = spell check (needs 'aspell' command)
basepython = {[testenv:py3]basepython}
Expand Down

0 comments on commit 9c899dc

Please sign in to comment.