Skip to content

Commit

Permalink
fix: Min should've been a max
Browse files Browse the repository at this point in the history
  • Loading branch information
saattrupdan committed Nov 21, 2023
1 parent 63ebc36 commit bdc4d5a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/coral_models/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def load_compute_metrics(self) -> Callable[[EvalPrediction], dict]:

def load_training_arguments(self) -> TrainingArguments:
# Compute the gradient accumulation based on the total batch size in the config
num_devices = min(torch.cuda.device_count(), 1)
num_devices = max(torch.cuda.device_count(), 1)
per_device_total_batch_size = self.cfg.total_batch_size // num_devices
gradient_accumulation_steps = (
per_device_total_batch_size // self.cfg.per_device_batch_size
Expand Down
2 changes: 1 addition & 1 deletion src/coral_models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def load_compute_metrics(self) -> Callable[[EvalPrediction], dict]:

def load_training_arguments(self) -> TrainingArguments:
# Compute the gradient accumulation based on the total batch size in the config
num_devices = min(torch.cuda.device_count(), 1)
num_devices = max(torch.cuda.device_count(), 1)
per_device_total_batch_size = self.cfg.total_batch_size // num_devices
gradient_accumulation_steps = (
per_device_total_batch_size // self.cfg.per_device_batch_size
Expand Down

0 comments on commit bdc4d5a

Please sign in to comment.