diff --git a/flowsettings.py b/flowsettings.py index 0962eefc2..9a3bd8c46 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -321,7 +321,7 @@ "config": { "supported_file_types": ( ".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, " - ".pptx, .csv, .html, .mhtml, .txt, .md, .zip" + ".pptx, .csv, .html, .mhtml, .txt, .md, .zip, .mp3" ), "private": False, }, @@ -336,7 +336,7 @@ "config": { "supported_file_types": ( ".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, " - ".pptx, .csv, .html, .mhtml, .txt, .md, .zip" + ".pptx, .csv, .html, .mhtml, .txt, .md, .zip, .mp3" ), "private": False, }, diff --git a/libs/kotaemon/kotaemon/indices/ingests/files.py b/libs/kotaemon/kotaemon/indices/ingests/files.py index 18db7ca86..cb15868af 100644 --- a/libs/kotaemon/kotaemon/indices/ingests/files.py +++ b/libs/kotaemon/kotaemon/indices/ingests/files.py @@ -17,6 +17,7 @@ HtmlReader, MathpixPDFReader, MhtmlReader, + MP3Reader, OCRReader, PandasExcelReader, PDFThumbnailReader, @@ -53,6 +54,7 @@ ".tiff": unstructured, ".tif": unstructured, ".pdf": PDFThumbnailReader(), + ".mp3": MP3Reader(), ".txt": TxtReader(), ".md": TxtReader(), } diff --git a/libs/kotaemon/kotaemon/loaders/__init__.py b/libs/kotaemon/kotaemon/loaders/__init__.py index f498da806..f7748326f 100644 --- a/libs/kotaemon/kotaemon/loaders/__init__.py +++ b/libs/kotaemon/kotaemon/loaders/__init__.py @@ -7,6 +7,7 @@ from .excel_loader import ExcelReader, PandasExcelReader from .html_loader import HtmlReader, MhtmlReader from .mathpix_loader import MathpixPDFReader +from .mp3_loader import MP3Reader from .ocr_loader import ImageReader, OCRReader from .pdf_loader import PDFThumbnailReader from .txt_loader import TxtReader @@ -30,6 +31,7 @@ "AdobeReader", "TxtReader", "PDFThumbnailReader", + "MP3Reader", "WebReader", "DoclingReader", ] diff --git a/libs/kotaemon/kotaemon/loaders/mp3_loader.py b/libs/kotaemon/kotaemon/loaders/mp3_loader.py new file mode 100644 index 000000000..b427c7700 --- /dev/null +++ b/libs/kotaemon/kotaemon/loaders/mp3_loader.py @@ -0,0 +1,101 @@ +from pathlib import Path +from typing import TYPE_CHECKING, List, Optional + +from loguru import logger + +from kotaemon.base import Document, Param + +from .base import BaseReader + +if TYPE_CHECKING: + from transformers import pipeline + + +class MP3Reader(BaseReader): + model_name_or_path: str = Param( + help="The model name or path to use for speech recognition.", + default="distil-whisper/distil-large-v3", + ) + cache_dir: str = Param( + help="The cache directory to use for the model.", + default="models", + ) + + @Param.auto() + def asr_pipeline(self) -> "pipeline": + """Setup the ASR pipeline for speech recognition""" + try: + import accelerate # noqa: F401 + import torch + from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline + except ImportError: + raise ImportError( + "Please install the required packages to use the MP3Reader: " + "'pip install accelerate torch transformers'" + ) + + try: + # Device and model configuration + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + device = "cuda:0" if torch.cuda.is_available() else "cpu" + + # Model and processor initialization + model = AutoModelForSpeechSeq2Seq.from_pretrained( + self.model_name_or_path, + torch_dtype=torch_dtype, + low_cpu_mem_usage=True, + use_safetensors=True, + cache_dir=self.cache_dir, + ).to(device) + + processor = AutoProcessor.from_pretrained( + self.model_name_or_path, + ) + + # ASR pipeline setup + asr_pipeline = pipeline( + "automatic-speech-recognition", + model=model, + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + max_new_tokens=128, + torch_dtype=torch_dtype, + device=device, + return_timestamps=True, + ) + logger.info("ASR pipeline setup successful.") + except Exception as e: + logger.error(f"Error occurred during ASR pipeline setup: {e}") + raise + + return asr_pipeline + + def speech_to_text(self, audio_path: str) -> str: + try: + import librosa + + # Performing speech recognition + audio_array, _ = librosa.load(audio_path, sr=16000) # 16kHz sampling rate + result = self.asr_pipeline(audio_array) + + text = result.get("text", "").strip() + if text == "": + logger.warning("No text found in the audio file.") + return text + except Exception as e: + logger.error(f"Error occurred during speech recognition: {e}") + return "" + + def run( + self, file_path: str | Path, extra_info: Optional[dict] = None, **kwargs + ) -> list[Document]: + return self.load_data(str(file_path), extra_info=extra_info, **kwargs) + + def load_data( + self, audio_file: str, extra_info: Optional[dict] = None, **kwargs + ) -> List[Document]: + # Get text from the audio file + text = self.speech_to_text(audio_file) + metadata = extra_info or {} + + return [Document(text=text, metadata=metadata)] diff --git a/libs/kotaemon/tests/conftest.py b/libs/kotaemon/tests/conftest.py index 3f46c7092..8c50431f3 100644 --- a/libs/kotaemon/tests/conftest.py +++ b/libs/kotaemon/tests/conftest.py @@ -70,6 +70,15 @@ def if_llama_cpp_not_installed(): return False +def if_librosa_not_installed(): + try: + import librosa # noqa: F401 + except ImportError: + return True + else: + return False + + skip_when_haystack_not_installed = pytest.mark.skipif( if_haystack_not_installed(), reason="Haystack is not installed" ) @@ -97,3 +106,7 @@ def if_llama_cpp_not_installed(): skip_llama_cpp_not_installed = pytest.mark.skipif( if_llama_cpp_not_installed(), reason="llama_cpp is not installed" ) + +skip_when_librosa_not_installed = pytest.mark.skipif( + if_librosa_not_installed(), reason="librosa is not installed" +) diff --git a/libs/kotaemon/tests/resources/dummy.mp3 b/libs/kotaemon/tests/resources/dummy.mp3 new file mode 100644 index 000000000..61de3714a Binary files /dev/null and b/libs/kotaemon/tests/resources/dummy.mp3 differ diff --git a/libs/kotaemon/tests/test_reader.py b/libs/kotaemon/tests/test_reader.py index d27c774bf..72788f9d5 100644 --- a/libs/kotaemon/tests/test_reader.py +++ b/libs/kotaemon/tests/test_reader.py @@ -11,10 +11,14 @@ DocxReader, HtmlReader, MhtmlReader, + MP3Reader, UnstructuredReader, ) -from .conftest import skip_when_unstructured_pdf_not_installed +from .conftest import ( + skip_when_librosa_not_installed, + skip_when_unstructured_pdf_not_installed, +) def test_docx_reader(): @@ -93,3 +97,18 @@ def test_azureai_document_intelligence_reader(mock_client): assert len(docs) == 1 mock_client.assert_called_once() + + +@skip_when_librosa_not_installed +@patch("kotaemon.loaders.MP3Reader.asr_pipeline") +def test_mp3_reader(mock_pipeline): + # Mock the return value + mock_pipeline.return_value = "This is the transcript" + + reader = MP3Reader() + docs = reader.load_data(str(Path(__file__).parent / "resources" / "dummy.mp3")) + + assert len(docs) == 1 + + # Assert that the ASR pipeline was called + mock_pipeline.assert_called_once()