-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #33 from JSchmie/tests
Good Job :)
- Loading branch information
Showing
8 changed files
with
327 additions
and
120 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
name: Run tests | ||
|
||
on: | ||
#push: | ||
|
||
pull_request: | ||
branches: ['main', 'develop'] | ||
|
||
workflow_dispatch: | ||
|
||
jobs: | ||
pytest: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v3 | ||
|
||
- name: Setup Python | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: 3.9 | ||
|
||
- name: Install Dependencies | ||
run: | | ||
sudo apt update && sudo apt upgrade | ||
python -m pip install --upgrade pip | ||
pip install -r requirements.txt | ||
pip install . | ||
sudo apt-get install libsndfile1-dev | ||
sudo apt-get install ffmpeg | ||
pip install pytest | ||
- name: Run pytest | ||
env: | ||
HF_TOKEN : ${{ secrets.HF_TOKEN }} | ||
run: | | ||
pytest | ||
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import pytest | ||
from scraibe.audio import AudioProcessor | ||
import torch | ||
|
||
|
||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
TEST_WAVEFORM = torch.sin(torch.randn(160000)).to(DEVICE) | ||
TEST_SR = 16000 | ||
SAMPLE_RATE = 16000 | ||
NORMALIZATION_FACTOR = 32768 | ||
|
||
|
||
@pytest.fixture | ||
def probe_audio_processor(): | ||
"""Fixture for creating an instance of the AudioProcessor class with test waveform and sample rate. | ||
This fixture is used to create an instance of the AudioProcessor class with a predfined test waveform and sample rate (TEST_SR). It returns the instantiated AudioProcessor , which can bes used as a | ||
dependency in other test functions. | ||
Returns: | ||
AudioProcessor (obj): An instance of the AudioProcessor class with the test waveform and sample rate. | ||
""" | ||
return AudioProcessor(TEST_WAVEFORM, TEST_SR) | ||
|
||
|
||
|
||
|
||
|
||
|
||
def test_AudioProcessor_init(probe_audio_processor): | ||
""" | ||
Test the initialization of the AudioProcessor class. | ||
This test verifies that the AUdioProcessor class is correctly initialized with the provided waveform and sample rate. It checks whether the instantiated AhdioProcessor object has the correct attributes | ||
and whether the waveform and sample rate match the expected values. | ||
Args: | ||
probe_audio_processor (obj): An instance of the AudioProcessor class to be tested. | ||
Returns: | ||
None | ||
""" | ||
assert isinstance(probe_audio_processor, AudioProcessor) | ||
assert probe_audio_processor.waveform.device == TEST_WAVEFORM.device | ||
assert torch.equal(probe_audio_processor.waveform, TEST_WAVEFORM) | ||
assert probe_audio_processor.sr == TEST_SR | ||
|
||
|
||
|
||
def test_cut(probe_audio_processor): | ||
"""Test the cut function of the AudioProcessor class. | ||
This test verifies that the cut function correctly extracts a segment of audio data from | ||
the waveform, given start and end indices. It checks whether the size of the extracted segment matches | ||
the expected size based on the provided start and end indices and the sample rate. | ||
Returns: | ||
None | ||
""" | ||
|
||
start = 4 | ||
end = 7 | ||
trimmed_waveform = probe_audio_processor.cut(start, end) | ||
expected_size = int((end - start) * TEST_SR) | ||
real_size = trimmed_waveform.size(0) | ||
assert real_size == expected_size | ||
#assert AudioProcessor(TEST_WAVEFORM, TEST_SR).cut(start, end).size() == int((end - start) * TEST_SR) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
def test_audio_processor_invalid_sr(): | ||
"""Test the behavior of AudioProcessor when an invalid smaple rate is provided. | ||
This test verifies that the AudioProcessor constructor raises a ValueError when an invalid sample rate is provided. It uses the pytest.raises context manager to check if the ValueError is raised when initializing an | ||
AudioProcessor object with an invalid sample rate. | ||
Returns: | ||
None | ||
""" | ||
with pytest.raises(ValueError): | ||
AudioProcessor(TEST_WAVEFORM, [44100,48000]) | ||
|
||
|
||
def test_audio_processor_SAMPLE_RATE(): | ||
"""Test the default sample rate of the AudioProcessor class. | ||
This test verifies that the default sample rate of the AudioProcessor class matches the expected value defined by the constant SAMPLE_RATE. It instantiates an AudioProcessor object with a test waveform | ||
and checks whether the sample rate attribute (sr) of the AudioProcessor object equals the predefined constant SAMPLE_RATE. | ||
Returns: | ||
None | ||
""" | ||
probe_audio_processor = AudioProcessor(TEST_WAVEFORM) | ||
assert probe_audio_processor.sr == SAMPLE_RATE | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import pytest | ||
from scraibe import Scraibe, Diariser, Transcriber, Transcript | ||
from unittest.mock import MagicMock, patch | ||
import os | ||
|
||
|
||
|
||
|
||
|
||
@pytest.fixture | ||
def create_scraibe_instance(): | ||
if "HF_TOKEN" in os.environ: | ||
return Scraibe(use_auth_token=os.environ["HF_TOKEN"] ) | ||
else: | ||
return Scraibe() | ||
|
||
|
||
|
||
|
||
def test_scraibe_init(create_scraibe_instance): | ||
model = create_scraibe_instance | ||
assert isinstance(model.transcriber, Transcriber) | ||
assert isinstance(model.diariser, Diariser) | ||
|
||
|
||
def test_scraibe_autotranscribe(create_scraibe_instance): | ||
model = create_scraibe_instance | ||
transcript = model.autotranscribe('test/audio_test_2.mp4') | ||
assert isinstance(transcript, Transcript) | ||
|
||
|
||
def test_scraibe_diarization(create_scraibe_instance): | ||
model = create_scraibe_instance | ||
diarisation_result = model.diarization('test/audio_test_2.mp4') | ||
assert isinstance(diarisation_result, dict) | ||
|
||
|
||
def test_scraibe_transcribe(create_scraibe_instance): | ||
model = create_scraibe_instance | ||
transcription_result = model.transcribe('test/audio_test_2.mp4') | ||
assert isinstance(transcription_result, str) | ||
|
||
|
||
""" def test_remove_audio_file(create_scraibe_instance): | ||
model = create_scraibe_instance | ||
with pytest.raises(ValueError): | ||
model.remove_audio_file("non_existing_audio_file") | ||
model.remove_audio_file("audio_test_2.mp4") | ||
assert not os.path.exists("audio_test_2.mp4") """ | ||
|
||
|
||
""" def test_get_audio_file(create_scraibe_instance): | ||
model = create_scraibe_instance | ||
audio_file = os.path.exist("audio_test_2.mp4") | ||
assert isinstance(audio_file, AudioProcessor) | ||
assert isinstance(audio_file.waveform, torch.Tensor) | ||
assert isinstance(audio_file.sr, torch.Tensor) """ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import pytest | ||
import os | ||
from unittest import mock | ||
from scraibe import diarisation, Diariser | ||
|
||
|
||
|
||
@pytest.fixture | ||
def diariser_instance(): | ||
"""Fixture for creating an instance of the Diariser class with mocked token. | ||
This fixture is used to create an instance of the the Diariser class with a mocked token returned by the _get_token method. It patches the _get_token method of the Diariser class | ||
using unit.test.mock.patch.object, ensuring that it returns a predetrmined value ('personal Hugging-Face token'). The mocked Diariser object is retunrned and can be used as a dependency in otehr tests. | ||
Returns: | ||
Diariser(Obj): An instance of the Diariser class with a mocked token. | ||
""" | ||
#with mock.patch.object(Diariser, '_get_token', return_value = 'HF_TOKEN' ): | ||
return Diariser('pyannote') | ||
|
||
|
||
|
||
def test_Diariser_init(diariser_instance): | ||
"""Test the initialization of the Diariser class. | ||
This test verifies that the Diariser class is correctly initialized with the specified model. | ||
It checks whether the 'model' attribute of the instantiated Diariser object equals 'pyannote'. | ||
Args: | ||
diariser_instance (obj): instance of the Diariser class | ||
Returns: | ||
None | ||
""" | ||
assert diariser_instance.model == 'pyannote' | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import pytest | ||
from unittest.mock import patch | ||
from scraibe import Transcriber | ||
import torch | ||
|
||
|
||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
TEST_WAVEFORM = "Hello World" | ||
|
||
""" | ||
@pytest.mark.parametrize("audio_file, expected_transcription",[("path_to_test_audiofile", "test_transcription")] ) | ||
@patch("scraibe.Transcriber.load_model") | ||
def test_transcriber(mock_load_model, audio_file, expected_transcription): | ||
Args: | ||
mock_load_model (_type_): _description_ | ||
audio_file (_type_): _description_ | ||
expected_transcription (_type_): _description_ | ||
mock_model = mock_load_model.return_value | ||
mock_model.transcribe.return_value ={"text": expected_transcription} | ||
transcriber = Transcriber.load_model(model="medium") | ||
transcription_result = transcriber.transcribe(audio=audio_file) | ||
assert transcription_result == expected_transcription """ | ||
|
||
@pytest.fixture | ||
def transcriber_instance(): | ||
return Transcriber.load_model('medium') | ||
|
||
def test_transcriber_initialization(transcriber_instance): | ||
assert isinstance(transcriber_instance, Transcriber) | ||
|
||
def test_get_whisper_kwargs(): | ||
kwargs = {"arg1": 1, "arg3": 3} | ||
valid_kwargs = Transcriber._get_whisper_kwargs(**kwargs) | ||
assert not valid_kwargs == {"arg1": 1, "arg3": 3} | ||
|
||
|
||
def test_transcribe(transcriber_instance): | ||
model = transcriber_instance | ||
#mocker.patch.object(transcriber_instance.model, 'transcribe', return_value={'Hello, World !'} ) | ||
transcript = model.transcribe('test/audio_test_2.mp4') | ||
assert isinstance(transcript, str) | ||
|
||
|
||
|
Oops, something went wrong.