Skip to content

Commit

Permalink
fix: Do not hardcode max_seconds_per_example, and add | and space to …
Browse files Browse the repository at this point in the history
…characters_to_keep
  • Loading branch information
saattrupdan committed Dec 5, 2023
1 parent 6478ee6 commit ed4d97c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/coral_models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def clean_dataset(
# transcriptions, as they do not have an influence on the pronunciation of the
# words.
non_standard_characters_regex = re.compile(
f"[^{re.escape(cfg.characters_to_keep + ' ')}]"
f"[^{re.escape(cfg.characters_to_keep + ' |')}]"
)

mapped = dataset.map(
Expand Down
19 changes: 13 additions & 6 deletions src/coral_models/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class DataCollatorCTCWithPadding(DataCollatorMixin):
Args:
processor (Wav2Vec2Processor)
The processor used for proccessing the data.
max_seconds_per_example (float):
The maximum number of seconds per example.
padding (bool, str or PaddingStrategy, optional):
Select a strategy to pad the returned sequences (according to the model's
padding side and padding index) among:
Expand All @@ -61,6 +63,7 @@ class DataCollatorCTCWithPadding(DataCollatorMixin):

processor: Wav2Vec2Processor
padding: bool | str
max_seconds_per_example: float
return_tensors: str = "pt"

def torch_call(self, features: list[dict]) -> BatchFeature:
Expand All @@ -86,12 +89,12 @@ def torch_call(self, features: list[dict]) -> BatchFeature:
audio_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=16_000 * 10,
max_length=16_000 * self.max_seconds_per_example,
)

label_features = [dict(input_ids=feature["labels"]) for feature in features]
labels_batch: BatchEncoding = self.processor.tokenizer.pad(
label_features,
labels_batch: BatchEncoding = self.processor.pad(
labels=label_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=512,
Expand Down Expand Up @@ -192,7 +195,9 @@ def load_model(self) -> Wav2Vec2ForCTC:

def load_data_collator(self) -> DataCollatorCTCWithPadding:
return DataCollatorCTCWithPadding(
processor=self.processor, padding=self.cfg.padding
processor=self.processor,
max_seconds_per_example=self.cfg.max_seconds_per_example,
padding=self.cfg.padding,
)

def load_trainer_class(self) -> Type[Trainer]:
Expand Down Expand Up @@ -278,7 +283,9 @@ def load_saved(self) -> PreTrainedModelData:

model = Wav2Vec2ForCTC.from_pretrained(self.cfg.hub_id, token=True)
data_collator = DataCollatorCTCWithPadding(
processor=processor, padding=self.cfg.padding
processor=processor,
max_seconds_per_example=self.cfg.max_seconds_per_example,
padding=self.cfg.padding,
)
compute_metrics = partial(compute_wer_metrics, processor=processor)
return PreTrainedModelData(
Expand All @@ -299,7 +306,7 @@ def dump_vocabulary(cfg: DictConfig) -> None:
The Hydra configuration object.
"""
# Build the set of all unique characters in the dataset
unique_characters: set[str] = set(cfg.characters_to_keep)
unique_characters: set[str] = set(cfg.characters_to_keep + "|")

# Build vocabulary
vocab = {char: idx for idx, char in enumerate(unique_characters)}
Expand Down

0 comments on commit ed4d97c

Please sign in to comment.