Skip to content

Commit

Permalink
Misc fixes 20250130 (#2301)
Browse files Browse the repository at this point in the history
* misc fixes for garbage collection and L40S w NCCL P2P

* patch bnb fix for triton check

* chore: lint

* change up import

* try patching differently

* remove patch for bnb fix for now

* more verbose checks and tweak train loss threshold
  • Loading branch information
winglian authored Jan 31, 2025
1 parent 6f294c3 commit cf17649
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/dataset-formats/stepwise_supervised.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ Here's a simple example of a stepwise supervised dataset entry:
],
"labels": [true, false]
}
```
```
8 changes: 7 additions & 1 deletion src/axolotl/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,12 @@ def __init__(self, gc_steps=None):
def on_step_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if state.global_step % self.gc_steps == 0:
if self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
torch.cuda.empty_cache()
gc.collect()

def on_epoch_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
torch.cuda.empty_cache()
gc.collect()
2 changes: 1 addition & 1 deletion src/axolotl/utils/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support():
return False
unsupported_devices = {"RTX 6000 Ada"}
unsupported_devices = {"RTX 6000 Ada", "L40S"}
try:
device_names, device_count = get_gpu_info()
if 1 < device_count < 8:
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_process_reward_model_smollm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_prm(self, temp_dir):

train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.7, "Train Loss (%s) is too high"
)

check_model_output_exists(temp_dir, cfg)
5 changes: 4 additions & 1 deletion tests/e2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def check_tensorboard(
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
assert df.value.values[-1] < lt_val, assertion_err
if "%s" in assertion_err:
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
else:
assert df.value.values[-1] < lt_val, assertion_err


def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
Expand Down

0 comments on commit cf17649

Please sign in to comment.