Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/small changes to wav2vec2 finetuning #54

Merged
merged 20 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
afdc325
chore: Update configs
saattrupdan Nov 28, 2023
77bc20f
style: Logging, kwargs
saattrupdan Nov 28, 2023
8a15b80
debug: Breakpoint
saattrupdan Nov 28, 2023
26ccd56
fix: Do not remove special tokens when decoding, as it prevents dupli…
saattrupdan Nov 28, 2023
309a60d
chore: Remove breakpoint
saattrupdan Nov 28, 2023
bc07ea0
docs: Add note
saattrupdan Nov 28, 2023
e6f3f43
docs: Always print sample predictions when computing metrics
saattrupdan Dec 5, 2023
5e20bd4
chore: Deal with word delimiters
saattrupdan Dec 5, 2023
6478ee6
chore: Update configs
saattrupdan Dec 5, 2023
ed4d97c
fix: Do not hardcode max_seconds_per_example, and add | and space to …
saattrupdan Dec 5, 2023
6270399
fix: Ensure that we pad with pad_token when using a LM decoder
saattrupdan Dec 7, 2023
4c0d09a
fix: Ensure that pad_token is chosen when all logits for a token are …
saattrupdan Dec 7, 2023
5d15643
fix: Do not add special tokens to vocab, as then they won't count as …
saattrupdan Dec 7, 2023
4818ec8
fix: Update padding kwargs in Whisper analogous to Wav2Vec2
saattrupdan Dec 7, 2023
7a9d4bb
fix: Padding with a WhisperProcessor
saattrupdan Dec 7, 2023
56d617c
fix: Add `max_seconds_per_example` to Whisper Processor
saattrupdan Dec 13, 2023
61d2b04
chore: Change config
saattrupdan Dec 13, 2023
2f7573b
fix: Add max_seconds_per_example as argument to DataCollatorSpeechSeq…
saattrupdan Dec 13, 2023
6430fd7
fix: Typo in config max_seconds_per_example
saattrupdan Dec 13, 2023
6b80161
fix: Remove `labels` kwarg from Whisper tokenizer padding
saattrupdan Dec 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dirs:
seed: 4242

# Dataset parameters
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü '
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü'
max_seconds_per_example: 10
dataloader_num_workers: 8

Expand Down Expand Up @@ -47,12 +47,12 @@ ignore_data_skip: false
save_total_limit: 2

# Optimisation parameters
learning_rate: 3e-5
learning_rate: 1e-4
adam_first_momentum: 0.9
adam_second_momentum: 0.98
total_batch_size: 256
per_device_batch_size: 16
max_steps: 50_000
max_steps: 10_000
warmup_steps: 1_000
logging_steps: 10
eval_steps: 100
Expand Down
19 changes: 10 additions & 9 deletions config/model/test_wav2vec2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ clean_dataset: true
# Model hyperparameters
sampling_rate: 16_000
activation_dropout: 0.1
attention_dropout: 0.1
hidden_dropout: 0.1
feat_proj_dropout: 0.1
final_dropout: 0.1
mask_time_prob: 0.075
attention_dropout: 0.0
hidden_dropout: 0.0
feat_proj_dropout: 0.0
feat_quantizer_dropout: 0.0
final_dropout: 0.0
mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.075
mask_feature_length: 10
layerdrop: 0.0 # NOTE: This parameter cannot be used in a multi-gpu setting!
ctc_loss_reduction: sum
mask_feature_prob: 0.5
mask_feature_length: 64
layerdrop: 0.1 # NOTE: This will automatically be set to 0 in a multi-gpu setting
ctc_loss_reduction: mean

# Decoder hyperparameters
language_model_decoder: null
4 changes: 2 additions & 2 deletions config/model/wav2vec2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.5
mask_feature_length: 64
layerdrop: 0.1 # This will automatically be set to 0 in a multi-gpu setting
ctc_loss_reduction: mean
layerdrop: 0.1 # NOTE: This will automatically be set to 0 in a multi-gpu setting
ctc_loss_reduction: sum

# Decoder hyperparameters
language_model_decoder: ngram
Expand Down
24 changes: 15 additions & 9 deletions src/coral_models/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,27 @@ def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str,
if predictions.dtype == np.int_:
vocab_size = tokenizer.get_vocab()
mismatch_dim = len(vocab_size) - predictions.shape[-1]
predictions = np.pad(predictions, ((0, 0), (0, 0), (0, mismatch_dim)))
predictions_str = tokenizer.batch_decode(
predictions, skip_special_tokens=True
predictions = np.pad(
array=predictions,
pad_width=((0, 0), (0, 0), (0, mismatch_dim)),
mode="constant",
constant_values=pad_token,
)
predictions_str = tokenizer.batch_decode(sequences=predictions)

# Otherwise, if we are not using a language model, we need to convert the
# logits to token IDs and then decode the token IDs to get the predicted string
else:
# If all the logits are -100 for a token, then we set the logit for the
# padding token for that token to 0. This is to ensure that this token gets
# decoded to a padding token, and are therefore ignored
predictions[np.all(predictions == -100, axis=-1), pad_token] = 0

pred_ids: NDArray[np.int_] = np.argmax(predictions, axis=-1)
predictions_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
predictions_str = tokenizer.batch_decode(pred_ids)

elif len(predictions.shape) == 2 and predictions.dtype == np.int_:
predictions_str = tokenizer.batch_decode(predictions, skip_special_tokens=True)
predictions_str = tokenizer.batch_decode(sequences=predictions)

else:
raise ValueError(
Expand All @@ -67,11 +75,9 @@ def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str,
labels[labels == -100] = pad_token

# Decode the ground truth labels
labels_str = tokenizer.batch_decode(
sequences=labels, skip_special_tokens=True, group_tokens=False
)
labels_str = tokenizer.batch_decode(sequences=labels, group_tokens=False)

# TEMP: Log both the predictions and the ground truth labels
# Log both the predictions and the ground truth labels
is_main_process = os.getenv("RANK", "0") == "0"
if is_main_process:
random_idx = np.random.randint(0, len(predictions_str))
Expand Down
2 changes: 1 addition & 1 deletion src/coral_models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def clean_dataset(
# transcriptions, as they do not have an influence on the pronunciation of the
# words.
non_standard_characters_regex = re.compile(
f"[^{re.escape(cfg.characters_to_keep)}]"
f"[^{re.escape(cfg.characters_to_keep + ' |')}]"
)

mapped = dataset.map(
Expand Down
40 changes: 24 additions & 16 deletions src/coral_models/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class DataCollatorCTCWithPadding(DataCollatorMixin):
Args:
processor (Wav2Vec2Processor)
The processor used for proccessing the data.
max_seconds_per_example (float):
The maximum number of seconds per example.
padding (bool, str or PaddingStrategy, optional):
Select a strategy to pad the returned sequences (according to the model's
padding side and padding index) among:
Expand All @@ -60,6 +62,7 @@ class DataCollatorCTCWithPadding(DataCollatorMixin):
"""

processor: Wav2Vec2Processor
max_seconds_per_example: float
padding: bool | str
return_tensors: str = "pt"

Expand All @@ -86,12 +89,12 @@ def torch_call(self, features: list[dict]) -> BatchFeature:
audio_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=16_000 * 10,
max_length=16_000 * self.max_seconds_per_example,
)

label_features = [dict(input_ids=feature["labels"]) for feature in features]
labels_batch: BatchEncoding = self.processor.tokenizer.pad(
label_features,
labels_batch: BatchEncoding = self.processor.pad(
labels=label_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=512,
Expand Down Expand Up @@ -125,19 +128,21 @@ def load_processor(self) -> Wav2Vec2Processor:
dump_vocabulary(self.cfg)
tokenizer: Wav2Vec2CTCTokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
self.cfg.model_dir,
unk_token="<unk>",
pad_token="<pad>",
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
word_delimiter_token=" ",
word_delimiter_token="|",
replace_word_delimiter_char=" ",
)
break
except json.decoder.JSONDecodeError:
process_id = os.getenv("RANK", 0)
logger.warning(
f"JSONDecodeError while loading tokenizer on process {process_id}. "
"Retrying in a second."
)
log_message = "JSONDecodeError while loading tokenizer"
process_id = os.getenv("RANK")
if process_id is not None:
log_message += f" in process {process_id}"
log_message += ". Retrying in a second."
logger.warning(log_message)
time.sleep(1)

# Set the `model_max_length` attribute of the tokenizer, if it hasn't been set,
Expand All @@ -155,6 +160,7 @@ def load_processor(self) -> Wav2Vec2Processor:
self.processor = Wav2Vec2Processor(
feature_extractor=extractor, tokenizer=tokenizer
)

return self.processor

def load_model(self) -> Wav2Vec2ForCTC:
Expand All @@ -179,7 +185,7 @@ def load_model(self) -> Wav2Vec2ForCTC:
vocab_size=len(self.processor.tokenizer.get_vocab()),
ctc_zero_infinity=True,
)
assert isinstance(model, Wav2Vec2ForCTC)
assert isinstance(model, Wav2Vec2ForCTC)

if self.cfg.model.freeze_feature_encoder:
for param in model.wav2vec2.parameters():
Expand All @@ -189,7 +195,9 @@ def load_model(self) -> Wav2Vec2ForCTC:

def load_data_collator(self) -> DataCollatorCTCWithPadding:
return DataCollatorCTCWithPadding(
processor=self.processor, padding=self.cfg.padding
processor=self.processor,
max_seconds_per_example=self.cfg.max_seconds_per_example,
padding=self.cfg.padding,
)

def load_trainer_class(self) -> Type[Trainer]:
Expand Down Expand Up @@ -275,7 +283,9 @@ def load_saved(self) -> PreTrainedModelData:

model = Wav2Vec2ForCTC.from_pretrained(self.cfg.hub_id, token=True)
data_collator = DataCollatorCTCWithPadding(
processor=processor, padding=self.cfg.padding
processor=processor,
max_seconds_per_example=self.cfg.max_seconds_per_example,
padding=self.cfg.padding,
)
compute_metrics = partial(compute_wer_metrics, processor=processor)
return PreTrainedModelData(
Expand All @@ -296,12 +306,10 @@ def dump_vocabulary(cfg: DictConfig) -> None:
The Hydra configuration object.
"""
# Build the set of all unique characters in the dataset
unique_characters: set[str] = set(cfg.characters_to_keep)
unique_characters: set[str] = set(cfg.characters_to_keep + "|")

# Build vocabulary
vocab = {char: idx for idx, char in enumerate(unique_characters)}
for tok in ["<unk>", "<pad>", "<s>", "</s>"]:
vocab[tok] = len(vocab)

# Dump the vocabulary to a json file
model_dir = Path(cfg.model_dir)
Expand Down
19 changes: 16 additions & 3 deletions src/coral_models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class DataCollatorSpeechSeq2SeqWithPadding(DataCollatorMixin):
Args:
processor (WhisperProcessor)
The processor used for proccessing the data.
max_seconds_per_example (float):
The maximum number of seconds per example.
padding (bool, str or PaddingStrategy, optional):
Select a strategy to pad the returned sequences (according to the model's
padding side and padding index) among:
Expand All @@ -53,6 +55,7 @@ class DataCollatorSpeechSeq2SeqWithPadding(DataCollatorMixin):
"""

processor: WhisperProcessor
max_seconds_per_example: float
padding: bool | str = True
return_tensors: str = "pt"

Expand All @@ -78,14 +81,22 @@ def torch_call(self, features: list[dict]) -> BatchFeature:
"Features must contain either 'input_features' or 'audio' key."
)
batch = self.processor.feature_extractor.pad(
audio_features, return_tensors="pt"
audio_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=16_000 * self.max_seconds_per_example,
)

# Get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]

# Pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
labels_batch = self.processor.tokenizer.pad(
label_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=512,
)

# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
Expand Down Expand Up @@ -162,7 +173,9 @@ def load_model(self) -> WhisperForConditionalGeneration:

def load_data_collator(self) -> DataCollatorSpeechSeq2SeqWithPadding:
return DataCollatorSpeechSeq2SeqWithPadding(
processor=self.processor, padding=self.cfg.padding
processor=self.processor,
max_seconds_per_example=self.cfg.max_seconds_per_example,
padding=self.cfg.padding,
)

def load_trainer_class(self) -> Type[Trainer]:
Expand Down