From 5e20bd417a7937e8352c2cd2445f23cb58bdcfe7 Mon Sep 17 00:00:00 2001 From: Dan Saattrup Nielsen Date: Tue, 5 Dec 2023 16:05:55 +0100 Subject: [PATCH] chore: Deal with word delimiters --- config/config.yaml | 2 +- src/coral_models/data.py | 2 +- src/coral_models/wav2vec2.py | 6 ++++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 24401896..2b55cc00 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -15,7 +15,7 @@ dirs: seed: 4242 # Dataset parameters -characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü ' +characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü' max_seconds_per_example: 10 dataloader_num_workers: 8 diff --git a/src/coral_models/data.py b/src/coral_models/data.py index a3011ce7..37ac0e69 100644 --- a/src/coral_models/data.py +++ b/src/coral_models/data.py @@ -281,7 +281,7 @@ def clean_dataset( # transcriptions, as they do not have an influence on the pronunciation of the # words. non_standard_characters_regex = re.compile( - f"[^{re.escape(cfg.characters_to_keep)}]" + f"[^{re.escape(cfg.characters_to_keep + ' ')}]" ) mapped = dataset.map( diff --git a/src/coral_models/wav2vec2.py b/src/coral_models/wav2vec2.py index 1dee9980..552769f5 100644 --- a/src/coral_models/wav2vec2.py +++ b/src/coral_models/wav2vec2.py @@ -129,7 +129,8 @@ def load_processor(self) -> Wav2Vec2Processor: pad_token="", bos_token="", eos_token="", - word_delimiter_token=" ", + word_delimiter_token="|", + replace_word_delimiter_char=" ", ) break except json.decoder.JSONDecodeError: @@ -156,6 +157,7 @@ def load_processor(self) -> Wav2Vec2Processor: self.processor = Wav2Vec2Processor( feature_extractor=extractor, tokenizer=tokenizer ) + return self.processor def load_model(self) -> Wav2Vec2ForCTC: @@ -180,7 +182,7 @@ def load_model(self) -> Wav2Vec2ForCTC: vocab_size=len(self.processor.tokenizer.get_vocab()), ctc_zero_infinity=True, ) - assert isinstance(model, Wav2Vec2ForCTC) + assert isinstance(model, Wav2Vec2ForCTC) if self.cfg.model.freeze_feature_encoder: for param in model.wav2vec2.parameters():