Skip to content

Commit

Permalink
feat: Automatically set hyperparameters related to multi-GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
saattrupdan committed Nov 22, 2023
1 parent 4b356a6 commit cbeb44e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
5 changes: 4 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ seed: 4242
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü '
max_seconds_per_example: 10
dataloader_num_workers: 8
padding: longest # longest/max_length/do_not_pad

# Can be `longest`, `max_length` or `do_not_pad`
# NOTE: This is automatically set to `max_length` in a multi-gpu setting
padding: longest

# This is a list of the sampling probability of each dataset, where null means that
# each dataset will be sampled equally often
Expand Down
2 changes: 1 addition & 1 deletion config/model/wav2vec2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.5
mask_feature_length: 64
layerdrop: 0.0 # NOTE: This parameter cannot be used in a multi-gpu setting!
layerdrop: 0.1 # This will automatically be set to 0 in a multi-gpu setting
ctc_loss_reduction: mean

# Decoder hyperparameters
Expand Down
9 changes: 9 additions & 0 deletions src/scripts/finetune_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import hydra
from omegaconf import DictConfig
import os

from coral_models.finetune import finetune

Expand All @@ -18,6 +19,14 @@ def main(cfg: DictConfig) -> None:
cfg (DictConfig):
The Hydra configuration object.
"""
# In case we are running in a multi-GPU setting, we need to force certain
# hyperparameters
if os.getenv("WORLD_SIZE") is not None:
if "layerdrop" in cfg.model:
cfg.model.layerdrop = 0.0
cfg.padding = "max_length"

breakpoint()
finetune(cfg)


Expand Down

0 comments on commit cbeb44e

Please sign in to comment.