Skip to content

Commit

Permalink
fix: Update padding kwargs in Whisper analogous to Wav2Vec2
Browse files Browse the repository at this point in the history
  • Loading branch information
saattrupdan committed Dec 7, 2023
1 parent 5d15643 commit 4818ec8
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/coral_models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,23 @@ def torch_call(self, features: list[dict]) -> BatchFeature:
raise ValueError(
"Features must contain either 'input_features' or 'audio' key."
)
batch = self.processor.feature_extractor.pad(
audio_features, return_tensors="pt"
batch = self.processor.pad(
audio_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=16_000 * self.max_seconds_per_example,
)

# Get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]

# Pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
labels_batch = self.processor.pad(
labels=label_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=512,
)

# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
Expand Down

0 comments on commit 4818ec8

Please sign in to comment.