From dd5007262ec2c57f0a76e9522ddf968cc43eee67 Mon Sep 17 00:00:00 2001 From: Dan Saattrup Nielsen Date: Tue, 21 Nov 2023 12:46:44 +0100 Subject: [PATCH] fix: Set up Whisper training arguments to new config --- src/coral_models/whisper.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/coral_models/whisper.py b/src/coral_models/whisper.py index ecd6de8e..858fa325 100644 --- a/src/coral_models/whisper.py +++ b/src/coral_models/whisper.py @@ -7,6 +7,7 @@ from typing import Callable, Type from omegaconf import DictConfig +import torch from torch.backends.mps import is_available as mps_is_available from transformers import ( BatchFeature, @@ -171,6 +172,23 @@ def load_compute_metrics(self) -> Callable[[EvalPrediction], dict]: return partial(compute_wer_metrics, processor=self.processor) def load_training_arguments(self) -> TrainingArguments: + # Compute the gradient accumulation based on the total batch size in the config + num_gpus_available = torch.cuda.device_count() + per_device_total_batch_size = self.cfg.total_batch_size // num_gpus_available + gradient_accumulation_steps = ( + per_device_total_batch_size // self.cfg.per_device_batch_size + ) + + if gradient_accumulation_steps == 0: + logger.warning( + f"Your `total_batch_size` is too small ({self.cfg.total_batch_size}), " + f"relative to the number of GPUs ({num_gpus_available}) and your " + f"`per_device_batch_size` ({self.cfg.per_device_batch_size}). It has " + f"been set to `per_device_batch_size * num_gpus_available` = " + f"{self.cfg.per_device_batch_size * num_gpus_available}." + ) + gradient_accumulation_steps = 1 + do_eval = any( [ dataset_cfg.val_name is not None @@ -180,9 +198,9 @@ def load_training_arguments(self) -> TrainingArguments: args = Seq2SeqTrainingArguments( output_dir=self.cfg.model_dir, hub_model_id=self.cfg.hub_id, - per_device_train_batch_size=self.cfg.batch_size, - per_device_eval_batch_size=self.cfg.batch_size, - gradient_accumulation_steps=self.cfg.gradient_accumulation, + per_device_train_batch_size=self.cfg.per_device_batch_size, + per_device_eval_batch_size=self.cfg.per_device_batch_size, + gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=self.cfg.learning_rate, warmup_steps=self.cfg.warmup_steps, max_steps=self.cfg.max_steps, @@ -208,6 +226,7 @@ def load_training_arguments(self) -> TrainingArguments: generation_max_length=self.cfg.model.generation_max_length, use_cpu=hasattr(sys, "_called_from_test"), dataloader_num_workers=self.cfg.dataloader_num_workers, + ddp_find_unused_parameters=False, ) return args