Skip to content

Commit

Permalink
Merge pull request #50 from alexandrainst/feat/enable-multi-gpu-training
Browse files Browse the repository at this point in the history
  • Loading branch information
saattrupdan authored Nov 23, 2023
2 parents 47d83b7 + ef5d4d5 commit d44145c
Show file tree
Hide file tree
Showing 14 changed files with 403 additions and 50 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
pytest:
strategy:
matrix:
os: [windows-latest, macos-latest, ubuntu-latest]
os: [macos-latest, ubuntu-latest]
python-version: ["3.11"]
runs-on: ${{ matrix.os }}
steps:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ______________________________________________________________________
[![Documentation](https://img.shields.io/badge/docs-passing-green)](https://alexandrainst.github.io/coral_models/coral_models.html)
[![License](https://img.shields.io/github/license/alexandrainst/coral_models)](https://github.com/alexandrainst/coral_models/blob/main/LICENSE)
[![LastCommit](https://img.shields.io/github/last-commit/alexandrainst/coral_models)](https://github.com/alexandrainst/coral_models/commits/main)
[![Code Coverage](https://img.shields.io/badge/Coverage-53%25-orange.svg)](https://github.com/alexandrainst/coral_models/tree/main/tests)
[![Code Coverage](https://img.shields.io/badge/Coverage-54%25-orange.svg)](https://github.com/alexandrainst/coral_models/tree/main/tests)


Developers:
Expand Down
8 changes: 6 additions & 2 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü '
max_seconds_per_example: 10
dataloader_num_workers: 8

# Can be `longest`, `max_length` or `do_not_pad`
# NOTE: This is automatically set to `max_length` in a multi-gpu setting
padding: longest

# This is a list of the sampling probability of each dataset, where null means that
# each dataset will be sampled equally often
dataset_probabilities:
Expand Down Expand Up @@ -46,8 +50,8 @@ save_total_limit: 2
learning_rate: 3e-5
adam_first_momentum: 0.9
adam_second_momentum: 0.98
batch_size: 8
gradient_accumulation: 32
total_batch_size: 256
per_device_batch_size: 16
max_steps: 50_000
warmup_steps: 1_000
logging_steps: 10
Expand Down
2 changes: 1 addition & 1 deletion config/model/test_wav2vec2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mask_time_prob: 0.075
mask_time_length: 10
mask_feature_prob: 0.075
mask_feature_length: 10
layerdrop: 0.1
layerdrop: 0.0 # NOTE: This parameter cannot be used in a multi-gpu setting!
ctc_loss_reduction: sum

# Decoder hyperparameters
Expand Down
6 changes: 3 additions & 3 deletions config/model/wav2vec2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ hidden_dropout: 0.0
feat_proj_dropout: 0.0
feat_quantizer_dropout: 0.0
final_dropout: 0.0
mask_time_prob: 0.3
mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.3
mask_feature_prob: 0.5
mask_feature_length: 64
layerdrop: 0.1
layerdrop: 0.1 # This will automatically be set to 0 in a multi-gpu setting
ctc_loss_reduction: mean

# Decoder hyperparameters
Expand Down
244 changes: 243 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pycountry = "^22.3.5"
wave = ">=0.0.2,<1.0.0"
kenlm = {url = "https://github.com/kpu/kenlm/archive/master.zip"}
matplotlib = "3.7.3"
deepspeed = ">=0.12.3,<1.0.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.0.0"
Expand Down
15 changes: 14 additions & 1 deletion src/coral_models/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
from evaluate.loading import load as load_metric
from numpy.typing import NDArray
from transformers import EvalPrediction, PreTrainedTokenizerBase
import logging
import os

from .protocols import Processor

logger = logging.getLogger(__name__)


def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str, float]:
"""Compute the word error rate of predictions.
Expand Down Expand Up @@ -63,7 +67,16 @@ def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str,
labels[labels == -100] = pad_token

# Decode the ground truth labels
labels_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
labels_str = tokenizer.batch_decode(
sequences=labels, skip_special_tokens=True, group_tokens=False
)

# TEMP: Log both the predictions and the ground truth labels
is_main_process = os.getenv("RANK", "0") == "0"
if is_main_process:
random_idx = np.random.randint(0, len(predictions_str))
logger.info(f"Sample document: {labels_str[random_idx]}")
logger.info(f"Predicted: {predictions_str[random_idx]}")

# Compute the word error rate
computed = wer_metric.compute(predictions=predictions_str, references=labels_str)
Expand Down
24 changes: 15 additions & 9 deletions src/coral_models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@ def load_data(cfg: DictConfig) -> DatasetDict | IterableDatasetDict:
ValueError:
If the dataset is not supported.
"""
# Note if we're on the main process, if we are running in a distributed setting
is_main_process = os.getenv("RANK", "0") == "0"

all_datasets: list[DatasetDict | IterableDatasetDict] = list()
for dataset_name, dataset_cfg in cfg.datasets.items():
logger.info(f"Loading dataset {dataset_name!r}")
if is_main_process:
logger.info(f"Loading dataset {dataset_name!r}")

# Load from disk if the dataset ID is a path
if Path(dataset_cfg.id).exists():
Expand Down Expand Up @@ -126,14 +130,16 @@ def load_data(cfg: DictConfig) -> DatasetDict | IterableDatasetDict:
assert len(all_datasets) > 0, "No datasets were loaded"

if len(all_datasets) > 1:
logger.info("Interleaving datasets")
if cfg.dataset_probabilities["train"] is None and len(all_datasets) > 1:
logger.warning(
"No dataset probabilities were specified for the training split. "
"This means that each dataset will be sampled with equal probability, "
"which means that the smaller datasets will be sampled more often than "
"the larger datasets. This is probably not what you want."
)
if is_main_process:
logger.info("Interleaving datasets")
if cfg.dataset_probabilities["train"] is None and len(all_datasets) > 1:
logger.warning(
"No dataset probabilities were specified for the training split. "
"This means that each dataset will be sampled with equal "
"probability, which means that the smaller datasets will be "
"sampled more often than the larger datasets. This is probably "
"not what you want."
)

probabilities: dict[str, list[float]] = dict()
for split_name, split_probs in cfg.dataset_probabilities.items():
Expand Down
19 changes: 12 additions & 7 deletions src/coral_models/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
from functools import partial
import logging
from typing import Callable
import os

from omegaconf import DictConfig
from transformers import EarlyStoppingCallback, TrainerCallback
from wandb.sdk.wandb_init import init as wandb_init
from wandb.sdk.wandb_run import finish as wandb_finish

from .utils import disable_tqdm
from .data import load_data
from .model_setup import load_model_setup
from .protocols import ModelSetup
from .utils import disable_tqdm

logger = logging.getLogger(__package__)

Expand Down Expand Up @@ -64,6 +65,9 @@ def finetune(cfg: DictConfig) -> None:
Args:
cfg: The Hydra cfguration object.
"""
# Note if we're on the main process, if we are running in a distributed setting
is_main_process = os.getenv("RANK", "0") == "0"

model_setup: ModelSetup = load_model_setup(cfg)
processor = model_setup.load_processor()
processor.save_pretrained(cfg.model_dir)
Expand All @@ -81,15 +85,15 @@ def finetune(cfg: DictConfig) -> None:
),
)

if cfg.wandb:
if cfg.wandb and is_main_process:
wandb_init(
project=cfg.wandb_project,
group=cfg.wandb_group,
name=cfg.wandb_name,
config=dict(cfg),
)

if "val" not in dataset:
if "val" not in dataset and is_main_process:
logger.info("No validation set found. Disabling early stopping.")

trainer = model_setup.load_trainer_class()(
Expand All @@ -105,11 +109,12 @@ def finetune(cfg: DictConfig) -> None:

with disable_tqdm():
trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
wandb_finish()

model.save_pretrained(cfg.model_dir)
if cfg.push_to_hub:
trainer.push_to_hub()
if is_main_process:
wandb_finish()
model.save_pretrained(cfg.model_dir)
if cfg.push_to_hub:
trainer.push_to_hub()


def load_early_stopping_callback(cfg: DictConfig) -> list[TrainerCallback]:
Expand Down
73 changes: 56 additions & 17 deletions src/coral_models/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from functools import partial
from pathlib import Path
from typing import Callable, Type
import time
import os

import torch
from omegaconf import DictConfig
Expand Down Expand Up @@ -58,7 +60,7 @@ class DataCollatorCTCWithPadding(DataCollatorMixin):
"""

processor: Wav2Vec2Processor
padding: bool | str = True
padding: bool | str
return_tensors: str = "pt"

def torch_call(self, features: list[dict]) -> BatchFeature:
Expand All @@ -81,12 +83,18 @@ def torch_call(self, features: list[dict]) -> BatchFeature:
"Features must contain either 'input_values' or 'audio' key."
)
batch: BatchFeature = self.processor.pad(
audio_features, padding=self.padding, return_tensors="pt"
audio_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=16_000 * 10,
)

label_features = [dict(input_ids=feature["labels"]) for feature in features]
labels_batch: BatchEncoding = self.processor.tokenizer.pad(
label_features, padding=self.padding, return_tensors="pt"
label_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=512,
)

# Replace padding with -100 to ignore loss correctly
Expand All @@ -112,15 +120,25 @@ def __init__(self, cfg: DictConfig) -> None:
def load_processor(self) -> Wav2Vec2Processor:
# We dump the vocabulary to a file since the tokenizer uses this file during
# initialisation
dump_vocabulary(self.cfg)
tokenizer: Wav2Vec2CTCTokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
self.cfg.model_dir,
unk_token="<unk>",
pad_token="<pad>",
bos_token="<s>",
eos_token="</s>",
word_delimiter_token=" ",
)
while True:
try:
dump_vocabulary(self.cfg)
tokenizer: Wav2Vec2CTCTokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
self.cfg.model_dir,
unk_token="<unk>",
pad_token="<pad>",
bos_token="<s>",
eos_token="</s>",
word_delimiter_token=" ",
)
break
except json.decoder.JSONDecodeError:
process_id = os.getenv("RANK", 0)
logger.warning(
f"JSONDecodeError while loading tokenizer on process {process_id}. "
"Retrying in a second."
)
time.sleep(1)

# Set the `model_max_length` attribute of the tokenizer, if it hasn't been set,
# to ensure that truncation is done correctly
Expand Down Expand Up @@ -170,7 +188,9 @@ def load_model(self) -> Wav2Vec2ForCTC:
return model

def load_data_collator(self) -> DataCollatorCTCWithPadding:
return DataCollatorCTCWithPadding(processor=self.processor, padding=True)
return DataCollatorCTCWithPadding(
processor=self.processor, padding=self.cfg.padding
)

def load_trainer_class(self) -> Type[Trainer]:
return Trainer
Expand All @@ -179,6 +199,23 @@ def load_compute_metrics(self) -> Callable[[EvalPrediction], dict]:
return partial(compute_wer_metrics, processor=self.processor)

def load_training_arguments(self) -> TrainingArguments:
# Compute the gradient accumulation based on the total batch size in the config
num_devices = max(torch.cuda.device_count(), 1)
per_device_total_batch_size = self.cfg.total_batch_size // num_devices
gradient_accumulation_steps = (
per_device_total_batch_size // self.cfg.per_device_batch_size
)

if gradient_accumulation_steps == 0:
logger.warning(
f"Your `total_batch_size` is too small ({self.cfg.total_batch_size}), "
f"relative to the number of devices ({num_devices}) and your "
f"`per_device_batch_size` ({self.cfg.per_device_batch_size}). It has "
f"been set to `per_device_batch_size * num_devices` = "
f"{self.cfg.per_device_batch_size * num_devices}."
)
gradient_accumulation_steps = 1

do_eval = any(
[
dataset_cfg.val_name is not None
Expand All @@ -188,9 +225,9 @@ def load_training_arguments(self) -> TrainingArguments:
args = TrainingArguments(
output_dir=self.cfg.model_dir,
hub_model_id=self.cfg.hub_id,
per_device_train_batch_size=self.cfg.batch_size,
per_device_eval_batch_size=self.cfg.batch_size,
gradient_accumulation_steps=self.cfg.gradient_accumulation,
per_device_train_batch_size=self.cfg.per_device_batch_size,
per_device_eval_batch_size=self.cfg.per_device_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
lr_scheduler_type=SchedulerType.COSINE,
warmup_steps=self.cfg.warmup_steps,
Expand All @@ -200,6 +237,7 @@ def load_training_arguments(self) -> TrainingArguments:
evaluation_strategy="steps" if do_eval else "no",
eval_steps=self.cfg.eval_steps if do_eval else None,
save_steps=self.cfg.save_steps,
save_strategy="no" if self.cfg.save_total_limit == 0 else "steps",
logging_steps=self.cfg.logging_steps,
length_column_name="input_length",
gradient_checkpointing=True,
Expand All @@ -217,6 +255,7 @@ def load_training_arguments(self) -> TrainingArguments:
save_safetensors=True,
use_cpu=hasattr(sys, "_called_from_test"),
dataloader_num_workers=self.cfg.dataloader_num_workers,
ddp_find_unused_parameters=False,
)
return args

Expand All @@ -236,7 +275,7 @@ def load_saved(self) -> PreTrainedModelData:

model = Wav2Vec2ForCTC.from_pretrained(self.cfg.hub_id, token=True)
data_collator = DataCollatorCTCWithPadding(
processor=processor, padding="longest"
processor=processor, padding=self.cfg.padding
)
compute_metrics = partial(compute_wer_metrics, processor=processor)
return PreTrainedModelData(
Expand Down
Loading

0 comments on commit d44145c

Please sign in to comment.