diff --git a/test/test_align.py b/test/test_align.py index 1a442f3..c6ac352 100644 --- a/test/test_align.py +++ b/test/test_align.py @@ -3,45 +3,70 @@ import stable_whisper -def check_result(result, expected_text: str): +def check_result(result, expected_text: str, test_name: str): assert result.text == expected_text timing_checked = False for segment in result: for word in segment: - assert word.start < word.end + assert word.start < word.end, (word.start, word.end, test_name) if word.word.strip(" ,") == "americans": - assert word.start <= 1.8, word.start - assert word.end >= 1.8, word.end + assert word.start <= 1.8, (word.start, test_name) + assert word.end >= 1.8, (word.end, test_name) timing_checked = True - assert timing_checked + assert timing_checked, test_name -def test_transcribe(model0_name: str, model1_name: str): +def test_align(model_names): device = "cuda" if torch.cuda.is_available() else "cpu" - model0 = stable_whisper.load_model(model0_name, device=device) audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") - - language = "en" if model0_name.endswith(".en") else None - orig_result = model0.transcribe( - audio_path, language=language, temperature=0.0, word_timestamps=True + models = [stable_whisper.load_model(name, device=device) for name in model_names] + orig_result = models[0].transcribe( + audio_path, language='en', temperature=0.0, word_timestamps=True ) for word in orig_result.all_words(): word.word = word.word.replace('Americans', 'americans') - model1 = stable_whisper.load_model(model1_name, device=device) + def single_test(m, meth: str, prep, extra_check, **kwargs): + model_type = 'multilingual-model' if m.is_multilingual else 'en-model' + meth = getattr(m, meth) + test_name = f'{model_type} {meth.__name__}(WhisperResult)' + try: + result = meth(audio_path, orig_result, **kwargs) + check_same_segment_text(orig_result, result) + except Exception as e: + raise Exception(f'failed test {test_name} -> {e.__class__.__name__}: {e}') + check_result(result, orig_result.text, test_name) + + test_name = f'{model_type} {meth.__name__}(plain-text)' + try: + result = meth(audio_path, prep(orig_result), language=orig_result.language) + if extra_check: + extra_check(orig_result, result) + except Exception as e: + raise Exception(f'failed test {test_name} -> {e.__class__.__name__}: {e}') + check_result(result, orig_result.text, test_name) + + def get_text(res): + return res.text + + def get_segment_dicts(res): + return [dict(start=s.start, end=s.end, text=s.text) for s in res] - result = model1.align(audio_path, orig_result, original_split=True) - assert [s.text for s in result] == [s.text for s in orig_result] - check_result(result, orig_result.text) + def check_same_segment_text(res0, res1): + assert [s.text for s in res0] == [s.text for s in res1], 'mismatch segment text' - result = model1.align(audio_path, orig_result.text, language=orig_result.language) - check_result(result, orig_result.text) + for model in models: + for method in ('align', 'align_words'): + options = dict(original_split=True) if method == 'align' else {} + preprocess = get_text if method == 'align' else get_segment_dicts + check_seg = None if method == 'align' else check_same_segment_text + single_test(model, method, preprocess, check_seg, **options) def test(): - test_transcribe('tiny', 'tiny.en') + test_align(['tiny', 'tiny.en']) if __name__ == '__main__': diff --git a/test/test_refine.py b/test/test_refine.py index 10e0daf..01a2a4c 100644 --- a/test/test_refine.py +++ b/test/test_refine.py @@ -23,7 +23,7 @@ def check_result(result, orig_result, expect_change: bool = True): assert timing_checked -def test_transcribe(model0_name: str, model1_name: str): +def test_refine(model0_name: str, model1_name: str): device = "cuda" if torch.cuda.is_available() else "cpu" model0 = stable_whisper.load_model(model0_name, device=device) audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") @@ -40,7 +40,7 @@ def test_transcribe(model0_name: str, model1_name: str): def test(): - test_transcribe('tiny.en', 'tiny') + test_refine('tiny.en', 'tiny') if __name__ == '__main__':