From bdc4d5a1ed0e27ab9d366c6d26eec96889d2d26c Mon Sep 17 00:00:00 2001 From: Dan Saattrup Nielsen Date: Tue, 21 Nov 2023 12:58:12 +0100 Subject: [PATCH] fix: Min should've been a max --- src/coral_models/wav2vec2.py | 2 +- src/coral_models/whisper.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/coral_models/wav2vec2.py b/src/coral_models/wav2vec2.py index 0d63a90d..a8ef12c2 100644 --- a/src/coral_models/wav2vec2.py +++ b/src/coral_models/wav2vec2.py @@ -200,7 +200,7 @@ def load_compute_metrics(self) -> Callable[[EvalPrediction], dict]: def load_training_arguments(self) -> TrainingArguments: # Compute the gradient accumulation based on the total batch size in the config - num_devices = min(torch.cuda.device_count(), 1) + num_devices = max(torch.cuda.device_count(), 1) per_device_total_batch_size = self.cfg.total_batch_size // num_devices gradient_accumulation_steps = ( per_device_total_batch_size // self.cfg.per_device_batch_size diff --git a/src/coral_models/whisper.py b/src/coral_models/whisper.py index f8bdd167..2cea92d0 100644 --- a/src/coral_models/whisper.py +++ b/src/coral_models/whisper.py @@ -173,7 +173,7 @@ def load_compute_metrics(self) -> Callable[[EvalPrediction], dict]: def load_training_arguments(self) -> TrainingArguments: # Compute the gradient accumulation based on the total batch size in the config - num_devices = min(torch.cuda.device_count(), 1) + num_devices = max(torch.cuda.device_count(), 1) per_device_total_batch_size = self.cfg.total_batch_size // num_devices gradient_accumulation_steps = ( per_device_total_batch_size // self.cfg.per_device_batch_size