diff --git a/src/coral_models/whisper.py b/src/coral_models/whisper.py index 3ade4940..6c459097 100644 --- a/src/coral_models/whisper.py +++ b/src/coral_models/whisper.py @@ -77,15 +77,23 @@ def torch_call(self, features: list[dict]) -> BatchFeature: raise ValueError( "Features must contain either 'input_features' or 'audio' key." ) - batch = self.processor.feature_extractor.pad( - audio_features, return_tensors="pt" + batch = self.processor.pad( + audio_features, + padding=self.padding, + return_tensors=self.return_tensors, + max_length=16_000 * self.max_seconds_per_example, ) # Get the tokenized label sequences label_features = [{"input_ids": feature["labels"]} for feature in features] # Pad the labels to max length - labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") + labels_batch = self.processor.pad( + labels=label_features, + padding=self.padding, + return_tensors=self.return_tensors, + max_length=512, + ) # replace padding with -100 to ignore loss correctly labels = labels_batch["input_ids"].masked_fill(